File size: 11,421 Bytes
358b417
 
 
1c3d982
358b417
1c3d982
358b417
 
1c3d982
358b417
 
 
 
 
1c3d982
 
76f9924
61b4964
d60cb2b
61b4964
 
 
 
5870b8b
 
 
 
 
76f9924
 
358b417
61b4964
1c3d982
81b1678
1c3d982
81b1678
61b4964
1c3d982
76f9924
61b4964
1c3d982
76f9924
 
 
 
 
 
 
 
 
 
 
 
 
 
358b417
1c3d982
 
 
76f9924
 
26dad8e
 
 
 
 
76f9924
358b417
76f9924
358b417
76f9924
358b417
76f9924
 
 
 
 
 
 
 
 
 
5987169
76f9924
 
5987169
76f9924
 
 
26dad8e
 
 
76f9924
61b4964
358b417
76f9924
57f60e0
3c29c1f
5870b8b
76f9924
a58a7e6
1c3d982
26dad8e
a58a7e6
 
26dad8e
76f9924
3c29c1f
 
 
26dad8e
3c29c1f
61b4964
3c29c1f
61b4964
 
 
3c29c1f
26dad8e
a58a7e6
3c29c1f
26dad8e
 
1c3d982
3c29c1f
1c3d982
 
26dad8e
 
 
 
a58a7e6
61b4964
 
 
 
 
 
76f9924
1c3d982
 
26dad8e
1c3d982
 
61b4964
 
1c3d982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76f9924
61b4964
26dad8e
1c3d982
76f9924
 
 
 
 
1c3d982
 
d60cb2b
 
61b4964
1c3d982
 
d60cb2b
1c3d982
d60cb2b
 
1c3d982
 
 
 
61b4964
 
1c3d982
 
d60cb2b
1c3d982
 
 
 
 
 
 
76f9924
89bb4b5
 
 
 
 
 
 
 
 
 
 
 
cc218c9
89bb4b5
 
 
 
cc218c9
89bb4b5
 
 
 
 
 
 
 
 
 
1c3d982
89bb4b5
 
 
 
 
 
1c3d982
89bb4b5
 
 
 
 
 
d60cb2b
1c3d982
89bb4b5
 
76f9924
89bb4b5
 
 
 
e9ae11b
 
5bd196a
 
 
 
 
89bb4b5
 
d60cb2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import streamlit as st
import torch
import librosa
import numpy as np
import tempfile
from transformers import AutoModelForCTC, Wav2Vec2Processor, WhisperProcessor, WhisperForConditionalGeneration
from pyctcdecode import build_ctcdecoder
from huggingface_hub import hf_hub_download
from jiwer import wer
import json
import gzip
import shutil
import os
from pydub import AudioSegment
import time
import re

# Constants
WHISPER_FINETUNED = "wy0909/whisper-medium_mixedLanguageModel"
WHISPER_PRETRAINED = "openai/whisper-medium"
WAV2VEC_MODEL = "mesolitica/wav2vec2-xls-r-300m-mixed"
MAX_RECORDING_SECONDS = 12

def capitalize_sentences(text):
    sentences = re.split(r'(?<=[.!?]) +', text)
    capitalized = [s.strip().capitalize() for s in sentences]
    return ' '.join(capitalized)

# Main title
st.title("๐ŸŽ™๏ธ Bahasa Rojak Speech-to-Text")

# Sidebar configuration
st.sidebar.title("Model Configuration")
model_type = st.sidebar.selectbox(
    "Select Model Type",
    ["Whisper", "Wav2Vec2"],
    index=0,
    help="Choose between Whisper or Wav2Vec2 models"
)

# Session state initialization
if "audio_bytes" not in st.session_state:
    st.session_state.audio_bytes = None
if "audio_path" not in st.session_state:
    st.session_state.audio_path = None
if "ground_truth" not in st.session_state:
    st.session_state.ground_truth = ""
if "predicted_text" not in st.session_state:
    st.session_state.predicted_text = ""
if "wer_value" not in st.session_state:
    st.session_state.wer_value = None
if "selected_tab" not in st.session_state:
    st.session_state.selected_tab = "๐Ÿ“ Upload Audio"
if "previous_tab" not in st.session_state:
    st.session_state.previous_tab = "๐Ÿ“ Upload Audio"

