ashantharosary commited on
Commit
e8d2580
ยท
verified ยท
1 Parent(s): 2ec3f86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -120
app.py CHANGED
@@ -3,144 +3,204 @@ import torch
3
  import librosa
4
  import numpy as np
5
  import tempfile
6
- from transformers import AutoModelForCTC, Wav2Vec2Processor, WhisperForConditionalGeneration, WhisperProcessor
7
- from pyctcdecode import build_ctcdecoder
8
- from huggingface_hub import hf_hub_download
9
  from jiwer import wer
10
- import json
11
- import gzip
12
- import shutil
13
  import os
14
  from pydub import AudioSegment
15
- import torchaudio
16
- import re
17
  import time
 
18
 
19
- st.set_page_config(page_title="Rojak STT", layout="centered")
20
- st.title("๐ŸŽ™๏ธ Bahasa Rojak Malaysia Speech-to-Text")
21
-
22
- # Sidebar: Model selector
23
- model_choice = st.sidebar.selectbox("Choose Model", ["wav2vec2", "whisper"])
24
-
25
- # Session State
26
- for key in ["audio_bytes", "audio_path", "ground_truth", "wer_value", "predicted_text"]:
27
- if key not in st.session_state:
28
- st.session_state[key] = None if key in ["audio_bytes", "audio_path", "wer_value"] else ""
29
 
30
- # Tabs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  tab1, tab2 = st.tabs(["๐Ÿ“ Upload Audio", "๐ŸŽค Record Audio"])
32
 
33
- # Tab 1: Upload
 
 
 
 
 
 
 
 
 
34
  with tab1:
35
- uploaded_file = st.file_uploader("Upload .wav or .mp3", type=["wav", "mp3", "flac", "m4a", "ogg"])
36
  if uploaded_file:
37
- st.session_state.audio_bytes = uploaded_file.read()
38
- ext = uploaded_file.name.split(".")[-1]
39
- with tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}") as tmp:
40
- tmp.write(st.session_state.audio_bytes)
41
- st.session_state.audio_path = tmp.name
42
-
43
- if ext == "mp3":
44
- audio = AudioSegment.from_mp3(st.session_state.audio_path)
45
- wav_path = st.session_state.audio_path.replace(".mp3", ".wav")
46
- audio.export(wav_path, format="wav")
47
- st.session_state.audio_path = wav_path
48
-
49
- librosa.load(st.session_state.audio_path, sr=16000)
50
- st.audio(st.session_state.audio_bytes, format="audio/wav")
51
-
52
- # Tab 2: Record
 
 
 
 
 
 
53
  with tab2:
54
- audio_input = st.audio_input("๐ŸŽค Record your audio")
 
 
 
 
55
  if audio_input:
56
- st.session_state.audio_bytes = audio_input.getvalue()
57
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
58
- tmp.write(st.session_state.audio_bytes)
59
- st.session_state.audio_path = tmp.name
60
- librosa.load(st.session_state.audio_path, sr=16000)
61
- st.audio(st.session_state.audio_bytes, format="audio/wav")
62
-
63
- # Clear state if no audio
64
- if not st.session_state.audio_bytes:
65
- st.session_state["ground_truth"] = ""
66
- st.session_state["predicted_text"] = ""
67
- st.session_state["wer_value"] = None
68
-
69
- # Ground truth input
70
- st.session_state["ground_truth"] = st.text_input("Optional: Enter ground truth", value=st.session_state["ground_truth"])
71
-
72
- # ---- Loaders ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  @st.cache_resource
75
- def load_wav2vec2_model():
76
- processor = Wav2Vec2Processor.from_pretrained("mesolitica/wav2vec2-xls-r-300m-mixed")
77
- model = AutoModelForCTC.from_pretrained("mesolitica/wav2vec2-xls-r-300m-mixed")
78
- model.eval()
79
- return processor, model
80
-
81
- @st.cache_resource
82
- def load_decoder():
83
- vocab_path = hf_hub_download("ashantharosary/wav2vec2-ngram-finetuned", "vocab.json", repo_type="model")
84
- with open(vocab_path, "r") as f:
85
- vocab_dict = json.load(f)
86
- vocab_list = [k.lower() for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])]
87
- arpa_gz_path = hf_hub_download("ashantharosary/wav2vec2-ngram-finetuned", "4gram.arpa.gz", repo_type="model")
88
- arpa_path = "4gram.arpa"
89
- if not os.path.exists(arpa_path):
90
- with gzip.open(arpa_gz_path, 'rb') as f_in, open(arpa_path, 'wb') as f_out:
91
- shutil.copyfileobj(f_in, f_out)
92
- return build_ctcdecoder(vocab_list, kenlm_model_path=arpa_path, alpha=0.2, beta=1.0)
93
 
