|
import streamlit as st |
|
import torch |
|
import librosa |
|
import numpy as np |
|
import tempfile |
|
from transformers import AutoModelForCTC, Wav2Vec2Processor, WhisperProcessor, WhisperForConditionalGeneration |
|
from pyctcdecode import build_ctcdecoder |
|
from huggingface_hub import hf_hub_download |
|
from jiwer import wer |
|
import json |
|
import gzip |
|
import shutil |
|
import os |
|
from pydub import AudioSegment |
|
import time |
|
import re |
|
|
|
|
|
WHISPER_FINETUNED = "wy0909/whisper-medium_mixedLanguageModel" |
|
WHISPER_PRETRAINED = "openai/whisper-medium" |
|
WAV2VEC_MODEL = "mesolitica/wav2vec2-xls-r-300m-mixed" |
|
MAX_RECORDING_SECONDS = 12 |
|
|
|
def capitalize_sentences(text): |
|
sentences = re.split(r'(?<=[.!?]) +', text) |
|
capitalized = [s.strip().capitalize() for s in sentences] |
|
return ' '.join(capitalized) |
|
|
|
|
|
st.title("๐๏ธ Bahasa Rojak Speech-to-Text") |
|
|
|
|
|
st.sidebar.title("Model Configuration") |
|
model_type = st.sidebar.selectbox( |
|
"Select Model Type", |
|
["Whisper", "Wav2Vec2"], |
|
index=0, |
|
help="Choose between Whisper or Wav2Vec2 models" |
|
) |
|
|
|
|
|
if "audio_bytes" not in st.session_state: |
|
st.session_state.audio_bytes = None |
|
if "audio_path" not in st.session_state: |
|
st.session_state.audio_path = None |
|
if "ground_truth" not in st.session_state: |
|
st.session_state.ground_truth = "" |
|
if "predicted_text" not in st.session_state: |
|
st.session_state.predicted_text = "" |
|
if "wer_value" not in st.session_state: |
|
st.session_state.wer_value = None |
|
if "selected_tab" not in st.session_state: |
|
st.session_state.selected_tab = "๐ Upload Audio" |
|
if "previous_tab" not in st.session_state: |
|
st.session_state.previous_tab = "๐ Upload Audio" |
|
|
|
|
|
tab1, tab2 = st.tabs(["๐ Upload Audio", "๐ค Record Audio"]) |
|
|
|
|
|
if st.session_state.selected_tab != st.session_state.previous_tab: |
|
st.session_state.audio_bytes = None |
|
st.session_state.audio_path = None |
|
st.session_state.ground_truth = "" |
|
st.session_state.predicted_text = "" |
|
st.session_state.wer_value = None |
|
st.session_state.previous_tab = st.session_state.selected_tab |
|
|
|
|
|
with tab1: |
|
uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3"]) |
|
if uploaded_file: |
|
try: |
|
st.session_state.audio_bytes = uploaded_file.read() |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp: |
|
tmp.write(st.session_state.audio_bytes) |
|
st.session_state.audio_path = tmp.name |
|
|
|
if uploaded_file.name.endswith(".mp3"): |
|
audio = AudioSegment.from_mp3(st.session_state.audio_path) |
|
wav_path = st.session_state.audio_path.replace(".mp3", ".wav") |
|
audio.export(wav_path, format="wav") |
|
os.unlink(st.session_state.audio_path) |
|
st.session_state.audio_path = wav_path |
|
|
|
librosa.load(st.session_state.audio_path, sr=16000) |
|
st.audio(st.session_state.audio_bytes, format="audio/wav") |
|
except Exception as e: |
|
st.error(f"โ Failed to read audio file: {str(e)}") |
|
if 'st.session_state.audio_path' in locals() and os.path.exists(st.session_state.audio_path): |
|
os.unlink(st.session_state.audio_path) |
|
st.session_state.audio_bytes = None |
|
|
|
|
|
with tab2: |
|
st.session_state.selected_tab = "๐ค Record Audio" |
|
st.caption(f"Click microphone below to start recording (max {MAX_RECORDING_SECONDS} seconds)") |
|
|
|
audio_input = st.audio_input("๐๏ธ Record Audio") |
|
|
|
if audio_input: |
|
try: |
|
|
|
audio_bytes = audio_input.read() if hasattr(audio_input, 'read') else audio_input.getvalue() |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
|
tmp.write(audio_bytes) |
|
temp_path = tmp.name |
|
|
|
|
|
audio_segment = AudioSegment.from_file(temp_path) |
|
duration_seconds = len(audio_segment) / 1000 |
|
|
|
if duration_seconds > MAX_RECORDING_SECONDS: |
|
st.error(f"โ Recording too long! Please keep it under {MAX_RECORDING_SECONDS} seconds.") |
|
os.unlink(temp_path) |
|
else: |
|
|
|
st.session_state.audio_bytes = audio_bytes |
|
st.session_state.audio_path = temp_path |
|
|
|
|
|
librosa.load(st.session_state.audio_path, sr=16000) |
|
|
|
except Exception as e: |
|
st.error(f"โ Failed to process recorded audio: {str(e)}") |
|
if 'temp_path' in locals() and os.path.exists(temp_path): |
|
os.unlink(temp_path) |
|
st.session_state.audio_bytes = None |
|
st.session_state.audio_path = None |
|
|
|
|
|
st.session_state.ground_truth = st.text_input( |
|
"Enter ground truth for WER calculation (Optional)", |
|
value=st.session_state.ground_truth, |
|
key="ground_truth_input" |
|
) |
|
|
|
|
|
if model_type == "Wav2Vec2": |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_MODEL) |
|
model = AutoModelForCTC.from_pretrained(WAV2VEC_MODEL) |
|
model.eval() |
|
return processor, model |
|
|
|
@st.cache_resource |
|
def load_decoder(): |
|
vocab_path = hf_hub_download( |
|
repo_id="ashantharosary/wav2vec2-ngram-finetuned", |
|
filename="vocab.json", |
|
repo_type="model" |
|
) |
|
|
|
with open(vocab_path, "r") as f: |
|
vocab_dict = json.load(f) |
|
sorted_vocab = sorted(vocab_dict.items(), key=lambda item: item[1]) |
|
vocab_list = [k.lower() for k, v in sorted_vocab] |
|
|
|
arpa_gz_path = hf_hub_download( |
|
repo_id="ashantharosary/wav2vec2-ngram-finetuned", |
|
filename="4gram.arpa.gz", |
|
repo_type="model" |
|
) |
|
|
|
arpa_path = "4gram.arpa" |
|
if not os.path.exists(arpa_path): |
|
with gzip.open(arpa_gz_path, 'rb') as f_in: |
|
with open(arpa_path, 'wb') as f_out: |
|
shutil.copyfileobj(f_in, f_out) |
|
|
|
decoder = build_ctcdecoder(vocab_list, kenlm_model_path=arpa_path, alpha=0.2, beta=1.0) |
|
return decoder |
|
|
|
processor, model = load_model() |
|
decoder = load_decoder() |
|
|
|
mode = st.selectbox( |
|
"Choose transcription method:", |
|
["Without N-gram LM", "With N-gram LM"], |
|
help="Select whether to use language model decoding" |
|
) |
|
else: |
|
|
|
model_choice = st.selectbox( |
|
"Select Whisper Model", |
|
options=["Fine-tuned Model", "Pretrained Whisper-Medium Model"], |
|
help="Choose the Whisper model to transcribe the audio" |
|
) |
|
|
|
@st.cache_resource |
|
def load_finetuned_model_and_processor(): |
|
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_FINETUNED,torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None) |
|
processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED) |
|
model.config.forced_decoder_ids = None |
|
model.generation_config.forced_decoder_ids = None |
|
model.config.use_cache = None |
|
model.config.suppress_tokens = [] |
|
if torch.cuda.is_available(): |
|
model = model.to("cuda") |
|
return model, processor |
|
|
|
@st.cache_resource |
|
def load_pretrained_model_and_processor(): |
|
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_PRETRAINED) |
|
processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED) |
|
model.config.forced_decoder_ids = None |
|
model.generation_config.forced_decoder_ids = None |
|
model.config.use_cache = None |
|
model.config.suppress_tokens = [] |
|
return model, processor |
|
|
|
if model_choice == "Fine-tuned Model": |
|
model, processor = load_finetuned_model_and_processor() |
|
else: |
|
model, processor = load_pretrained_model_and_processor() |
|
|
|
|
|
if st.button("๐ Transcribe"): |
|
if not st.session_state.audio_bytes: |
|
st.error("โ Please upload or record an audio file first.") |
|
else: |
|
start_time = time.time() |
|
try: |
|
if model_type == "Wav2Vec2": |
|
audio_input, _ = librosa.load(st.session_state.audio_path, sr=16000) |
|
input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values |
|
with torch.no_grad(): |
|
logits = model(input_values).logits[0].cpu().numpy() |
|
|
|
if mode == "With N-gram LM": |
|
decoded_ngram = decoder.decode_beams(logits, prune_history=True) |
|
st.session_state.predicted_text = decoded_ngram[0][0] |
|
st.markdown("### ๐ Transcription with N-gram Language Model") |
|
st.success(st.session_state.predicted_text) |
|
else: |
|
predicted_ids = torch.argmax(torch.tensor(logits), dim=-1) |
|
st.session_state.predicted_text = processor.batch_decode(predicted_ids.unsqueeze(0))[0] |
|
st.markdown("### ๐ค Transcription without N-gram Language Model") |
|
st.info(st.session_state.predicted_text) |
|
else: |
|
audio_input_data, _ = librosa.load(st.session_state.audio_path, sr=16000) |
|
input_features = processor( |
|
audio_input_data, sampling_rate=16000, return_tensors="pt" |
|
).input_features |
|
|
|
predicted_ids = model.generate(input_features) |
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
transcription = capitalize_sentences(transcription) |
|
st.session_state.predicted_text = transcription |
|
st.markdown("### ๐ Predicted Transcription") |
|
st.success(st.session_state.predicted_text) |
|
|
|
if st.session_state.ground_truth: |
|
st.session_state.wer_value = wer( |
|
st.session_state.ground_truth.lower(), |
|
st.session_state.predicted_text.lower() |
|
) |
|
st.markdown("### ๐งฎ Word Error Rate (WER)") |
|
st.write(f"WER: `{st.session_state.wer_value * 100:.2f}%`") |
|
|
|
except Exception as e: |
|
st.error(f"โ Transcription failed: {str(e)}") |
|
|
|
finally: |
|
|
|
if st.session_state.audio_path and os.path.exists(st.session_state.audio_path): |
|
os.unlink(st.session_state.audio_path) |
|
st.session_state.audio_bytes = None |
|
st.session_state.audio_path = None |
|
st.session_state.audio_path = None |
|
st.session_state.predicted_text = "" |
|
st.session_state.ground_truth = "" |
|
st.session_state.wer_value = None |
|
|
|
end_time = time.time() |
|
duration = end_time - start_time |
|
st.caption(f"๐ Time taken: {duration:.2f}s") |
|
|