ashantharosary's picture
Update app.py
d60cb2b verified
import streamlit as st
import torch
import librosa
import numpy as np
import tempfile
from transformers import AutoModelForCTC, Wav2Vec2Processor, WhisperProcessor, WhisperForConditionalGeneration
from pyctcdecode import build_ctcdecoder
from huggingface_hub import hf_hub_download
from jiwer import wer
import json
import gzip
import shutil
import os
from pydub import AudioSegment
import time
import re
# Constants
WHISPER_FINETUNED = "wy0909/whisper-medium_mixedLanguageModel"
WHISPER_PRETRAINED = "openai/whisper-medium"
WAV2VEC_MODEL = "mesolitica/wav2vec2-xls-r-300m-mixed"
MAX_RECORDING_SECONDS = 12
def capitalize_sentences(text):
sentences = re.split(r'(?<=[.!?]) +', text)
capitalized = [s.strip().capitalize() for s in sentences]
return ' '.join(capitalized)
# Main title
st.title("๐ŸŽ™๏ธ Bahasa Rojak Speech-to-Text")
# Sidebar configuration
st.sidebar.title("Model Configuration")
model_type = st.sidebar.selectbox(
"Select Model Type",
["Whisper", "Wav2Vec2"],
index=0,
help="Choose between Whisper or Wav2Vec2 models"
)
# Session state initialization
if "audio_bytes" not in st.session_state:
st.session_state.audio_bytes = None
if "audio_path" not in st.session_state:
st.session_state.audio_path = None
if "ground_truth" not in st.session_state:
st.session_state.ground_truth = ""
if "predicted_text" not in st.session_state:
st.session_state.predicted_text = ""
if "wer_value" not in st.session_state:
st.session_state.wer_value = None
if "selected_tab" not in st.session_state:
st.session_state.selected_tab = "๐Ÿ“ Upload Audio"
if "previous_tab" not in st.session_state:
st.session_state.previous_tab = "๐Ÿ“ Upload Audio"
# Tab Selection
tab1, tab2 = st.tabs(["๐Ÿ“ Upload Audio", "๐ŸŽค Record Audio"])
# Reset state if tab is changed
if st.session_state.selected_tab != st.session_state.previous_tab:
st.session_state.audio_bytes = None
st.session_state.audio_path = None
st.session_state.ground_truth = ""
st.session_state.predicted_text = ""
st.session_state.wer_value = None
st.session_state.previous_tab = st.session_state.selected_tab
# Tab 1: Upload Audio
with tab1:
uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3"])
if uploaded_file:
try:
st.session_state.audio_bytes = uploaded_file.read()
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
tmp.write(st.session_state.audio_bytes)
st.session_state.audio_path = tmp.name
if uploaded_file.name.endswith(".mp3"):
audio = AudioSegment.from_mp3(st.session_state.audio_path)
wav_path = st.session_state.audio_path.replace(".mp3", ".wav")
audio.export(wav_path, format="wav")
os.unlink(st.session_state.audio_path)
st.session_state.audio_path = wav_path
librosa.load(st.session_state.audio_path, sr=16000)
st.audio(st.session_state.audio_bytes, format="audio/wav")
except Exception as e:
st.error(f"โŒ Failed to read audio file: {str(e)}")
if 'st.session_state.audio_path' in locals() and os.path.exists(st.session_state.audio_path):
os.unlink(st.session_state.audio_path)
st.session_state.audio_bytes = None
# Tab 2: Record Audio
with tab2:
st.session_state.selected_tab = "๐ŸŽค Record Audio"
st.caption(f"Click microphone below to start recording (max {MAX_RECORDING_SECONDS} seconds)")
audio_input = st.audio_input("๐ŸŽ™๏ธ Record Audio")
if audio_input:
try:
# Get the audio bytes in the correct format
audio_bytes = audio_input.read() if hasattr(audio_input, 'read') else audio_input.getvalue()
# Save to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio_bytes)
temp_path = tmp.name
# Check duration
audio_segment = AudioSegment.from_file(temp_path)
duration_seconds = len(audio_segment) / 1000
if duration_seconds > MAX_RECORDING_SECONDS:
st.error(f"โŒ Recording too long! Please keep it under {MAX_RECORDING_SECONDS} seconds.")
os.unlink(temp_path)
else:
# Store in session state
st.session_state.audio_bytes = audio_bytes
st.session_state.audio_path = temp_path
# Validate and display
librosa.load(st.session_state.audio_path, sr=16000)
except Exception as e:
st.error(f"โŒ Failed to process recorded audio: {str(e)}")
if 'temp_path' in locals() and os.path.exists(temp_path):
os.unlink(temp_path)
st.session_state.audio_bytes = None
st.session_state.audio_path = None
# Input ground truth for WER
st.session_state.ground_truth = st.text_input(
"Enter ground truth for WER calculation (Optional)",
value=st.session_state.ground_truth,
key="ground_truth_input"
)
# Model-specific configurations
if model_type == "Wav2Vec2":
# Wav2Vec2 configuration
@st.cache_resource
def load_model():
processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_MODEL)
model = AutoModelForCTC.from_pretrained(WAV2VEC_MODEL)
model.eval()
return processor, model
@st.cache_resource
def load_decoder():
vocab_path = hf_hub_download(
repo_id="ashantharosary/wav2vec2-ngram-finetuned",
filename="vocab.json",
repo_type="model"
)
with open(vocab_path, "r") as f:
vocab_dict = json.load(f)
sorted_vocab = sorted(vocab_dict.items(), key=lambda item: item[1])
vocab_list = [k.lower() for k, v in sorted_vocab]
arpa_gz_path = hf_hub_download(
repo_id="ashantharosary/wav2vec2-ngram-finetuned",
filename="4gram.arpa.gz",
repo_type="model"
)
arpa_path = "4gram.arpa"
if not os.path.exists(arpa_path):
with gzip.open(arpa_gz_path, 'rb') as f_in:
with open(arpa_path, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
decoder = build_ctcdecoder(vocab_list, kenlm_model_path=arpa_path, alpha=0.2, beta=1.0)
return decoder
processor, model = load_model()
decoder = load_decoder()
mode = st.selectbox(
"Choose transcription method:",
["Without N-gram LM", "With N-gram LM"],
help="Select whether to use language model decoding"
)
else:
# Whisper configuration
model_choice = st.selectbox(
"Select Whisper Model",
options=["Fine-tuned Model", "Pretrained Whisper-Medium Model"],
help="Choose the Whisper model to transcribe the audio"
)
@st.cache_resource
def load_finetuned_model_and_processor():
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_FINETUNED,torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None)
processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED)
model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None
model.config.use_cache = None
model.config.suppress_tokens = []
if torch.cuda.is_available():
model = model.to("cuda")
return model, processor
@st.cache_resource
def load_pretrained_model_and_processor():
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_PRETRAINED)
processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED)
model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None
model.config.use_cache = None
model.config.suppress_tokens = []
return model, processor
if model_choice == "Fine-tuned Model":
model, processor = load_finetuned_model_and_processor()
else:
model, processor = load_pretrained_model_and_processor()
# Transcription Button
if st.button("๐Ÿ“ Transcribe"):
if not st.session_state.audio_bytes:
st.error("โŒ Please upload or record an audio file first.")
else:
start_time = time.time()
try:
if model_type == "Wav2Vec2":
audio_input, _ = librosa.load(st.session_state.audio_path, sr=16000)
input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values
with torch.no_grad():
logits = model(input_values).logits[0].cpu().numpy()
if mode == "With N-gram LM":
decoded_ngram = decoder.decode_beams(logits, prune_history=True)
st.session_state.predicted_text = decoded_ngram[0][0]
st.markdown("### ๐Ÿ”Š Transcription with N-gram Language Model")
st.success(st.session_state.predicted_text)
else:
predicted_ids = torch.argmax(torch.tensor(logits), dim=-1)
st.session_state.predicted_text = processor.batch_decode(predicted_ids.unsqueeze(0))[0]
st.markdown("### ๐ŸŽค Transcription without N-gram Language Model")
st.info(st.session_state.predicted_text)
else:
audio_input_data, _ = librosa.load(st.session_state.audio_path, sr=16000)
input_features = processor(
audio_input_data, sampling_rate=16000, return_tensors="pt"
).input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
transcription = capitalize_sentences(transcription)
st.session_state.predicted_text = transcription
st.markdown("### ๐Ÿ”Š Predicted Transcription")
st.success(st.session_state.predicted_text)
if st.session_state.ground_truth:
st.session_state.wer_value = wer(
st.session_state.ground_truth.lower(),
st.session_state.predicted_text.lower()
)
st.markdown("### ๐Ÿงฎ Word Error Rate (WER)")
st.write(f"WER: `{st.session_state.wer_value * 100:.2f}%`")
except Exception as e:
st.error(f"โŒ Transcription failed: {str(e)}")
finally:
# Clean up temporary files
if st.session_state.audio_path and os.path.exists(st.session_state.audio_path):
os.unlink(st.session_state.audio_path)
st.session_state.audio_bytes = None
st.session_state.audio_path = None
st.session_state.audio_path = None
st.session_state.predicted_text = ""
st.session_state.ground_truth = ""
st.session_state.wer_value = None
end_time = time.time()
duration = end_time - start_time
st.caption(f"๐Ÿ•’ Time taken: {duration:.2f}s")