94
  @st.cache_resource
95
- def load_whisper_model():
96
- model = WhisperForConditionalGeneration.from_pretrained("wy0909/Whisper-MixedLanguageModel")
97
- processor = WhisperProcessor.from_pretrained("wy0909/Whisper-MixedLanguageModel")
98
  model.config.forced_decoder_ids = None
99
  model.generation_config.forced_decoder_ids = None
 
100
  model.config.suppress_tokens = []
101
  return model, processor
102
 
103
- # ---- Transcription ----
104
-
105
- def capitalize_sentences(text):
106
- sentences = re.split(r'(?<=[.!?]) +', text)
107
- return ' '.join([s.strip().capitalize() for s in sentences])
108
-
109
- if st.button("๐Ÿ“ Transcribe", disabled=not st.session_state.audio_bytes):
110
- start_time = time.time()
111
- try:
112
- if model_choice == "wav2vec2":
113
- processor, model = load_wav2vec2_model()
114
- decoder = load_decoder()
115
- audio, _ = librosa.load(st.session_state.audio_path, sr=16000)
116
- input_values = processor(audio, return_tensors="pt", sampling_rate=16000).input_values
117
- with torch.no_grad():
118
- logits = model(input_values).logits[0].cpu().numpy()
119
- decoded_ngram = decoder.decode_beams(logits, prune_history=True)
120
- text = decoded_ngram[0][0]
121
- st.markdown("### ๐Ÿง  Transcription (Wav2Vec2 + LM)")
122
- st.success(text)
123
-
124
- else: # whisper
125
- model, processor = load_whisper_model()
126
- waveform, sr = torchaudio.load(st.session_state.audio_path)
127
- waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
128
- inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
129
- with torch.no_grad():
130
- predicted_ids = model.generate(inputs["input_features"])
131
- text = capitalize_sentences(processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
132
- st.markdown("### ๐ŸŽง Transcription (Whisper)")
133
- st.success(text)
134
-
135
- st.session_state["predicted_text"] = text
136
-
137
- if st.session_state["ground_truth"]:
138
- error = wer(st.session_state["ground_truth"].lower(), text.lower())
139
- st.session_state["wer_value"] = error
140
- st.markdown("### ๐Ÿงฎ Word Error Rate (WER)")
141
- st.write(f"WER: `{error:.2f}`")
142
-
143
- except Exception as e:
144
- st.error(f"โŒ Transcription failed: {str(e)}")
145
-
146
- st.caption(f"๐Ÿ•’ Time taken: {time.time() - start_time:.2f}s")
 
 
 
 
 
 
3
  import librosa
4
  import numpy as np
5
  import tempfile
6
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
 
 
7
  from jiwer import wer
 
 
 
8
  import os
9
  from pydub import AudioSegment
 
 
10
  import time
11
+ import re
12
 
13
+ # Constants
14
+ WHISPER_FINETUNED = "wy0909/whisper-medium_mixedLanguageModel"
15
+ WHISPER_PRETRAINED = "openai/whisper-medium"
16
+ MAX_RECORDING_SECONDS = 12
 
 
 
 
 
 
17
 
18
+ def capitalize_sentences(text):
19
+ sentences = re.split(r'(?<=[.!?]) +', text)
20
+ capitalized = [s.strip().capitalize() for s in sentences]
21
+ return ' '.join(capitalized)
22
+
23
+ # Main title
24
+ st.title("๐ŸŽ™๏ธ Speech-to-Text with Whisper")
25
+
26
+ # Session state initialization
27
+ if "audio_bytes" not in st.session_state:
28
+ st.session_state.audio_bytes = None
29
+ if "audio_path" not in st.session_state:
30
+ st.session_state.audio_path = None
31
+ if "ground_truth" not in st.session_state:
32
+ st.session_state.ground_truth = ""
33
+ if "predicted_text" not in st.session_state:
34
+ st.session_state.predicted_text = ""
35
+ if "wer_value" not in st.session_state:
36
+ st.session_state.wer_value = None
37
+ if "selected_tab" not in st.session_state:
38
+ st.session_state.selected_tab = "๐Ÿ“ Upload Audio"
39
+ if "previous_tab" not in st.session_state:
40
+ st.session_state.previous_tab = "๐Ÿ“ Upload Audio"
41
+
42
+ # Tab Selection
43
  tab1, tab2 = st.tabs(["๐Ÿ“ Upload Audio", "๐ŸŽค Record Audio"])
44
 
45
+ # Reset state if tab is changed
46
+ if st.session_state.selected_tab != st.session_state.previous_tab:
47
+ st.session_state.audio_bytes = None
48
+ st.session_state.audio_path = None
49
+ st.session_state.ground_truth = ""
50
+ st.session_state.predicted_text = ""
51
+ st.session_state.wer_value = None
52
+ st.session_state.previous_tab = st.session_state.selected_tab
53
+
54
+ # Tab 1: Upload Audio
55
  with tab1:
56
+ uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3"])
57
  if uploaded_file:
58
+ try:
59
+ st.session_state.audio_bytes = uploaded_file.read()
60
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
61
+ tmp.write(st.session_state.audio_bytes)
62
+ st.session_state.audio_path = tmp.name
63
+
64
+ if uploaded_file.name.endswith(".mp3"):
65
+ audio = AudioSegment.from_mp3(st.session_state.audio_path)
66
+ wav_path = st.session_state.audio_path.replace(".mp3", ".wav")
67
+ audio.export(wav_path, format="wav")
68
+ os.unlink(st.session_state.audio_path)
69
+ st.session_state.audio_path = wav_path
70
+
71
+ librosa.load(st.session_state.audio_path, sr=16000)
72
+ st.audio(st.session_state.audio_bytes, format="audio/wav")
73
+ except Exception as e:
74
+ st.error(f"โŒ Failed to read audio file: {str(e)}")
75
+ if 'st.session_state.audio_path' in locals() and os.path.exists(st.session_state.audio_path):
76
+ os.unlink(st.session_state.audio_path)
77
+ st.session_state.audio_bytes = None
78
+
79
+ # Tab 2: Record Audio
80
  with tab2:
81
+ st.session_state.selected_tab = "๐ŸŽค Record Audio"
82
+ st.caption(f"Click microphone below to start recording (max {MAX_RECORDING_SECONDS} seconds)")
83
+
84
+ audio_input = st.audio_input("๐ŸŽ™๏ธ Record Audio")
85
+
86
  if audio_input:
87
+ try:
88
+ # Get the audio bytes in the correct format
89
+ audio_bytes = audio_input.read() if hasattr(audio_input, 'read') else audio_input.getvalue()
90
+
91
+ # Save to temporary file
92
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
93
+ tmp.write(audio_bytes)
94
+ temp_path = tmp.name
95
+
96
+ # Check duration
97
+ audio_segment = AudioSegment.from_file(temp_path)
98
+ duration_seconds = len(audio_segment) / 1000
99
+
100
+ if duration_seconds > MAX_RECORDING_SECONDS:
101
+ st.error(f"โŒ Recording too long! Please keep it under {MAX_RECORDING_SECONDS} seconds.")
102
+ os.unlink(temp_path)
103
+ else:
104
+ # Store in session state
105
+ st.session_state.audio_bytes = audio_bytes
106
+ st.session_state.audio_path = temp_path
107
+
108
+ # Validate and display
109
+ librosa.load(st.session_state.audio_path, sr=16000)
110
+
111
+ except Exception as e:
112
+ st.error(f"โŒ Failed to process recorded audio: {str(e)}")
113
+ if 'temp_path' in locals() and os.path.exists(temp_path):
114
+ os.unlink(temp_path)
115
+ st.session_state.audio_bytes = None
116
+ st.session_state.audio_path = None
117
+
118
+ # Input ground truth for WER
119
+ st.session_state.ground_truth = st.text_input(
120
+ "Enter ground truth for WER calculation (Optional)",
121
+ value=st.session_state.ground_truth,
122
+ key="ground_truth_input"
123
+ )
124
+
125
+ # Whisper configuration
126
+ model_choice = st.selectbox(
127
+ "Select Whisper Model",
128
+ options=["Fine-tuned Model", "Pretrained Whisper-Medium Model"],
129
+ help="Choose the Whisper model to transcribe the audio"
130
+ )
131
 
132
  @st.cache_resource
133
+ def load_finetuned_model_and_processor():
134
+ model = WhisperForConditionalGeneration.from_pretrained(
135
+ WHISPER_FINETUNED,
136
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
137
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
138
+ )
139
+ processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED)
140
+ model.config.forced_decoder_ids = None
141
+ model.generation_config.forced_decoder_ids = None
142
+ model.config.use_cache = None
143
+ model.config.suppress_tokens = []
144
+ if torch.cuda.is_available():
145
+ model = model.to("cuda")
146
+ return model, processor
 
 
 
 
147
 
