File size: 1,448 Bytes
171bc27 641eeeb 171bc27 747d7b4 722014e 171bc27 722014e 747d7b4 722014e 171bc27 722014e 171bc27 722014e 171bc27 641eeeb 402022e 641eeeb 171bc27 722014e 171bc27 b1c4d32 641eeeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
from speechbrain.inference.interfaces import Pretrained
import librosa
class ASR(Pretrained):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode_batch(self, wavs, wav_lens=None, normalize=False):
wavs = wavs.to(self.device)
self.wav_lens = wav_lens.to(self.device)
# Forward pass
encoded_outputs = self.mods.encoder_w2v2(wavs.detach())
# append
tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(self.device)
embedded_tokens = self.mods.embedding(tokens_bos)
decoder_outputs, _ = self.mods.decoder(embedded_tokens, encoded_outputs, self.wav_lens)
# Output layer for seq2seq log-probabilities
predictions = self.hparams.test_search(encoded_outputs, self.wav_lens)[0]
predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
return predicted_words
def classify_file(self, path):
# waveform = self.load_audio(path)
waveform, sr = librosa.load(path, sr=16000)
waveform = torch.tensor(waveform)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
outputs = self.encode_batch(batch, rel_length)
return outputs
# def forward(self, wavs, wav_lens=None):
# return self.encode_batch(wavs=wavs, wav_lens=wav_lens)
|