Upload folder using huggingface_hub
Browse files- model_onnx.py +9 -4
model_onnx.py
CHANGED
@@ -55,16 +55,21 @@ class IndicASRModel(PreTrainedModel):
|
|
55 |
|
56 |
def encode(self, wav):
|
57 |
# pass through preprocessor
|
58 |
-
audio_signal, length = self.models['preprocessor'](input_signal=wav.to(self.
|
59 |
outputs, encoded_lengths = self.models['encoder'].run(['outputs', 'encoded_lengths'], {'audio_signal': audio_signal.cpu().numpy(), 'length': length.cpu().numpy()})
|
60 |
return outputs, encoded_lengths
|
61 |
|
62 |
def _ctc_decode(self, encoder_outputs, encoded_lengths, lang):
|
63 |
-
logprobs = self.models['ctc_decoder'](encoder_output
|
64 |
-
logprobs = logprobs[
|
|
|
|
|
65 |
indices = torch.argmax(logprobs[0],dim=-1)
|
66 |
collapsed_indices = torch.unique_consecutive(indices, dim=-1)
|
67 |
-
|
|
|
|
|
|
|
68 |
|
69 |
def _rnnt_decode(self, encoder_outputs, encoded_lengths, lang):
|
70 |
joint_enc = self.models['joint_enc'].run(['output'], {'input': encoder_outputs.transpose(0, 2, 1)})[0]
|
|
|
55 |
|
56 |
def encode(self, wav):
|
57 |
# pass through preprocessor
|
58 |
+
audio_signal, length = self.models['preprocessor'](input_signal=wav.to(self.device), length=torch.tensor([wav.shape[-1]]).to(self.device))
|
59 |
outputs, encoded_lengths = self.models['encoder'].run(['outputs', 'encoded_lengths'], {'audio_signal': audio_signal.cpu().numpy(), 'length': length.cpu().numpy()})
|
60 |
return outputs, encoded_lengths
|
61 |
|
62 |
def _ctc_decode(self, encoder_outputs, encoded_lengths, lang):
|
63 |
+
logprobs = self.models['ctc_decoder'].run(['logprobs'], {'encoder_output': encoder_outputs})[0]
|
64 |
+
logprobs = torch.from_numpy(logprobs[:, :, self.language_masks[lang]]).log_softmax(dim=-1)
|
65 |
+
|
66 |
+
# currently no batching
|
67 |
indices = torch.argmax(logprobs[0],dim=-1)
|
68 |
collapsed_indices = torch.unique_consecutive(indices, dim=-1)
|
69 |
+
|
70 |
+
hyp = ''.join([self.vocab[lang][x] for x in collapsed_indices if x != self.BLANK_ID]).replace('▁',' ').strip()
|
71 |
+
del logprobs, indices, collapsed_indices
|
72 |
+
return hyp
|
73 |
|
74 |
def _rnnt_decode(self, encoder_outputs, encoded_lengths, lang):
|
75 |
joint_enc = self.models['joint_enc'].run(['output'], {'input': encoder_outputs.transpose(0, 2, 1)})[0]
|