Skip to content

Commit

Permalink
Merge pull request #14 from Open-Speech-EkStep/dev/v2-hydra
Browse files Browse the repository at this point in the history
Merging for Torchscript
  • Loading branch information
harveenchadha authored Sep 4, 2021
2 parents 15d67eb + c438016 commit 16943fa
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 1 deletion.
11 changes: 10 additions & 1 deletion scripts/inference/infer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ parentdir="$(dirname "$parentdir")"
### Values to change -start ###

w2l_decoder_viterbi=1 # 1 for viterbi, 0 for kenlm
inference_data_name=''
inference_data_name='toy_english'
beam=128 # 128 or 1024
subset='test'

Expand All @@ -15,6 +15,10 @@ lm_name=''
lm_model_path=${parentdir}'/lm/'${lm_name}'/lm.binary'
lexicon_lst_path=${parentdir}'/lm/'${lm_name}'/lexicon.lst'

# SAVE PREDICTED TEXT FILES
dest_folder='/home/ankurdhuriya/abc'
save_predicted=1

# FOR pretrained model
pretrained_model_path='../../checkpoints/pretraining/CLSRIL-23.pt'

Expand Down Expand Up @@ -58,4 +62,9 @@ else

python ../../utils/wer/wer_wav2vec.py -o ${kenlm_result_path}/ref.word-checkpoint_best.pt-test.txt -p ${kenlm_result_path}/hypo.word-checkpoint_best.pt-test.txt \
-t ${data_path}/${subset}.tsv -s save -n ${kenlm_result_path}/sentence_wise_wer.csv -e true

fi

if [ "${save_predicted}" = 1 ]; then
python ../../utils/inference/save_predicted_output.py -f ${result_path}/sentence_wise_wer.csv -d ${dest_folder}
fi
8 changes: 8 additions & 0 deletions scripts/torchscript/convert_hf.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
input_model_name = 'test'
output_dir = '../../checkpoints/ts/'

mkdir -p ${output_dir}

python ../../utils/torchscript/convert_hf.py -i ${input_model_name} -o ${output_dir}

echo "Torchscript Model saved"
32 changes: 32 additions & 0 deletions utils/inference/save_predicted_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pandas as pd
import os
from tqdm import tqdm
import argparse

def save_text_file(path, text):
with open(path, 'w+', encoding='utf-8') as out:
out.write(text)

def save_predicted_output(out_csv, dest):
df = pd.read_csv(out_csv)
dest = os.path.abspath(dest)
os.makedirs(dest, exist_ok=True)

for ix, row in tqdm(df.iterrows()):
text = row['predicted']
fpath = os.path.join(dest, row['path'].strip('/')).replace('.wav','.txt')
folder = '/'.join(fpath.split('/')[:-1])

os.makedirs(folder, exist_ok=True)
save_text_file(fpath, text)
print(f"predicted files created at {dest}")

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run')
parser.add_argument('-f', '--csv-path', type=str, help="Sentence wer csv path")
parser.add_argument('-d', '--dest-path', type=str, help="Path to save predicted output as text files")

args_local = parser.parse_args()

save_predicted_output(args_local.csv_path, args_local.dest_path)

76 changes: 76 additions & 0 deletions utils/torchscript/convert_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
from torch import Tensor
from torch.utils.mobile_optimizer import optimize_for_mobile
import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
from transformers import Wav2Vec2ForCTC
import json
import argparse
import os

class SpeechRecognizer(torch.nn.Module):
def __init__(self, model, vocab):
super().__init__()
self.model = model
vocab = vocab
self.labels = list(vocab.keys())

def forward(self, waveforms: Tensor) -> str:
"""Given a single channel speech data, return transcription.
Args:
waveforms (Tensor): Speech tensor. Shape `[1, num_frames]`.
Returns:
str: The resulting transcript
"""
logits, _ = self.model(waveforms) # [batch, num_seq, num_label]
best_path = torch.argmax(logits[0], dim=-1) # [num_seq,]
prev = ''
hypothesis = ''
for i in best_path:
char = self.labels[i]
if char == prev:
continue
if char == '<s>':
prev = ''
continue
hypothesis += char
prev = char
return hypothesis.replace('|', ' ')

def read_vocab(hf_model_name):
vocab = f'https://huggingface.co/{hf_model_name}/raw/main/vocab.json'
os.system('wget ' +vocab)
with open('vocab.json', encoding='utf-8') as file:
vocab = json.load(file)

return vocab

def convert_model(hf_model_name, output_dir):
# Load Wav2Vec2 pretrained model from Hugging Face Hub
model = Wav2Vec2ForCTC.from_pretrained(hf_model_name)
# Convert the model to torchaudio format, which supports TorchScript.
model = import_huggingface_model(model)
# Remove weight normalization which is not supported by quantization.
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
model = model.eval()
# Attach decoder
model = SpeechRecognizer(model, read_vocab(hf_model_name))

# Apply quantization / script / optimize for mobile
quantized_model = torch.quantization.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_model = torch.jit.script(quantized_model)
optimized_model = optimize_for_mobile(scripted_model)
quant_model_name = hf_model_name.split('/')[-1] + '_quant.pt'
os.makedirs(output_dir, exist_ok=True)
optimized_model.save(output_dir+ '/' + quant_model_name)
os.system(f'mv vocab.json {output_dir}/')


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--hf-model', '-i', type=str, required=True)
parser.add_argument('--output', '-o', type=str, required=True)
args = parser.parse_args()

convert_model(args.hf_model, args.output)

0 comments on commit 16943fa

Please sign in to comment.