|
import os |
|
import gc |
|
import torch |
|
import librosa |
|
import numpy as np |
|
import gradio as gr |
|
from transformers import (AutoProcessor, AutoModelForCTC, |
|
AutoModelForTokenClassification, AutoTokenizer) |
|
from speechbrain.inference.VAD import VAD |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty", savedir="vad_model") |
|
|
|
|
|
def clean_up_memory(): |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
asr_model_name = "facebook/wav2vec2-large-960h" |
|
processor = AutoProcessor.from_pretrained(asr_model_name) |
|
w2v2_model = AutoModelForCTC.from_pretrained(asr_model_name).to(device) |
|
w2v2_model.eval() |
|
|
|
|
|
recap_model_name = "kredor/punctuate-all" |
|
recap_tokenizer = AutoTokenizer.from_pretrained(recap_model_name) |
|
recap_model = AutoModelForTokenClassification.from_pretrained(recap_model_name).to(device) |
|
recap_model.eval() |
|
|
|
|
|
def recap_sentence(string): |
|
tokens = recap_tokenizer(string, return_tensors="pt", padding=True, truncation=True).to(device) |
|
with torch.no_grad(): |
|
predictions = recap_model(**tokens).logits |
|
|
|
predicted_ids = torch.argmax(predictions, dim=-1)[0] |
|
words = string.split() |
|
punctuated_text = [] |
|
|
|
for word, pred in zip(words, predicted_ids): |
|
punctuated_text.append(word + recap_tokenizer.convert_ids_to_tokens([pred.item()])[0]) |
|
|
|
return " ".join(punctuated_text) |
|
|
|
|
|
def transcribe_audio_stream(audio_file, chunk_size=2.0): |
|
audio, sr = librosa.load(audio_file, sr=16000) |
|
duration = librosa.get_duration(y=audio, sr=sr) |
|
transcriptions = [] |
|
|
|
for start in np.arange(0, duration, chunk_size): |
|
end = min(start + chunk_size, duration) |
|
chunk = audio[int(start * sr):int(end * sr)] |
|
|
|
input_values = processor(chunk, return_tensors="pt", sampling_rate=16000).input_values.to(w2v2_model.device) |
|
|
|
with torch.no_grad(): |
|
logits = w2v2_model(input_values).logits |
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.batch_decode(predicted_ids)[0] |
|
transcriptions.append(transcription) |
|
|
|
return " ".join(transcriptions) |
|
|
|
|
|
def return_prediction_w2v2(file_or_mic): |
|
if not file_or_mic: |
|
return "", "empty.txt" |
|
|
|
|
|
transcription = transcribe_audio_stream(file_or_mic) |
|
|
|
|
|
recap_result = recap_sentence(transcription) |
|
|
|
|
|
download_path = "transcription.txt" |
|
with open(download_path, "w") as f: |
|
f.write(recap_result) |
|
|
|
clean_up_memory() |
|
return recap_result, download_path |
|
|
|
|
|
mic_transcribe = gr.Interface( |
|
fn=return_prediction_w2v2, |
|
inputs=gr.Audio(sources="microphone", type="filepath"), |
|
outputs=[gr.Textbox(label="Real-Time Transcription"), gr.File(label="Download Transcript")], |
|
allow_flagging="never", |
|
live=True |
|
) |
|
|
|
file_transcribe = gr.Interface( |
|
fn=return_prediction_w2v2, |
|
inputs=gr.Audio(sources="upload", type="filepath"), |
|
outputs=[gr.Textbox(label="File Transcription"), gr.File(label="Download Transcript")], |
|
allow_flagging="never", |
|
live=False |
|
) |
|
|
|
|
|
with gr.Blocks() as transcriber_app: |
|
gr.Markdown("<h2>CCI Real-Time Sermon Transcription</h2>") |
|
gr.TabbedInterface([mic_transcribe, file_transcribe], |
|
["Real-Time (Microphone)", "Upload Audio"]) |
|
|
|
|
|
if __name__ == "__main__": |
|
transcriber_app.launch() |
|
|