Emmanuel08's picture
Update app.py
d63bba0 verified
raw
history blame
4.9 kB
import torch
import torchaudio
import gradio as gr
import time
import numpy as np
import scipy.io.wavfile
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
# βœ… 1️⃣ Use "whisper-small" for better accuracy
device = "cpu"
torch_dtype = torch.float32
MODEL_NAME = "openai/whisper-small"
# βœ… 2️⃣ Load Whisper Model on CPU (Removed bitsandbytes)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_NAME, torch_dtype=torch_dtype, use_safetensors=True
)
model.to(device)
# βœ… 3️⃣ Speed up execution with torch.compile()
model = torch.compile(model) # βœ… Faster inference on CPU
# βœ… 4️⃣ Load Processor & Pipeline
processor = AutoProcessor.from_pretrained(MODEL_NAME)
processor.feature_extractor.sampling_rate = 16000 # βœ… Set correct sampling rate
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=5, # βœ… Better balance between speed & accuracy
torch_dtype=torch_dtype,
device=device,
generate_kwargs={"num_beams": 5, "language": "en"}, # βœ… Beam search for better accuracy
)
# βœ… 5️⃣ Real-Time Streaming Transcription (Microphone)
def stream_transcribe(stream, new_chunk):
start_time = time.time()
try:
sr, y = new_chunk
# βœ… Convert stereo to mono
if y.ndim > 1:
y = y.mean(axis=1)
y = y.astype(np.float32)
y /= np.max(np.abs(y))
# βœ… Resample audio using optimized torchaudio method
y_tensor = torch.tensor(y)
y_resampled = torchaudio.functional.resample(y_tensor, orig_freq=sr, new_freq=16000).numpy()
# βœ… Append to Stream
if stream is not None:
stream = np.concatenate([stream, y_resampled])
else:
stream = y_resampled
# βœ… Run Transcription with Optimized Parameters
transcription = pipe({"sampling_rate": 16000, "raw": stream})["text"]
latency = time.time() - start_time
return stream, transcription, f"{latency:.2f} sec"
except Exception as e:
print(f"Error: {e}")
return stream, str(e), "Error"
# βœ… 6️⃣ Transcription for File Upload
def transcribe(inputs, previous_transcription):
start_time = time.time()
try:
# βœ… Convert file input to correct format
sample_rate, audio_data = inputs
# βœ… Resample using torchaudio (optimized)
audio_tensor = torch.tensor(audio_data)
resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=16000).numpy()
transcription = pipe({"sampling_rate": 16000, "raw": resampled_audio})["text"]
previous_transcription += transcription
latency = time.time() - start_time
return previous_transcription, f"{latency:.2f} sec"
except Exception as e:
print(f"Error: {e}")
return previous_transcription, "Error"
# βœ… 7️⃣ Clear Function
def clear():
return ""
# βœ… 8️⃣ Gradio Interface (Microphone Streaming)
with gr.Blocks() as microphone:
gr.Markdown(f"# Whisper Small - Real-Time Transcription (Optimized CPU) πŸŽ™οΈ")
gr.Markdown(f"Using [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) for ultra-fast speech-to-text with better accuracy.")
with gr.Row():
input_audio_microphone = gr.Audio(sources=["microphone"], type="numpy", streaming=True)
output = gr.Textbox(label="Live Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0")
with gr.Row():
clear_button = gr.Button("Clear Output")
state = gr.State()
input_audio_microphone.stream(
stream_transcribe, [state, input_audio_microphone],
[state, output, latency_textbox], time_limit=30, stream_every=1
)
clear_button.click(clear, outputs=[output])
# βœ… 9️⃣ Gradio Interface (File Upload)
with gr.Blocks() as file:
gr.Markdown(f"# Upload Audio File for Transcription 🎡")
gr.Markdown(f"Using [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) for better transcription accuracy.")
with gr.Row():
input_audio = gr.Audio(sources=["upload"], type="numpy")
output = gr.Textbox(label="Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0")
with gr.Row():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear Output")
submit_button.click(transcribe, [input_audio, output], [output, latency_textbox])
clear_button.click(clear, outputs=[output])
# βœ… πŸ”Ÿ Final Gradio App
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
gr.TabbedInterface([microphone, file], ["Microphone", "Upload Audio"])
# βœ… 1️⃣1️⃣ Run Gradio Locally
if __name__ == "__main__":
demo.launch()