import streamlit as st import torch import librosa import numpy as np import tempfile from transformers import AutoModelForCTC, Wav2Vec2Processor, WhisperProcessor, WhisperForConditionalGeneration from pyctcdecode import build_ctcdecoder from huggingface_hub import hf_hub_download from jiwer import wer import json import gzip import shutil import os from pydub import AudioSegment import time import re # Constants WHISPER_FINETUNED = "wy0909/whisper-medium_mixedLanguageModel" WHISPER_PRETRAINED = "openai/whisper-medium" WAV2VEC_MODEL = "mesolitica/wav2vec2-xls-r-300m-mixed" MAX_RECORDING_SECONDS = 12 def capitalize_sentences(text): sentences = re.split(r'(?<=[.!?]) +', text) capitalized = [s.strip().capitalize() for s in sentences] return ' '.join(capitalized) # Main title st.title("🎙️ Bahasa Rojak Speech-to-Text") # Sidebar configuration st.sidebar.title("Model Configuration") model_type = st.sidebar.selectbox( "Select Model Type", ["Whisper", "Wav2Vec2"], index=0, help="Choose between Whisper or Wav2Vec2 models" ) # Session state initialization if "audio_bytes" not in st.session_state: st.session_state.audio_bytes = None if "audio_path" not in st.session_state: st.session_state.audio_path = None if "ground_truth" not in st.session_state: st.session_state.ground_truth = "" if "predicted_text" not in st.session_state: st.session_state.predicted_text = "" if "wer_value" not in st.session_state: st.session_state.wer_value = None if "selected_tab" not in st.session_state: st.session_state.selected_tab = "📁 Upload Audio" if "previous_tab" not in st.session_state: st.session_state.previous_tab = "📁 Upload Audio" # Tab Selection tab1, tab2 = st.tabs(["📁 Upload Audio", "🎤 Record Audio"]) # Reset state if tab is changed if st.session_state.selected_tab != st.session_state.previous_tab: st.session_state.audio_bytes = None st.session_state.audio_path = None st.session_state.ground_truth = "" st.session_state.predicted_text = "" st.session_state.wer_value = None st.session_state.previous_tab = st.session_state.selected_tab # Tab 1: Upload Audio with tab1: uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3"]) if uploaded_file: try: st.session_state.audio_bytes = uploaded_file.read() with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp: tmp.write(st.session_state.audio_bytes) st.session_state.audio_path = tmp.name if uploaded_file.name.endswith(".mp3"): audio = AudioSegment.from_mp3(st.session_state.audio_path) wav_path = st.session_state.audio_path.replace(".mp3", ".wav") audio.export(wav_path, format="wav") os.unlink(st.session_state.audio_path) st.session_state.audio_path = wav_path librosa.load(st.session_state.audio_path, sr=16000) st.audio(st.session_state.audio_bytes, format="audio/wav") except Exception as e: st.error(f"❌ Failed to read audio file: {str(e)}") if 'st.session_state.audio_path' in locals() and os.path.exists(st.session_state.audio_path): os.unlink(st.session_state.audio_path) st.session_state.audio_bytes = None # Tab 2: Record Audio with tab2: st.session_state.selected_tab = "🎤 Record Audio" st.caption(f"Click microphone below to start recording (max {MAX_RECORDING_SECONDS} seconds)") audio_input = st.audio_input("🎙️ Record Audio") if audio_input: try: # Get the audio bytes in the correct format audio_bytes = audio_input.read() if hasattr(audio_input, 'read') else audio_input.getvalue() # Save to temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(audio_bytes) temp_path = tmp.name # Check duration audio_segment = AudioSegment.from_file(temp_path) duration_seconds = len(audio_segment) / 1000 if duration_seconds > MAX_RECORDING_SECONDS: st.error(f"❌ Recording too long! Please keep it under {MAX_RECORDING_SECONDS} seconds.") os.unlink(temp_path) else: # Store in session state st.session_state.audio_bytes = audio_bytes st.session_state.audio_path = temp_path # Validate and display librosa.load(st.session_state.audio_path, sr=16000) except Exception as e: st.error(f"❌ Failed to process recorded audio: {str(e)}") if 'temp_path' in locals() and os.path.exists(temp_path): os.unlink(temp_path) st.session_state.audio_bytes = None st.session_state.audio_path = None # Input ground truth for WER st.session_state.ground_truth = st.text_input( "Enter ground truth for WER calculation (Optional)", value=st.session_state.ground_truth, key="ground_truth_input" ) # Model-specific configurations if model_type == "Wav2Vec2": # Wav2Vec2 configuration @st.cache_resource def load_model(): processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_MODEL) model = AutoModelForCTC.from_pretrained(WAV2VEC_MODEL) model.eval() return processor, model @st.cache_resource def load_decoder(): vocab_path = hf_hub_download( repo_id="ashantharosary/wav2vec2-ngram-finetuned", filename="vocab.json", repo_type="model" ) with open(vocab_path, "r") as f: vocab_dict = json.load(f) sorted_vocab = sorted(vocab_dict.items(), key=lambda item: item[1]) vocab_list = [k.lower() for k, v in sorted_vocab] arpa_gz_path = hf_hub_download( repo_id="ashantharosary/wav2vec2-ngram-finetuned", filename="4gram.arpa.gz", repo_type="model" ) arpa_path = "4gram.arpa" if not os.path.exists(arpa_path): with gzip.open(arpa_gz_path, 'rb') as f_in: with open(arpa_path, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) decoder = build_ctcdecoder(vocab_list, kenlm_model_path=arpa_path, alpha=0.2, beta=1.0) return decoder processor, model = load_model() decoder = load_decoder() mode = st.selectbox( "Choose transcription method:", ["Without N-gram LM", "With N-gram LM"], help="Select whether to use language model decoding" ) else: # Whisper configuration model_choice = st.selectbox( "Select Whisper Model", options=["Fine-tuned Model", "Pretrained Whisper-Medium Model"], help="Choose the Whisper model to transcribe the audio" ) @st.cache_resource def load_finetuned_model_and_processor(): model = WhisperForConditionalGeneration.from_pretrained(WHISPER_FINETUNED,torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, attn_implementation="flash_attention_2" if torch.cuda.is_available() else None) processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED) model.config.forced_decoder_ids = None model.generation_config.forced_decoder_ids = None model.config.use_cache = None model.config.suppress_tokens = [] if torch.cuda.is_available(): model = model.to("cuda") return model, processor @st.cache_resource def load_pretrained_model_and_processor(): model = WhisperForConditionalGeneration.from_pretrained(WHISPER_PRETRAINED) processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED) model.config.forced_decoder_ids = None model.generation_config.forced_decoder_ids = None model.config.use_cache = None model.config.suppress_tokens = [] return model, processor if model_choice == "Fine-tuned Model": model, processor = load_finetuned_model_and_processor() else: model, processor = load_pretrained_model_and_processor() # Transcription Button if st.button("📝 Transcribe"): if not st.session_state.audio_bytes: st.error("❌ Please upload or record an audio file first.") else: start_time = time.time() try: if model_type == "Wav2Vec2": audio_input, _ = librosa.load(st.session_state.audio_path, sr=16000) input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values with torch.no_grad(): logits = model(input_values).logits[0].cpu().numpy() if mode == "With N-gram LM": decoded_ngram = decoder.decode_beams(logits, prune_history=True) st.session_state.predicted_text = decoded_ngram[0][0] st.markdown("### 🔊 Transcription with N-gram Language Model") st.success(st.session_state.predicted_text) else: predicted_ids = torch.argmax(torch.tensor(logits), dim=-1) st.session_state.predicted_text = processor.batch_decode(predicted_ids.unsqueeze(0))[0] st.markdown("### 🎤 Transcription without N-gram Language Model") st.info(st.session_state.predicted_text) else: audio_input_data, _ = librosa.load(st.session_state.audio_path, sr=16000) input_features = processor( audio_input_data, sampling_rate=16000, return_tensors="pt" ).input_features predicted_ids = model.generate(input_features) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] transcription = capitalize_sentences(transcription) st.session_state.predicted_text = transcription st.markdown("### 🔊 Predicted Transcription") st.success(st.session_state.predicted_text) if st.session_state.ground_truth: st.session_state.wer_value = wer( st.session_state.ground_truth.lower(), st.session_state.predicted_text.lower() ) st.markdown("### 🧮 Word Error Rate (WER)") st.write(f"WER: `{st.session_state.wer_value * 100:.2f}%`") except Exception as e: st.error(f"❌ Transcription failed: {str(e)}") finally: # Clean up temporary files if st.session_state.audio_path and os.path.exists(st.session_state.audio_path): os.unlink(st.session_state.audio_path) st.session_state.audio_bytes = None st.session_state.audio_path = None st.session_state.audio_path = None st.session_state.predicted_text = "" st.session_state.ground_truth = "" st.session_state.wer_value = None end_time = time.time() duration = end_time - start_time st.caption(f"🕒 Time taken: {duration:.2f}s")