ashantharosary commited on
Commit
d60cb2b
ยท
verified ยท
1 Parent(s): 57f60e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -16,7 +16,7 @@ import time
16
  import re
17
 
18
  # Constants
19
- WHISPER_FINETUNED = "wy0909/Whisper-MixedLanguageModel"
20
  WHISPER_PRETRAINED = "openai/whisper-medium"
21
  WAV2VEC_MODEL = "mesolitica/wav2vec2-xls-r-300m-mixed"
22
  MAX_RECORDING_SECONDS = 12
@@ -193,11 +193,15 @@ else:
193
 
194
  @st.cache_resource
195
  def load_finetuned_model_and_processor():
196
- model = WhisperForConditionalGeneration.from_pretrained(WHISPER_FINETUNED)
 
197
  processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED)
198
  model.config.forced_decoder_ids = None
199
  model.generation_config.forced_decoder_ids = None
 
200
  model.config.suppress_tokens = []
 
 
201
  return model, processor
202
 
203
  @st.cache_resource
@@ -206,6 +210,7 @@ else:
206
  processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED)
207
  model.config.forced_decoder_ids = None
208
  model.generation_config.forced_decoder_ids = None
 
209
  model.config.suppress_tokens = []
210
  return model, processor
211
 
@@ -256,7 +261,7 @@ if st.button("๐Ÿ“ Transcribe"):
256
  st.session_state.predicted_text.lower()
257
  )
258
  st.markdown("### ๐Ÿงฎ Word Error Rate (WER)")
259
- st.write(f"WER: `{st.session_state.wer_value:.2f}`")
260
 
261
  except Exception as e:
262
  st.error(f"โŒ Transcription failed: {str(e)}")
@@ -274,4 +279,4 @@ if st.button("๐Ÿ“ Transcribe"):
274
 
275
  end_time = time.time()
276
  duration = end_time - start_time
277
- st.caption(f"๐Ÿ•’ Time taken: {duration:.2f}s")
 
16
  import re
17
 
18
  # Constants
19
+ WHISPER_FINETUNED = "wy0909/whisper-medium_mixedLanguageModel"
20
  WHISPER_PRETRAINED = "openai/whisper-medium"
21
  WAV2VEC_MODEL = "mesolitica/wav2vec2-xls-r-300m-mixed"
22
  MAX_RECORDING_SECONDS = 12
 
193
 
194
  @st.cache_resource
195
  def load_finetuned_model_and_processor():
196
+ model = WhisperForConditionalGeneration.from_pretrained(WHISPER_FINETUNED,torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
197
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else None)
198
  processor = WhisperProcessor.from_pretrained(WHISPER_FINETUNED)
199
  model.config.forced_decoder_ids = None
200
  model.generation_config.forced_decoder_ids = None
201
+ model.config.use_cache = None
202
  model.config.suppress_tokens = []
203
+ if torch.cuda.is_available():
204
+ model = model.to("cuda")
205
  return model, processor
206
 
207
  @st.cache_resource
 
210
  processor = WhisperProcessor.from_pretrained(WHISPER_PRETRAINED)
211
  model.config.forced_decoder_ids = None
212
  model.generation_config.forced_decoder_ids = None
213
+ model.config.use_cache = None
214
  model.config.suppress_tokens = []
215
  return model, processor
216
 
 
261
  st.session_state.predicted_text.lower()
262
  )
263
  st.markdown("### ๐Ÿงฎ Word Error Rate (WER)")
264
+ st.write(f"WER: `{st.session_state.wer_value * 100:.2f}%`")
265
 
266
  except Exception as e:
267
  st.error(f"โŒ Transcription failed: {str(e)}")
 
279
 
280
  end_time = time.time()
281
  duration = end_time - start_time
282
+ st.caption(f"๐Ÿ•’ Time taken: {duration:.2f}s")