Update custom_interface.py
Browse files- custom_interface.py +0 -5
custom_interface.py
CHANGED
@@ -22,12 +22,10 @@ class ASR(Pretrained):
|
|
22 |
predictions = self.hparams.test_search(encoded_outputs, self.wav_lens)[0]
|
23 |
predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
|
24 |
prediction = []
|
25 |
-
print(predicted_words)
|
26 |
for sent in predicted_words:
|
27 |
sent = self.filter_repetitions(sent, 3)
|
28 |
prediction.append(sent)
|
29 |
predicted_words = prediction
|
30 |
-
print(predicted_words)
|
31 |
return predicted_words
|
32 |
|
33 |
def filter_repetitions(self, seq, max_repetition_length):
|
@@ -94,6 +92,3 @@ class ASR(Pretrained):
|
|
94 |
outputs = self.encode_batch(batch, rel_length)
|
95 |
|
96 |
return outputs
|
97 |
-
|
98 |
-
# def forward(self, wavs, wav_lens=None):
|
99 |
-
# return self.encode_batch(wavs=wavs, wav_lens=wav_lens)
|
|
|
22 |
predictions = self.hparams.test_search(encoded_outputs, self.wav_lens)[0]
|
23 |
predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
|
24 |
prediction = []
|
|
|
25 |
for sent in predicted_words:
|
26 |
sent = self.filter_repetitions(sent, 3)
|
27 |
prediction.append(sent)
|
28 |
predicted_words = prediction
|
|
|
29 |
return predicted_words
|
30 |
|
31 |
def filter_repetitions(self, seq, max_repetition_length):
|
|
|
92 |
outputs = self.encode_batch(batch, rel_length)
|
93 |
|
94 |
return outputs
|
|
|
|
|
|