import streamlit as st import torch import librosa import numpy as np import tempfile from transformers import WhisperProcessor, WhisperForConditionalGeneration from jiwer import wer import os from pydub import AudioSegment import time import re # Constants WHISPER_FINETUNED = "wy0909/whisper-medium_mixedLanguageModel" WHISPER_PRETRAINED = "openai/whisper-medium" MAX_RECORDING_SECONDS = 12 def capitalize_sentences(text): sentences = re.split(r'(?<=[.!?]) +', text) capitalized = [s.strip().capitalize() for s in sentences] return ' '.join(capitalized) # Main title st.title("🎙️ Speech-to-Text with Whisper") # Session state initialization if "audio_bytes" not in st.session_state: st.session_state.audio_bytes = None if "audio_path" not in st.session_state: st.session_state.audio_path = None if "ground_truth" not in st.session_state: st.session_state.ground_truth = "" if "predicted_text" not in st.session_state: st.session_state.predicted_text = "" if "wer_value" not in st.session_state: st.session_state.wer_value = None if "selected_tab" not in st.session_state: st.session_state.selected_tab = "📁 Upload Audio" if "previous_tab" not in st.session_state: st.session_state.previous_tab = "📁 Upload Audio" # Tab Selection tab1, tab2 = st.tabs(["📁 Upload Audio", "🎤 Record Audio"]) # Reset state if tab is changed if st.session_state.selected_tab != st.session_state.previous_tab: st.session_state.audio_bytes = None st.session_state.audio_path = None st.session_state.ground_truth = "" st.session_state.predicted_text = "" st.session_state.wer_value = None st.session_state.previous_tab = st.session_state.selected_tab # Tab 1: Upload Audio with tab1: uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3"]) if uploaded_file: try: st.session_state.audio_bytes = uploaded_file.read() with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp: tmp.write(st.session_state.audio_bytes) st.session_state.audio_path = tmp.name if uploaded_file.name.endswith(".mp3"): audio = AudioSegment.from_mp3(st.session_state.audio_path) wav_path = st.session_state.audio_path.replace(".mp3", ".wav") audio.export(wav_path, format="wav") os.unlink(st.session_state.audio_path) st.session_state.audio_path = wav_path librosa.load(st.session_state.audio_path, sr=16000) st.audio(st.session_state.audio_bytes, format="audio/wav") except Exception as e: st.error(f"❌ Failed to read audio file: {str(e)}") if 'st.session_state.audio_path' in locals() and os.path.exists(st.session_state.audio_path): os.unlink(st.session_state.audio_path) st.session_state.audio_bytes = None # Tab 2: Record Audio with tab2: st.session_state.selected_tab = "🎤 Record Audio" st.caption(f"Click microphone below to start recording (max {MAX_RECORDING_SECONDS} seconds)") audio_input = st.audio_input("🎙️ Record Audio") if audio_input: try: # Get the audio bytes in the correct format audio_bytes = audio_input.read() if hasattr(audio_input, 'read') else audio_input.getvalue() # Save to temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(audio_bytes) temp_path = tmp.name # Check duration audio_segment = AudioSegment.from_file(temp_path) duration_seconds = len(audio_segment) / 1000 if duration_seconds > MAX_RECORDING_SECONDS: st.error(f"❌ Recording too long! Please keep it under {MAX_RECORDING_SECONDS} seconds.") os.unlink(temp_path) else: # Store in session state st.session_state.audio_bytes = audio_bytes st.session_state.audio_path = temp_path # Validate and display librosa.load(st.session_state.audio_path, sr=16000) except Exception as e: st.error(f"❌ Failed to process recorded audio: {str(e)}") if 'temp_path' in locals() and os.path.exists(temp_path): os.unlink(temp_path) st.session_state.audio_bytes = None st.session_state.audio_path = None # Input ground truth for WER st.session_state.ground_truth = st.text_input( "Enter ground truth for WER calculation (Optional)", value=st.session_state.ground_truth, key="ground_truth_input" ) # Whisper configuration model_choice = st.selectbox( "Select Whisper Model", options=["Fine-tuned Model", "Pretrained Whisper-Medium Model"], help="Choose the Whisper model to transcribe the audio" ) @st.cache_resource def load_finetuned_model_and_processor(): model = WhisperForConditionalGeneration.from_pretrained( WHISPER_FINETUNED, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, attn_implementation="flash_attention_2" if torch.cuda.is_available() else None ) processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED) model.config.forced_decoder_ids = None model.generation_config.forced_decoder_ids = None model.config.use_cache = None model.config.suppress_tokens = [] if torch.cuda.is_available(): model = model.to("cuda") return model, processor @st.cache_resource def load_pretrained_model_and_processor(): model = WhisperForConditionalGeneration.from_pretrained(WHISPER_PRETRAINED) processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED) model.config.forced_decoder_ids = None model.generation_config.forced_decoder_ids = None model.config.use_cache = None model.config.suppress_tokens = [] return model, processor if model_choice == "Fine-tuned Model": model, processor = load_finetuned_model_and_processor() else: model, processor = load_pretrained_model_and_processor() # Transcription Button if st.button("📝 Transcribe"): if not st.session_state.audio_bytes: st.error("❌ Please upload or record an audio file first.") else: start_time = time.time() try: audio_input_data, _ = librosa.load(st.session_state.audio_path, sr=16000) input_features = processor( audio_input_data, sampling_rate=16000, return_tensors="pt" ).input_features predicted_ids = model.generate(input_features) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] transcription = capitalize_sentences(transcription) st.session_state.predicted_text = transcription st.markdown("### 🔊 Predicted Transcription") st.success(st.session_state.predicted_text) if st.session_state.ground_truth: st.session_state.wer_value = wer( st.session_state.ground_truth.lower(), st.session_state.predicted_text.lower() ) st.markdown("### 🧮 Word Error Rate (WER)") st.write(f"WER: `{st.session_state.wer_value * 100:.2f}%`") except Exception as e: st.error(f"❌ Transcription failed: {str(e)}") finally: # Clean up temporary files if st.session_state.audio_path and os.path.exists(st.session_state.audio_path): os.unlink(st.session_state.audio_path) st.session_state.audio_bytes = None st.session_state.audio_path = None st.session_state.audio_path = None st.session_state.predicted_text = "" st.session_state.ground_truth = "" st.session_state.wer_value = None end_time = time.time() duration = end_time - start_time st.caption(f"🕒 Time taken: {duration:.2f}s")