msekoyan commited on
Commit
7ca74f4
·
verified ·
1 Parent(s): ae5ab93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -16
app.py CHANGED
@@ -3,11 +3,11 @@ import torch
3
  import gradio as gr
4
  import spaces
5
  import gc
 
6
  from pathlib import Path
7
  from pydub import AudioSegment
8
  import numpy as np
9
  import os
10
- import tempfile
11
  import gradio.themes as gr_themes
12
  import csv
13
 
@@ -17,6 +17,24 @@ MODEL_NAME="nvidia/parakeet-tdt-0.6b-v2"
17
  model = ASRModel.from_pretrained(model_name=MODEL_NAME)
18
  model.eval()
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def get_audio_segment(audio_path, start_second, end_second):
21
  if not audio_path or not Path(audio_path).exists():
22
  print(f"Warning: Audio path '{audio_path}' not found or invalid for clipping.")
@@ -55,7 +73,7 @@ def get_audio_segment(audio_path, start_second, end_second):
55
  return None
56
 
57
  @spaces.GPU
58
- def get_transcripts_and_raw_times(audio_path):
59
  if not audio_path:
60
  gr.Error("No audio file path provided for transcription.", duration=None)
61
  # Return an update to hide the button
@@ -64,9 +82,9 @@ def get_transcripts_and_raw_times(audio_path):
64
  vis_data = [["N/A", "N/A", "Processing failed"]]
65
  raw_times_data = [[0.0, 0.0]]
66
  processed_audio_path = None
67
- temp_file = None
68
  csv_file_path = None
69
  original_path_name = Path(audio_path).name
 
70
 
71
  try:
72
  try:
@@ -105,16 +123,14 @@ def get_transcripts_and_raw_times(audio_path):
105
 
106
  if resampled or mono:
107
  try:
108
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
109
- audio.export(temp_file.name, format="wav")
110
- processed_audio_path = temp_file.name
111
- temp_file.close()
112
- transcribe_path = processed_audio_path
113
  info_path_name = f"{original_path_name} (processed)"
114
  except Exception as export_e:
115
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
116
- if temp_file and hasattr(temp_file, 'name') and os.path.exists(temp_file.name): # Check temp_file has 'name' attribute
117
- os.remove(temp_file.name)
118
  # Return an update to hide the button
119
  return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
120
  else:
@@ -139,12 +155,10 @@ def get_transcripts_and_raw_times(audio_path):
139
  # Default button update (hidden) in case CSV writing fails
140
  button_update = gr.DownloadButton(visible=False)
141
  try:
142
- temp_csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='', encoding='utf-8')
143
- writer = csv.writer(temp_csv_file)
144
  writer.writerow(csv_headers)
145
  writer.writerows(vis_data)
146
- csv_file_path = temp_csv_file.name
147
- temp_csv_file.close()
148
  print(f"CSV transcript saved to temporary file: {csv_file_path}")
149
  # If CSV is saved, create update to show button with path
150
  button_update = gr.DownloadButton(value=csv_file_path, visible=True)
@@ -285,6 +299,9 @@ with gr.Blocks(theme=nvidia_theme) as demo:
285
  current_audio_path_state = gr.State(None)
286
  raw_timestamps_list_state = gr.State([])
287
 
 
 
 
288
  with gr.Tabs():
289
  with gr.TabItem("Audio File"):
290
  file_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio File")
@@ -313,14 +330,14 @@ with gr.Blocks(theme=nvidia_theme) as demo:
313
 
314
  mic_transcribe_btn.click(
315
  fn=get_transcripts_and_raw_times,
316
- inputs=[mic_input],
317
  outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
318
  api_name="transcribe_mic"
319
  )
320
 
321
  file_transcribe_btn.click(
322
  fn=get_transcripts_and_raw_times,
323
- inputs=[file_input],
324
  outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
325
  api_name="transcribe_file"
326
  )
@@ -331,6 +348,8 @@ with gr.Blocks(theme=nvidia_theme) as demo:
331
  outputs=[selected_segment_player],
332
  )
333
 
 
 
334
  if __name__ == "__main__":
