ashantharosary's picture
Update app.py
e8d2580 verified
import streamlit as st
import torch
import librosa
import numpy as np
import tempfile
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from jiwer import wer
import os
from pydub import AudioSegment
import time
import re
# Constants
WHISPER_FINETUNED = "wy0909/whisper-medium_mixedLanguageModel"
WHISPER_PRETRAINED = "openai/whisper-medium"
MAX_RECORDING_SECONDS = 12
def capitalize_sentences(text):
sentences = re.split(r'(?<=[.!?]) +', text)
capitalized = [s.strip().capitalize() for s in sentences]
return ' '.join(capitalized)
# Main title
st.title("๐ŸŽ™๏ธ Speech-to-Text with Whisper")
# Session state initialization
if "audio_bytes" not in st.session_state:
st.session_state.audio_bytes = None
if "audio_path" not in st.session_state:
st.session_state.audio_path = None
if "ground_truth" not in st.session_state:
st.session_state.ground_truth = ""
if "predicted_text" not in st.session_state:
st.session_state.predicted_text = ""
if "wer_value" not in st.session_state:
st.session_state.wer_value = None
if "selected_tab" not in st.session_state:
st.session_state.selected_tab = "๐Ÿ“ Upload Audio"
if "previous_tab" not in st.session_state:
st.session_state.previous_tab = "๐Ÿ“ Upload Audio"
# Tab Selection
tab1, tab2 = st.tabs(["๐Ÿ“ Upload Audio", "๐ŸŽค Record Audio"])
# Reset state if tab is changed
if st.session_state.selected_tab != st.session_state.previous_tab:
st.session_state.audio_bytes = None
st.session_state.audio_path = None
st.session_state.ground_truth = ""
st.session_state.predicted_text = ""
st.session_state.wer_value = None
st.session_state.previous_tab = st.session_state.selected_tab
# Tab 1: Upload Audio
with tab1:
uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3"])
if uploaded_file:
try:
st.session_state.audio_bytes = uploaded_file.read()
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
tmp.write(st.session_state.audio_bytes)
st.session_state.audio_path = tmp.name
if uploaded_file.name.endswith(".mp3"):
audio = AudioSegment.from_mp3(st.session_state.audio_path)
wav_path = st.session_state.audio_path.replace(".mp3", ".wav")
audio.export(wav_path, format="wav")
os.unlink(st.session_state.audio_path)
st.session_state.audio_path = wav_path
librosa.load(st.session_state.audio_path, sr=16000)
st.audio(st.session_state.audio_bytes, format="audio/wav")
except Exception as e:
st.error(f"โŒ Failed to read audio file: {str(e)}")
if 'st.session_state.audio_path' in locals() and os.path.exists(st.session_state.audio_path):
os.unlink(st.session_state.audio_path)
st.session_state.audio_bytes = None
# Tab 2: Record Audio
with tab2:
st.session_state.selected_tab = "๐ŸŽค Record Audio"
st.caption(f"Click microphone below to start recording (max {MAX_RECORDING_SECONDS} seconds)")
audio_input = st.audio_input("๐ŸŽ™๏ธ Record Audio")
if audio_input:
try:
# Get the audio bytes in the correct format
audio_bytes = audio_input.read() if hasattr(audio_input, 'read') else audio_input.getvalue()
# Save to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio_bytes)
temp_path = tmp.name
# Check duration
audio_segment = AudioSegment.from_file(temp_path)
duration_seconds = len(audio_segment) / 1000
if duration_seconds > MAX_RECORDING_SECONDS:
st.error(f"โŒ Recording too long! Please keep it under {MAX_RECORDING_SECONDS} seconds.")
os.unlink(temp_path)
else:
# Store in session state
st.session_state.audio_bytes = audio_bytes
st.session_state.audio_path = temp_path
# Validate and display
librosa.load(st.session_state.audio_path, sr=16000)
except Exception as e:
st.error(f"โŒ Failed to process recorded audio: {str(e)}")
if 'temp_path' in locals() and os.path.exists(temp_path):
os.unlink(temp_path)
st.session_state.audio_bytes = None
st.session_state.audio_path = None
# Input ground truth for WER
st.session_state.ground_truth = st.text_input(
"Enter ground truth for WER calculation (Optional)",
value=st.session_state.ground_truth,
key="ground_truth_input"
)
# Whisper configuration
model_choice = st.selectbox(
"Select Whisper Model",
options=["Fine-tuned Model", "Pretrained Whisper-Medium Model"],
help="Choose the Whisper model to transcribe the audio"
)
@st.cache_resource
def load_finetuned_model_and_processor():
model = WhisperForConditionalGeneration.from_pretrained(
WHISPER_FINETUNED,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
)
processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED)
model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None
model.config.use_cache = None
model.config.suppress_tokens = []
if torch.cuda.is_available():
model = model.to("cuda")
return model, processor
@st.cache_resource
def load_pretrained_model_and_processor():
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_PRETRAINED)
processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED)
model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None
model.config.use_cache = None
model.config.suppress_tokens = []
return model, processor
if model_choice == "Fine-tuned Model":
model, processor = load_finetuned_model_and_processor()
else:
model, processor = load_pretrained_model_and_processor()
# Transcription Button
if st.button("๐Ÿ“ Transcribe"):
if not st.session_state.audio_bytes:
st.error("โŒ Please upload or record an audio file first.")
else:
start_time = time.time()
try:
audio_input_data, _ = librosa.load(st.session_state.audio_path, sr=16000)
input_features = processor(
audio_input_data, sampling_rate=16000, return_tensors="pt"
).input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
transcription = capitalize_sentences(transcription)
st.session_state.predicted_text = transcription
st.markdown("### ๐Ÿ”Š Predicted Transcription")
st.success(st.session_state.predicted_text)
if st.session_state.ground_truth:
st.session_state.wer_value = wer(
st.session_state.ground_truth.lower(),
st.session_state.predicted_text.lower()
)
st.markdown("### ๐Ÿงฎ Word Error Rate (WER)")
st.write(f"WER: `{st.session_state.wer_value * 100:.2f}%`")
except Exception as e:
st.error(f"โŒ Transcription failed: {str(e)}")
finally:
# Clean up temporary files
if st.session_state.audio_path and os.path.exists(st.session_state.audio_path):
os.unlink(st.session_state.audio_path)
st.session_state.audio_bytes = None
st.session_state.audio_path = None
st.session_state.audio_path = None
st.session_state.predicted_text = ""
st.session_state.ground_truth = ""
st.session_state.wer_value = None
end_time = time.time()
duration = end_time - start_time
st.caption(f"๐Ÿ•’ Time taken: {duration:.2f}s")