File size: 6,566 Bytes
86e5dcc 5aece43 86e5dcc 5aece43 86e5dcc d30c046 86e5dcc d30c046 f3e64ec d30c046 86e5dcc f3e64ec 197a249 b89c1a3 86e5dcc f3e64ec 86e5dcc f3e64ec 86e5dcc f3e64ec 86e5dcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
import torch
import os
import numpy as np
import json
import onnxruntime as ort
from huggingface_hub import snapshot_download
class IndicASRConfig(PretrainedConfig):
model_type = "iasr"
def __init__(self, ts_folder: str = "path", BLANK_ID: int = 256, RNNT_MAX_SYMBOLS: int = 10,
PRED_RNN_LAYERS: int = 2, PRED_RNN_HIDDEN_DIM: int = 640, SOS: int = 256, **kwargs):
self.ts_folder = ts_folder
self.SOS = SOS
class IndicASRModel(PreTrainedModel):
config_class = IndicASRConfig
def __init__(self, config):
# Load model components
self.models = {}
names = ['encoder', 'ctc_decoder', 'rnnt_decoder', 'joint_enc', 'joint_pred', 'joint_pre_net'] + [f'joint_post_net_{z}' for z in ['as', 'bn', 'brx', 'doi', 'gu', 'hi', 'kn', 'kok', 'ks', 'mai', 'ml', 'mni', 'mr', 'ne', 'or', 'pa', 'sa', 'sat', 'sd', 'ta', 'te', 'ur']]
self.models = {}
self.d = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.models['preprocessor'] = torch.jit.load(f'{config.ts_folder}/assets/preprocessor.ts', map_location=self.d)
for n in names:
component_name = f'{config.ts_folder}/assets/{n}.onnx'
if os.path.exists(config.ts_folder):
self.models[n] = ort.InferenceSession(component_name, providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider'])
self.models[n] = None
print('Failed to load', component_name)
# Load vocab and language masks
with open(f'{config.ts_folder}/assets/vocab.json') as reader:
self.vocab = json.load(reader)
with open(f'{config.ts_folder}/assets/language_masks.json') as reader:
self.language_masks = json.load(reader)
def forward(self, wav, lang, decoding='ctc'):
encoder_outputs, encoded_lengths = self.encode(wav)
if decoding == 'ctc':
return self._ctc_decode(encoder_outputs, encoded_lengths, lang)
if decoding == 'rnnt':
return self._rnnt_decode(encoder_outputs, encoded_lengths, lang)
def encode(self, wav):
# pass through preprocessor
audio_signal, length = self.models['preprocessor'](, length=torch.tensor([wav.shape[-1]]).to(self.d))
outputs, encoded_lengths = self.models['encoder'].run(['outputs', 'encoded_lengths'], {'audio_signal': audio_signal.cpu().numpy(), 'length': length.cpu().numpy()})
return outputs, encoded_lengths
def _ctc_decode(self, encoder_outputs, encoded_lengths, lang):
logprobs = self.models['ctc_decoder'].run(['logprobs'], {'encoder_output': encoder_outputs})[0]
logprobs = torch.from_numpy(logprobs[:, :, self.language_masks[lang]]).log_softmax(dim=-1)
# currently no batching
indices = torch.argmax(logprobs[0],dim=-1)
collapsed_indices = torch.unique_consecutive(indices, dim=-1)
hyp = ''.join([self.vocab[lang][x] for x in collapsed_indices if x != self.config.BLANK_ID]).replace('▁',' ').strip()
del logprobs, indices, collapsed_indices
return hyp
def _rnnt_decode(self, encoder_outputs, encoded_lengths, lang):
joint_enc = self.models['joint_enc'].run(['output'], {'input': encoder_outputs.transpose(0, 2, 1)})[0]
joint_enc = torch.from_numpy(joint_enc)
# Initialize hypothesis with SOS token
hyp = [self.config.SOS]
prev_dec_state = (np.zeros((self.config.PRED_RNN_LAYERS, 1, self.config.PRED_RNN_HIDDEN_DIM), dtype=np.float32),
np.zeros((self.config.PRED_RNN_LAYERS, 1, self.config.PRED_RNN_HIDDEN_DIM), dtype=np.float32))
# Iterate over time steps (T)
for t in range(joint_enc.size(1)):
f = joint_enc[:, t, :].unsqueeze(1) # B x 1 x H
not_blank = True
symbols_added = 0
while not_blank and ((self.config.RNNT_MAX_SYMBOLS is None) or (symbols_added < self.config.RNNT_MAX_SYMBOLS)):
# Decoder forward passsaa
g, _, dec_state_0, dec_state_1 = self.models['rnnt_decoder'].run(
['outputs', 'prednet_lengths', 'states', '162'],
{'targets': np.array([[hyp[-1]]], dtype=np.int32),
'target_length': np.array([1], dtype=np.int32),
'states.1': prev_dec_state[0],
'onnx::Slice_3': prev_dec_state[1]})
# Joint network
g = self.models['joint_pred'].run(['output'], {'input': g.transpose(0,2,1)})[0]
joint_out = f + g # B x 1 x H
joint_out = self.models['joint_pre_net'].run(['output'], {'input': joint_out.numpy()})[0]
logits = self.models[f'joint_post_net_{lang}'].run(['output'], {'input': joint_out})[0]
log_probs = torch.from_numpy(logits).log_softmax(dim=-1)
pred_token = log_probs.argmax(dim=-1).item()
# Append if not blank
if pred_token == self.config.BLANK_ID:
not_blank = False
prev_dec_state = (dec_state_0, dec_state_1)
symbols_added += 1
pred_text = ''.join([self.vocab[lang][x] for x in hyp if x != self.config.SOS]).replace('▁',' ').strip()
return pred_text
def from_pretrained(cls,
revision=None, **kwargs):
loc = snapshot_download(repo_id=pretrained_model_name_or_path, token=token)
return cls(IndicASRConfig(ts_folder=loc, **kwargs))
if __name__ == '__main__':
from transformers import AutoConfig, AutoModel
# Register the model so it can be used with AutoModel
AutoConfig.register("iasr", IndicASRConfig)
AutoModel.register(IndicASRConfig, IndicASRModel)