whisper-tamil / app.py
ragunath-ravi's picture
Update app.py
8309cdf verified
import os
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import numpy as np
import librosa
import os
from huggingface_hub import login
hf_token = os.getenv("token")
if hf_token is None:
raise ValueError("HF_TOKEN environment variable not set.")
login(token=hf_token)
# Load model
processor = AutoProcessor.from_pretrained("ragunath-ravi/whisper-mini-ta")
model = AutoModelForSpeechSeq2Seq.from_pretrained("ragunath-ravi/whisper-mini-ta")
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
def transcribe_audio(audio):
"""
Transcribe audio using the fine-tuned Whisper model
"""
if audio is None:
return "No audio provided"
# Load audio
if isinstance(audio, str): # Path to audio file
audio_array, sampling_rate = librosa.load(audio, sr=16000)
else: # Tuple of (sample_rate, audio_array)
sampling_rate, audio_array = audio
# Convert to float32 if not already in floating-point format
if not np.issubdtype(audio_array.dtype, np.floating):
audio_array = audio_array.astype(np.float32)
# Normalize if needed (assuming int16 or int32 input)
if np.max(np.abs(audio_array)) > 1.0:
audio_array = audio_array / 32768.0 # Normalize for 16-bit audio
if sampling_rate != 16000:
audio_array = librosa.resample(audio_array, orig_sr=sampling_rate, target_sr=16000)
sampling_rate = 16000
# Process through model
inputs = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = model.generate(input_features=inputs.input_features)
# Decode
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Tamil Speech Transcription")
gr.Markdown("This app transcribes Tamil speech using a fine-tuned Whisper model. Upload an audio file or record using the microphone.")
with gr.Row():
with gr.Column():
# Set max_length to 10 seconds explicitly (10 * 16000 = 160000 samples)
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Audio Input (Record up to 10 seconds)",
max_length=10
)
transcribe_btn = gr.Button("Transcribe")
with gr.Column():
output_text = gr.Textbox(label="Transcription Output", lines=5)
# Connect components
transcribe_btn.click(fn=transcribe_audio, inputs=audio_input, outputs=output_text)
audio_input.change(fn=transcribe_audio, inputs=audio_input, outputs=output_text)
gr.Markdown("## Instructions")
gr.Markdown("1. Record audio using the microphone or upload an audio file")
gr.Markdown("2. Click 'Transcribe' to get the Tamil transcription")
gr.Markdown("3. For best results, speak clearly in Tamil")
# Launch the app
if __name__ == "__main__":
demo.launch(share=True)