-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpip_seq.py
102 lines (89 loc) · 3.57 KB
/
pip_seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import logging
import os
import random
from shutil import copyfile
import torch
import platalea.asr as M1
import platalea.dataset as D
import platalea.text_image as M2
from platalea.utils.copy_best import copy_best
from platalea.utils.extract_transcriptions import extract_trn
from platalea.experiments.config import args
# Parsing arguments
args.add_argument(
'--asr_model_dir',
help='Path to the directory where the pretrained ASR/SLT model is stored',
dest='asr_model_dir', type=str, action='store')
args.add_argument(
'--downsampling_factor_text', default=None, type=float,
help='factor by which the amount of available transcriptions should be \
downsampled (affecting ASR only)')
args.enable_help()
args.parse()
# Setting general configuration
torch.manual_seed(args.seed)
random.seed(args.seed)
logging.basicConfig(level=logging.INFO)
# Logging the arguments
logging.info('Arguments: {}'.format(args))
batch_size = 8
logging.info('Loading data')
data = dict(
train=D.flickr8k_loader(
args.flickr8k_root, args.flickr8k_meta, args.flickr8k_language,
args.audio_features_fn, split='train', batch_size=batch_size,
shuffle=True, downsampling_factor=args.downsampling_factor),
val=D.flickr8k_loader(
args.flickr8k_root, args.flickr8k_meta, args.flickr8k_language,
args.audio_features_fn, split='val', batch_size=batch_size,
shuffle=False)
if args.downsampling_factor_text:
ds_factor_text = args.downsampling_factor_text
# The downsampling factor for text is applied on top of the main
# downsampling factor that is applied to all data
if args.downsampling_factor:
ds_factor_text *= args.downsampling_factor
data_asr = dict(
train=D.flickr8k_loader(
split='train', batch_size=batch_size, shuffle=True,
downsampling_factor=ds_factor_text),
val=D.flickr8k_loader(split='val', batch_size=batch_size))
else:
data_asr = data
if args.asr_model_dir:
net = torch.load(os.path.join(args.asr_model_dir, 'net.best.pt'))
else:
logging.info('Building ASR/SLT model')
config = M1.get_default_config()
net = M1.SpeechTranscriber(config)
run_config = dict(max_norm=2.0, max_lr=2 * 1e-4, epochs=32)
logging.info('Training ASR/SLT')
if data_asr['train'].dataset.is_slt():
M1.experiment(net, data_asr, run_config, slt=True)
copy_best('.', 'result.json', 'asr.best.pt', experiment_type='slt')
else:
M1.experiment(net, data_asr, run_config)
copy_best('.', 'result.json', 'asr.best.pt', experiment_type='asr')
copyfile('result.json', 'result_asr.json')
net = torch.load('asr.best.pt')
logging.info('Extracting ASR/SLT transcriptions')
for set_name in ['train', 'val']:
ds = data[set_name].dataset
hyp_asr, ref_asr = extract_trn(net, ds, use_beam_decoding=True)
# Replacing original transcriptions with ASR/SLT's output
for i in range(len(hyp_asr)):
item = ds.split_data[i]
if item[2] == ref_asr[i]:
ds.split_data[i] = (item[0], item[1], hyp_asr[i])
else:
msg = 'Extracted reference #{} ({}) doesn\'t match dataset\'s \
one ({}) for {} set.'
msg = msg.format(i, ref_asr[i], ds.split_data[i][3], set_name)
logging.warning(msg)
logging.info('Building model text-image')
net = M2.TextImage(M2.get_default_config())
run_config = dict(max_lr=2 * 1e-4, epochs=32)
logging.info('Training text-image')
M2.experiment(net, data, run_config)
copyfile('result.json', 'result_text_image.json')
copy_best('.', 'result_text_image.json', 'ti.best.pt')