335
  print("Launching Gradio Demo...")
336
  demo.queue()
 
3
  import gradio as gr
4
  import spaces
5
  import gc
6
+ import shutil
7
  from pathlib import Path
8
  from pydub import AudioSegment
9
  import numpy as np
10
  import os
 
11
  import gradio.themes as gr_themes
12
  import csv
13
 
 
17
  model = ASRModel.from_pretrained(model_name=MODEL_NAME)
18
  model.eval()
19
 
20
+
21
+ def start_session(request: gr.Request):
22
+ session_hash = request.session_hash
23
+ session_dir = Path(f'/tmp/{session_hash}')
24
+ session_dir.mkdir(parents=True, exist_ok=True)
25
+
26
+ print(f"Session with hash {session_hash} started.")
27
+ return session_dir.as_posix()
28
+
29
+ def end_session(request: gr.Request):
30
+ session_hash = request.session_hash
31
+ session_dir = Path(f'/tmp/{session_hash}')
32
+
33
+ if session_dir.exists():
34
+ shutil.rmtree(session_dir)
35
+
36
+ print(f"Session with hash {session_hash} ended.")
37
+
38
  def get_audio_segment(audio_path, start_second, end_second):
39
  if not audio_path or not Path(audio_path).exists():
40
  print(f"Warning: Audio path '{audio_path}' not found or invalid for clipping.")
 
73
  return None
74
 
75
  @spaces.GPU
76
+ def get_transcripts_and_raw_times(audio_path, session_dir):
77
  if not audio_path:
78
  gr.Error("No audio file path provided for transcription.", duration=None)
79
  # Return an update to hide the button
 
82
  vis_data = [["N/A", "N/A", "Processing failed"]]
83
  raw_times_data = [[0.0, 0.0]]
84
  processed_audio_path = None
 
85
  csv_file_path = None
86
  original_path_name = Path(audio_path).name
87
+ audio_name = Path(audio_path).stem
88
 
89
  try:
90
  try:
 
123
 
124
  if resampled or mono:
125
  try:
126
+ processed_audio_path = Path(session_dir, f"{audio_name}_resampled.wav")
127
+ audio.export(processed_audio_path, format="wav")
128
+ transcribe_path = processed_audio_path.as_posix()
 
 
129
  info_path_name = f"{original_path_name} (processed)"
130
  except Exception as export_e:
131
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
132
+ if processed_audio_path and os.path.exists(processed_audio_path):
133
+ os.remove(processed_audio_path)
134
  # Return an update to hide the button
135
  return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
136
  else:
 
155
  # Default button update (hidden) in case CSV writing fails
156
  button_update = gr.DownloadButton(visible=False)
157
  try:
158
+ csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
159
+ writer = csv.writer(open(csv_file_path, 'w'))
160
  writer.writerow(csv_headers)
161
  writer.writerows(vis_data)
 
 
162
  print(f"CSV transcript saved to temporary file: {csv_file_path}")
163
  # If CSV is saved, create update to show button with path
164
  button_update = gr.DownloadButton(value=csv_file_path, visible=True)
 
299
  current_audio_path_state = gr.State(None)
300
  raw_timestamps_list_state = gr.State([])
301
 
302
+ session_dir = gr.State()
303
+ demo.load(start_session, outputs=[session_dir])
304
+
305
  with gr.Tabs():
306
  with gr.TabItem("Audio File"):
307
  file_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio File")
 
330
 
331
  mic_transcribe_btn.click(
332
  fn=get_transcripts_and_raw_times,
333
+ inputs=[mic_input, session_dir],
334
  outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
335
  api_name="transcribe_mic"
336
  )
337
 
338
  file_transcribe_btn.click(
339
  fn=get_transcripts_and_raw_times,
340
+ inputs=[file_input, session_dir],
341
  outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
342
  api_name="transcribe_file"
343
  )
 
348
  outputs=[selected_segment_player],
349
  )
350
 
351
+ demo.unload(end_session)
352
+
353
  if __name__ == "__main__":
354
  print("Launching Gradio Demo...")
355
  demo.queue()