# Tab Selection
tab1, tab2 = st.tabs(["๐Ÿ“ Upload Audio", "๐ŸŽค Record Audio"])

# Reset state if tab is changed
if st.session_state.selected_tab != st.session_state.previous_tab:
    st.session_state.audio_bytes = None
    st.session_state.audio_path = None
    st.session_state.ground_truth = ""
    st.session_state.predicted_text = ""
    st.session_state.wer_value = None
    st.session_state.previous_tab = st.session_state.selected_tab

# Tab 1: Upload Audio
with tab1:
    uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3"])
    if uploaded_file:
        try:
            st.session_state.audio_bytes = uploaded_file.read()
            with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
                tmp.write(st.session_state.audio_bytes)
                st.session_state.audio_path = tmp.name

            if uploaded_file.name.endswith(".mp3"):
                audio = AudioSegment.from_mp3(st.session_state.audio_path)
                wav_path = st.session_state.audio_path.replace(".mp3", ".wav")
                audio.export(wav_path, format="wav")
                os.unlink(st.session_state.audio_path)
                st.session_state.audio_path = wav_path

            librosa.load(st.session_state.audio_path, sr=16000)
            st.audio(st.session_state.audio_bytes, format="audio/wav")
        except Exception as e:
            st.error(f"โŒ Failed to read audio file: {str(e)}")
            if 'st.session_state.audio_path' in locals() and os.path.exists(st.session_state.audio_path):
                os.unlink(st.session_state.audio_path)
            st.session_state.audio_bytes = None

# Tab 2: Record Audio
with tab2:
    st.session_state.selected_tab = "๐ŸŽค Record Audio"
    st.caption(f"Click microphone below to start recording (max {MAX_RECORDING_SECONDS} seconds)")
    
    audio_input = st.audio_input("๐ŸŽ™๏ธ Record Audio")
    
    if audio_input:
        try:
            # Get the audio bytes in the correct format
            audio_bytes = audio_input.read() if hasattr(audio_input, 'read') else audio_input.getvalue()
            
            # Save to temporary file
            with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
                tmp.write(audio_bytes)
                temp_path = tmp.name
            
            # Check duration
            audio_segment = AudioSegment.from_file(temp_path)
            duration_seconds = len(audio_segment) / 1000
            
            if duration_seconds > MAX_RECORDING_SECONDS:
                st.error(f"โŒ Recording too long! Please keep it under {MAX_RECORDING_SECONDS} seconds.")
                os.unlink(temp_path)
            else:
                # Store in session state
                st.session_state.audio_bytes = audio_bytes
                st.session_state.audio_path = temp_path
                
                # Validate and display
                librosa.load(st.session_state.audio_path, sr=16000)
                
        except Exception as e:
            st.error(f"โŒ Failed to process recorded audio: {str(e)}")
            if 'temp_path' in locals() and os.path.exists(temp_path):
                os.unlink(temp_path)
            st.session_state.audio_bytes = None
            st.session_state.audio_path = None

# Input ground truth for WER
st.session_state.ground_truth = st.text_input(
    "Enter ground truth for WER calculation (Optional)",
    value=st.session_state.ground_truth,
    key="ground_truth_input"
)

# Model-specific configurations
if model_type == "Wav2Vec2":
    # Wav2Vec2 configuration
    @st.cache_resource
    def load_model():
        processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_MODEL)
        model = AutoModelForCTC.from_pretrained(WAV2VEC_MODEL)
        model.eval()
        return processor, model

    @st.cache_resource
    def load_decoder():
        vocab_path = hf_hub_download(
            repo_id="ashantharosary/wav2vec2-ngram-finetuned",
            filename="vocab.json",
            repo_type="model"
        )

        with open(vocab_path, "r") as f:
            vocab_dict = json.load(f)
        sorted_vocab = sorted(vocab_dict.items(), key=lambda item: item[1])
        vocab_list = [k.lower() for k, v in sorted_vocab]

        arpa_gz_path = hf_hub_download(
            repo_id="ashantharosary/wav2vec2-ngram-finetuned",
            filename="4gram.arpa.gz",
            repo_type="model"
        )

        arpa_path = "4gram.arpa"
        if not os.path.exists(arpa_path):
            with gzip.open(arpa_gz_path, 'rb') as f_in:
                with open(arpa_path, 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)

        decoder = build_ctcdecoder(vocab_list, kenlm_model_path=arpa_path, alpha=0.2, beta=1.0)
        return decoder

    processor, model = load_model()
    decoder = load_decoder()

    mode = st.selectbox(
        "Choose transcription method:",
        ["Without N-gram LM", "With N-gram LM"],
        help="Select whether to use language model decoding"
    )
