Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ import time
|
|
16 |
import re
|
17 |
|
18 |
# Constants
|
19 |
-
WHISPER_FINETUNED = "wy0909/
|
20 |
WHISPER_PRETRAINED = "openai/whisper-medium"
|
21 |
WAV2VEC_MODEL = "mesolitica/wav2vec2-xls-r-300m-mixed"
|
22 |
MAX_RECORDING_SECONDS = 12
|
@@ -193,11 +193,15 @@ else:
|
|
193 |
|
194 |
@st.cache_resource
|
195 |
def load_finetuned_model_and_processor():
|
196 |
-
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_FINETUNED)
|
|
|
197 |
processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED)
|
198 |
model.config.forced_decoder_ids = None
|
199 |
model.generation_config.forced_decoder_ids = None
|
|
|
200 |
model.config.suppress_tokens = []
|
|
|
|
|
201 |
return model, processor
|
202 |
|
203 |
@st.cache_resource
|
@@ -206,6 +210,7 @@ else:
|
|
206 |
processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED)
|
207 |
model.config.forced_decoder_ids = None
|
208 |
model.generation_config.forced_decoder_ids = None
|
|
|
209 |
model.config.suppress_tokens = []
|
210 |
return model, processor
|
211 |
|
@@ -256,7 +261,7 @@ if st.button("๐ Transcribe"):
|
|
256 |
st.session_state.predicted_text.lower()
|
257 |
)
|
258 |
st.markdown("### ๐งฎ Word Error Rate (WER)")
|
259 |
-
st.write(f"WER: `{st.session_state.wer_value:.2f}
|
260 |
|
261 |
except Exception as e:
|
262 |
st.error(f"โ Transcription failed: {str(e)}")
|
@@ -274,4 +279,4 @@ if st.button("๐ Transcribe"):
|
|
274 |
|
275 |
end_time = time.time()
|
276 |
duration = end_time - start_time
|
277 |
-
st.caption(f"๐ Time taken: {duration:.2f}s")
|
|
|
16 |
import re
|
17 |
|
18 |
# Constants
|
19 |
+
WHISPER_FINETUNED = "wy0909/whisper-medium_mixedLanguageModel"
|
20 |
WHISPER_PRETRAINED = "openai/whisper-medium"
|
21 |
WAV2VEC_MODEL = "mesolitica/wav2vec2-xls-r-300m-mixed"
|
22 |
MAX_RECORDING_SECONDS = 12
|
|
|
193 |
|
194 |
@st.cache_resource
|
195 |
def load_finetuned_model_and_processor():
|
196 |
+
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_FINETUNED,torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
197 |
+
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None)
|
198 |
processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED)
|
199 |
model.config.forced_decoder_ids = None
|
200 |
model.generation_config.forced_decoder_ids = None
|
201 |
+
model.config.use_cache = None
|
202 |
model.config.suppress_tokens = []
|
203 |
+
if torch.cuda.is_available():
|
204 |
+
model = model.to("cuda")
|
205 |
return model, processor
|
206 |
|
207 |
@st.cache_resource
|
|
|
210 |
processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED)
|
211 |
model.config.forced_decoder_ids = None
|
212 |
model.generation_config.forced_decoder_ids = None
|
213 |
+
model.config.use_cache = None
|
214 |
model.config.suppress_tokens = []
|
215 |
return model, processor
|
216 |
|
|
|
261 |
st.session_state.predicted_text.lower()
|
262 |
)
|
263 |
st.markdown("### ๐งฎ Word Error Rate (WER)")
|
264 |
+
st.write(f"WER: `{st.session_state.wer_value * 100:.2f}%`")
|
265 |
|
266 |
except Exception as e:
|
267 |
st.error(f"โ Transcription failed: {str(e)}")
|
|
|
279 |
|
280 |
end_time = time.time()
|
281 |
duration = end_time - start_time
|
282 |
+
st.caption(f"๐ Time taken: {duration:.2f}s")
|