import torch import torchaudio import gradio as gr from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC # Available models MODEL_OPTIONS = [ "facebook/wav2vec2-large-960h", "jonatasgrosman/wav2vec2-large-xlsr-53-english", "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "jonatasgrosman/wav2vec2-large-xlsr-53-french" ] # Default model model_name = MODEL_OPTIONS[0] processor = Wav2Vec2Processor.from_pretrained(model_name) model = Wav2Vec2ForCTC.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() def transcribe_audio(audio_path, model_choice): try: global processor, model, model_name # If model has changed, reload it if model_choice != model_name: processor = Wav2Vec2Processor.from_pretrained(model_choice) model = Wav2Vec2ForCTC.from_pretrained(model_choice).to(device) model_name = model_choice # Load and preprocess the audio waveform, sample_rate = torchaudio.load(audio_path) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) # Convert audio to input format input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_values.to(device) # Model inference with torch.no_grad(): logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) # Set group_tokens to False to bypass collapsing repeated tokens raw_transcription = processor.batch_decode(predicted_ids, group_tokens=False)[0] return raw_transcription except Exception as e: return f"Error: {str(e)}" # Gradio Interface iface = gr.Interface( fn=transcribe_audio, inputs=[ gr.Audio(type="filepath"), gr.Dropdown(MODEL_OPTIONS, label="Choose Model", value=MODEL_OPTIONS[0]), ], outputs="text", title="Speech Transcriber", description="Upload an audio file and get a transcription using different models.", ) iface.launch()