else:
    # Whisper configuration
    model_choice = st.selectbox(
        "Select Whisper Model",
        options=["Fine-tuned Model", "Pretrained Whisper-Medium Model"],
        help="Choose the Whisper model to transcribe the audio"
    )

    @st.cache_resource
    def load_finetuned_model_and_processor():
        model = WhisperForConditionalGeneration.from_pretrained(WHISPER_FINETUNED,torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        attn_implementation="flash_attention_2" if torch.cuda.is_available() else None)
        processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED)
        model.config.forced_decoder_ids = None
        model.generation_config.forced_decoder_ids = None
        model.config.use_cache = None
        model.config.suppress_tokens = []
        if torch.cuda.is_available():
            model = model.to("cuda")
        return model, processor

    @st.cache_resource
    def load_pretrained_model_and_processor():
        model = WhisperForConditionalGeneration.from_pretrained(WHISPER_PRETRAINED)
        processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED)
        model.config.forced_decoder_ids = None
        model.generation_config.forced_decoder_ids = None
        model.config.use_cache = None
        model.config.suppress_tokens = []
        return model, processor

    if model_choice == "Fine-tuned Model":
        model, processor = load_finetuned_model_and_processor()
    else:
        model, processor = load_pretrained_model_and_processor()

# Transcription Button
if st.button("๐Ÿ“ Transcribe"):
    if not st.session_state.audio_bytes:
        st.error("โŒ Please upload or record an audio file first.")
    else:
        start_time = time.time()
        try:
            if model_type == "Wav2Vec2":
                audio_input, _ = librosa.load(st.session_state.audio_path, sr=16000)
                input_values = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values
                with torch.no_grad():
                    logits = model(input_values).logits[0].cpu().numpy()

                if mode == "With N-gram LM":
                    decoded_ngram = decoder.decode_beams(logits, prune_history=True)
                    st.session_state.predicted_text = decoded_ngram[0][0]
                    st.markdown("### ๐Ÿ”Š Transcription with N-gram Language Model")
                    st.success(st.session_state.predicted_text)
                else:
                    predicted_ids = torch.argmax(torch.tensor(logits), dim=-1)
                    st.session_state.predicted_text = processor.batch_decode(predicted_ids.unsqueeze(0))[0]
                    st.markdown("### ๐ŸŽค Transcription without N-gram Language Model")
                    st.info(st.session_state.predicted_text)
            else:
                audio_input_data, _ = librosa.load(st.session_state.audio_path, sr=16000)
                input_features = processor(
                    audio_input_data, sampling_rate=16000, return_tensors="pt"
                ).input_features

                predicted_ids = model.generate(input_features)
                transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
                transcription = capitalize_sentences(transcription)
                st.session_state.predicted_text = transcription
                st.markdown("### ๐Ÿ”Š Predicted Transcription")
                st.success(st.session_state.predicted_text)

            if st.session_state.ground_truth:
                st.session_state.wer_value = wer(
                    st.session_state.ground_truth.lower(),
                    st.session_state.predicted_text.lower()
                )
                st.markdown("### ๐Ÿงฎ Word Error Rate (WER)")
                st.write(f"WER: `{st.session_state.wer_value * 100:.2f}%`")

        except Exception as e:
            st.error(f"โŒ Transcription failed: {str(e)}")

        finally:
            # Clean up temporary files
            if st.session_state.audio_path and os.path.exists(st.session_state.audio_path):
                os.unlink(st.session_state.audio_path)
            st.session_state.audio_bytes = None
            st.session_state.audio_path = None
            st.session_state.audio_path = None
            st.session_state.predicted_text = ""
            st.session_state.ground_truth = ""
            st.session_state.wer_value = None
 
            end_time = time.time()
            duration = end_time - start_time
            st.caption(f"๐Ÿ•’ Time taken: {duration:.2f}s")