tahirjm commited on
Commit
d30c046
·
verified ·
1 Parent(s): 33ca12f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. model_onnx.py +9 -4
model_onnx.py CHANGED
@@ -55,16 +55,21 @@ class IndicASRModel(PreTrainedModel):
55
 
56
  def encode(self, wav):
57
  # pass through preprocessor
58
- audio_signal, length = self.models['preprocessor'](input_signal=wav.to(self.config.device), length=torch.tensor([wav.shape[-1]]).to(self.config.device))
59
  outputs, encoded_lengths = self.models['encoder'].run(['outputs', 'encoded_lengths'], {'audio_signal': audio_signal.cpu().numpy(), 'length': length.cpu().numpy()})
60
  return outputs, encoded_lengths
61
 
62
  def _ctc_decode(self, encoder_outputs, encoded_lengths, lang):
63
- logprobs = self.models['ctc_decoder'](encoder_output=encoder_outputs)
64
- logprobs = logprobs[:,:,self.language_masks[lang]].log_softmax(dim=-1)
 
 
65
  indices = torch.argmax(logprobs[0],dim=-1)
66
  collapsed_indices = torch.unique_consecutive(indices, dim=-1)
67
- return ''.join([self.vocab[lang][x] for x in collapsed_indices if x != self.config.BLANK_ID]).replace('▁',' ').strip()
 
 
 
68
 
69
  def _rnnt_decode(self, encoder_outputs, encoded_lengths, lang):
70
  joint_enc = self.models['joint_enc'].run(['output'], {'input': encoder_outputs.transpose(0, 2, 1)})[0]
 
55
 
56
  def encode(self, wav):
57
  # pass through preprocessor
58
+ audio_signal, length = self.models['preprocessor'](input_signal=wav.to(self.device), length=torch.tensor([wav.shape[-1]]).to(self.device))
59
  outputs, encoded_lengths = self.models['encoder'].run(['outputs', 'encoded_lengths'], {'audio_signal': audio_signal.cpu().numpy(), 'length': length.cpu().numpy()})
60
  return outputs, encoded_lengths
61
 
62
  def _ctc_decode(self, encoder_outputs, encoded_lengths, lang):
63
+ logprobs = self.models['ctc_decoder'].run(['logprobs'], {'encoder_output': encoder_outputs})[0]
64
+ logprobs = torch.from_numpy(logprobs[:, :, self.language_masks[lang]]).log_softmax(dim=-1)
65
+
66
+ # currently no batching
67
  indices = torch.argmax(logprobs[0],dim=-1)
68
  collapsed_indices = torch.unique_consecutive(indices, dim=-1)
69
+
70
+ hyp = ''.join([self.vocab[lang][x] for x in collapsed_indices if x != self.BLANK_ID]).replace('▁',' ').strip()
71
+ del logprobs, indices, collapsed_indices
72
+ return hyp
73
 
74
  def _rnnt_decode(self, encoder_outputs, encoded_lengths, lang):
75
  joint_enc = self.models['joint_enc'].run(['output'], {'input': encoder_outputs.transpose(0, 2, 1)})[0]