148
  @st.cache_resource
149
+ def load_pretrained_model_and_processor():
150
+ model = WhisperForConditionalGeneration.from_pretrained(WHISPER_PRETRAINED)
151
+ processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED)
152
  model.config.forced_decoder_ids = None
153
  model.generation_config.forced_decoder_ids = None
154
+ model.config.use_cache = None
155
  model.config.suppress_tokens = []
156
  return model, processor
157
 
158
+ if model_choice == "Fine-tuned Model":
159
+ model, processor = load_finetuned_model_and_processor()
160
+ else:
161
+ model, processor = load_pretrained_model_and_processor()
162
+
163
+ # Transcription Button
164
+ if st.button("๐Ÿ“ Transcribe"):
165
+ if not st.session_state.audio_bytes:
166
+ st.error("โŒ Please upload or record an audio file first.")
167
+ else:
168
+ start_time = time.time()
169
+ try:
170
+ audio_input_data, _ = librosa.load(st.session_state.audio_path, sr=16000)
171
+ input_features = processor(
172
+ audio_input_data, sampling_rate=16000, return_tensors="pt"
173
+ ).input_features
174
+
175
+ predicted_ids = model.generate(input_features)
176
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
177
+ transcription = capitalize_sentences(transcription)
178
+ st.session_state.predicted_text = transcription
179
+ st.markdown("### ๐Ÿ”Š Predicted Transcription")
180
+ st.success(st.session_state.predicted_text)
181
+
182
+ if st.session_state.ground_truth:
183
+ st.session_state.wer_value = wer(
184
+ st.session_state.ground_truth.lower(),
185
+ st.session_state.predicted_text.lower()
186
+ )
187
+ st.markdown("### ๐Ÿงฎ Word Error Rate (WER)")
188
+ st.write(f"WER: `{st.session_state.wer_value * 100:.2f}%`")
189
+
190
+ except Exception as e:
191
+ st.error(f"โŒ Transcription failed: {str(e)}")
192
+
193
+ finally:
194
+ # Clean up temporary files
195
+ if st.session_state.audio_path and os.path.exists(st.session_state.audio_path):
196
+ os.unlink(st.session_state.audio_path)
197
+ st.session_state.audio_bytes = None
198
+ st.session_state.audio_path = None
199
+ st.session_state.audio_path = None
200
+ st.session_state.predicted_text = ""
201
+ st.session_state.ground_truth = ""
202
+ st.session_state.wer_value = None
203
+
204
+ end_time = time.time()
205
+ duration = end_time - start_time
206
+ st.caption(f"๐Ÿ•’ Time taken: {duration:.2f}s")