File size: 1,448 Bytes
171bc27
 
641eeeb
171bc27
 
 
 
 
 
 
747d7b4
722014e
171bc27
 
722014e
747d7b4
 
722014e
 
171bc27
 
722014e
 
171bc27
722014e
171bc27
 
 
641eeeb
 
402022e
641eeeb
171bc27
 
 
722014e
171bc27
 
 
b1c4d32
641eeeb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from speechbrain.inference.interfaces import Pretrained
import librosa


class ASR(Pretrained):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def encode_batch(self, wavs, wav_lens=None, normalize=False):
        wavs = wavs.to(self.device)
        self.wav_lens = wav_lens.to(self.device)

        # Forward pass
        encoded_outputs = self.mods.encoder_w2v2(wavs.detach())
        # append
        tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(self.device)
        embedded_tokens = self.mods.embedding(tokens_bos)
        decoder_outputs, _ = self.mods.decoder(embedded_tokens, encoded_outputs, self.wav_lens)

        # Output layer for seq2seq log-probabilities
        predictions = self.hparams.test_search(encoded_outputs, self.wav_lens)[0]
        predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]

        return predicted_words


    def classify_file(self, path):
        # waveform = self.load_audio(path)
        waveform, sr = librosa.load(path, sr=16000)
        waveform = torch.tensor(waveform)

        # Fake a batch:
        batch = waveform.unsqueeze(0)
        rel_length = torch.tensor([1.0])
        outputs = self.encode_batch(batch, rel_length)
       
        return outputs

    # def forward(self, wavs, wav_lens=None):
    #     return self.encode_batch(wavs=wavs, wav_lens=wav_lens)