Nithin Rao Koluguri commited on
Commit
575a6d7
·
1 Parent(s): f4054f1

add parakeet-v2

Browse files

Signed-off-by: Nithin Rao Koluguri <nithinraok>

Files changed (5) hide show
  1. .gitattributes +1 -0
  2. README.md +4 -4
  3. app.py +315 -117
  4. pre-requirements.txt +0 -1
  5. requirements.txt +2 -2
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Parakeet Rnnt 1.1b
3
- emoji: 🦀
4
  colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-4.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Parakeet TDT 1.1b
3
+ emoji: "\_🦜"
4
  colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.27.1
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-4.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,139 +1,337 @@
1
  from nemo.collections.asr.models import ASRModel
2
- import yt_dlp as youtube_dl
3
- import os
4
- import tempfile
5
  import torch
6
  import gradio as gr
 
 
 
7
  from pydub import AudioSegment
 
 
 
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- MODEL_NAME="nvidia/parakeet-tdt-1.1b"
11
- YT_LENGTH_LIMIT_S=3600
12
 
13
- model = ASRModel.from_pretrained(model_name=MODEL_NAME).to(device)
14
  model.eval()
15
 
16
- def get_transcripts(audio_path):
17
- text = model.transcribe([audio_path])[0][0]
18
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  article = (
21
- "<p style='text-align: center'>"
22
- "<a href='https://huggingface.co/nvidia/parakeet-tdt-1.1b' target='_blank'>🎙️ Learn more about Parakeet TDT model</a> | "
23
- "<a href='https://arxiv.org/abs/2304.06795' target='_blank'>📚 TDT ICML paper</a> | "
24
- "<a href='https://github.com/NVIDIA/NeMo' target='_blank'>🧑‍💻 Repository</a>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  "</p>"
26
  )
 
27
  examples = [
28
- ["data/conversation.wav"],
29
- ["data/id10270_5r0dWxy17C8-00001.wav"],
30
  ]
31
 
32
- def _return_yt_html_embed(yt_url):
33
- video_id = yt_url.split("?v=")[-1]
34
- HTML_str = (
35
- f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
36
- " </center>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  )
38
- return HTML_str
39
 
40
- def download_yt_audio(yt_url, filename):
41
- info_loader = youtube_dl.YoutubeDL()
42
-
43
- try:
44
- info = info_loader.extract_info(yt_url, download=False)
45
- except youtube_dl.utils.DownloadError as err:
46
- raise gr.Error(str(err))
47
-
48
- file_length = info["duration_string"]
49
- file_h_m_s = file_length.split(":")
50
- file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
51
-
52
- if len(file_h_m_s) == 1:
53
- file_h_m_s.insert(0, 0)
54
- if len(file_h_m_s) == 2:
55
- file_h_m_s.insert(0, 0)
56
- file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
57
-
58
- if file_length_s > YT_LENGTH_LIMIT_S:
59
- yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
60
- file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
61
- raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
62
-
63
- ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
64
-
65
- with youtube_dl.YoutubeDL(ydl_opts) as ydl:
66
- try:
67
- ydl.download([yt_url])
68
- except youtube_dl.utils.ExtractorError as err:
69
- raise gr.Error(str(err))
70
-
71
-
72
- def yt_transcribe(yt_url, max_filesize=75.0):
73
- html_embed_str = _return_yt_html_embed(yt_url)
74
-
75
- with tempfile.TemporaryDirectory() as tmpdirname:
76
- filepath = os.path.join(tmpdirname, "video.mp4")
77
- download_yt_audio(yt_url, filepath)
78
- audio = AudioSegment.from_file(filepath)
79
- wav_filepath = os.path.join(tmpdirname, "audio.wav")
80
- audio.export(wav_filepath, format="wav")
81
-
82
- text = get_transcripts(wav_filepath)
83
- return html_embed_str, text
84
-
85
-
86
- demo = gr.Blocks()
87
-
88
- mf_transcribe = gr.Interface(
89
- fn=get_transcripts,
90
- inputs=[
91
- gr.Audio(sources="microphone", type="filepath")
92
- ],
93
- outputs="text",
94
- theme="huggingface",
95
- title="Parakeet TDT 1.1B: Transcribe Audio",
96
- description=(
97
- "Transcribe microphone or audio inputs with the click of a button! Demo uses the"
98
- f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) to transcribe audio files"
99
- " of arbitrary length. TDT models are 75% more efficient than similar size RNNT model"
100
- ),
101
- allow_flagging="never",
102
- )
103
 
104
- file_transcribe = gr.Interface(
105
- fn=get_transcripts,
106
- inputs=[
107
- gr.Audio(sources="upload", type="filepath", label="Audio file"),
108
- ],
109
- outputs="text",
110
- theme="huggingface",
111
- title="Parakeet TDT 1.1B: Transcribe Audio",
112
- description=(
113
- "Transcribe microphone or audio inputs with the click of a button! Demo uses the"
114
- f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) to transcribe audio files"
115
- " of arbitrary length. TDT models are 75% more efficient than similar size RNNT model"
116
- ),
117
- allow_flagging="never",
118
- )
119
 
120
- youtube_transcribe = gr.Interface(
121
- fn=yt_transcribe,
122
- inputs=[
123
- gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
124
- ],
125
- outputs=["html", "text"],
126
- theme="huggingface",
127
- title="Parakeet TDT 1.1B: Transcribe Audio",
128
- description=(
129
- "Transcribe microphone or audio inputs with the click of a button! Demo uses the"
130
- f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) to transcribe audio files"
131
- " of arbitrary length. TDT models are 75% more efficient than similar size RNNT model"
132
- ),
133
- allow_flagging="never",
134
- )
135
 
136
- with demo:
137
- gr.TabbedInterface([mf_transcribe, file_transcribe], ["Microphone", "Audio file"])
 
 
 
138
 
139
- demo.launch()
 
 
 
 
1
  from nemo.collections.asr.models import ASRModel
 
 
 
2
  import torch
3
  import gradio as gr
4
+ import spaces
5
+ import gc
6
+ from pathlib import Path
7
  from pydub import AudioSegment
8
+ import numpy as np
9
+ import os
10
+ import tempfile
11
+ import gradio.themes as gr_themes
12
+ import csv
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ MODEL_NAME="nvidia/parakeet-tdt-0.6b-v2"
 
16
 
17
+ model = ASRModel.from_pretrained(model_name=MODEL_NAME)
18
  model.eval()
19
 
20
+ def get_audio_segment(audio_path, start_second, end_second):
21
+ if not audio_path or not Path(audio_path).exists():
22
+ print(f"Warning: Audio path '{audio_path}' not found or invalid for clipping.")
23
+ return None
24
+ try:
25
+ start_ms = int(start_second * 1000)
26
+ end_ms = int(end_second * 1000)
27
+
28
+ start_ms = max(0, start_ms)
29
+ if end_ms <= start_ms:
30
+ print(f"Warning: End time ({end_second}s) is not after start time ({start_second}s). Adjusting end time.")
31
+ end_ms = start_ms + 100
32
+
33
+ audio = AudioSegment.from_file(audio_path)
34
+ clipped_audio = audio[start_ms:end_ms]
35
+
36
+ samples = np.array(clipped_audio.get_array_of_samples())
37
+ if clipped_audio.channels == 2:
38
+ samples = samples.reshape((-1, 2)).mean(axis=1).astype(samples.dtype)
39
+
40
+ frame_rate = clipped_audio.frame_rate
41
+ if frame_rate <= 0:
42
+ print(f"Warning: Invalid frame rate ({frame_rate}) detected for clipped audio.")
43
+ frame_rate = audio.frame_rate
44
+
45
+ if samples.size == 0:
46
+ print(f"Warning: Clipped audio resulted in empty samples array ({start_second}s to {end_second}s).")
47
+ return None
48
+
49
+ return (frame_rate, samples)
50
+ except FileNotFoundError:
51
+ print(f"Error: Audio file not found at path: {audio_path}")
52
+ return None
53
+ except Exception as e:
54
+ print(f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}")
55
+ return None
56
+
57
+ @spaces.GPU
58
+ def get_transcripts_and_raw_times(audio_path):
59
+ if not audio_path:
60
+ gr.Error("No audio file path provided for transcription.", duration=None)
61
+ # Return an update to hide the button
62
+ return [], [], None, gr.DownloadButton(visible=False)
63
+
64
+ vis_data = [["N/A", "N/A", "Processing failed"]]
65
+ raw_times_data = [[0.0, 0.0]]
66
+ processed_audio_path = None
67
+ temp_file = None
68
+ csv_file_path = None
69
+ original_path_name = Path(audio_path).name
70
+
71
+ try:
72
+ try:
73
+ gr.Info(f"Loading audio: {original_path_name}", duration=2)
74
+ audio = AudioSegment.from_file(audio_path)
75
+ except Exception as load_e:
76
+ gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
77
+ # Return an update to hide the button
78
+ return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
79
+
80
+ resampled = False
81
+ mono = False
82
+
83
+ target_sr = 16000
84
+ if audio.frame_rate != target_sr:
85
+ try:
86
+ audio = audio.set_frame_rate(target_sr)
87
+ resampled = True
88
+ except Exception as resample_e:
89
+ gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
90
+ # Return an update to hide the button
91
+ return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
92
+
93
+ if audio.channels == 2:
94
+ try:
95
+ audio = audio.set_channels(1)
96
+ mono = True
97
+ except Exception as mono_e:
98
+ gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
99
+ # Return an update to hide the button
100
+ return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
101
+ elif audio.channels > 2:
102
+ gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
103
+ # Return an update to hide the button
104
+ return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
105
+
106
+ if resampled or mono:
107
+ try:
108
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
109
+ audio.export(temp_file.name, format="wav")
110
+ processed_audio_path = temp_file.name
111
+ temp_file.close()
112
+ transcribe_path = processed_audio_path
113
+ info_path_name = f"{original_path_name} (processed)"
114
+ except Exception as export_e:
115
+ gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
116
+ if temp_file and hasattr(temp_file, 'name') and os.path.exists(temp_file.name): # Check temp_file has 'name' attribute
117
+ os.remove(temp_file.name)
118
+ # Return an update to hide the button
119
+ return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
120
+ else:
121
+ transcribe_path = audio_path
122
+ info_path_name = original_path_name
123
+
124
+ try:
125
+ model.to(device)
126
+ gr.Info(f"Transcribing {info_path_name} on {device}...", duration=2)
127
+ output = model.transcribe([transcribe_path], timestamps=True)
128
+
129
+ if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
130
+ gr.Error("Transcription failed or produced unexpected output format.", duration=None)
131
+ # Return an update to hide the button
132
+ return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
133
+
134
+ segment_timestamps = output[0].timestamp['segment']
135
+ csv_headers = ["Start (s)", "End (s)", "Segment"]
136
+ vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
137
+ raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
138
+
139
+ # Default button update (hidden) in case CSV writing fails
140
+ button_update = gr.DownloadButton(visible=False)
141
+ try:
142
+ temp_csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='', encoding='utf-8')
143
+ writer = csv.writer(temp_csv_file)
144
+ writer.writerow(csv_headers)
145
+ writer.writerows(vis_data)
146
+ csv_file_path = temp_csv_file.name
147
+ temp_csv_file.close()
148
+ print(f"CSV transcript saved to temporary file: {csv_file_path}")
149
+ # If CSV is saved, create update to show button with path
150
+ button_update = gr.DownloadButton(value=csv_file_path, visible=True)
151
+ except Exception as csv_e:
152
+ gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
153
+ print(f"Error writing CSV: {csv_e}")
154
+ # csv_file_path remains None, button_update remains hidden
155
+
156
+ gr.Info("Transcription complete.", duration=2)
157
+ # Return the data and the button update dictionary
158
+ return vis_data, raw_times_data, audio_path, button_update
159
+
160
+ except torch.cuda.OutOfMemoryError as e:
161
+ error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
162
+ print(f"CUDA OutOfMemoryError: {e}")
163
+ gr.Error(error_msg, duration=None)
164
+ # Return an update to hide the button
165
+ return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
166
+
167
+ except FileNotFoundError:
168
+ error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
169
+ print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
170
+ gr.Error(error_msg, duration=None)
171
+ # Return an update to hide the button
172
+ return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
173
+
174
+ except Exception as e:
175
+ error_msg = f"Transcription failed: {e}"
176
+ print(f"Error during transcription processing: {e}")
177
+ gr.Error(error_msg, duration=None)
178
+ vis_data = [["Error", "Error", error_msg]]
179
+ raw_times_data = [[0.0, 0.0]]
180
+ # Return an update to hide the button
181
+ return vis_data, raw_times_data, audio_path, gr.DownloadButton(visible=False)
182
+ finally:
183
+ try:
184
+ if 'model' in locals() and hasattr(model, 'cpu'):
185
+ if device == 'cuda':
186
+ model.cpu()
187
+ gc.collect()
188
+ if device == 'cuda':
189
+ torch.cuda.empty_cache()
190
+ except Exception as cleanup_e:
191
+ print(f"Error during model cleanup: {cleanup_e}")
192
+ gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
193
+
194
+ finally:
195
+ if processed_audio_path and os.path.exists(processed_audio_path):
196
+ try:
197
+ os.remove(processed_audio_path)
198
+ print(f"Temporary audio file {processed_audio_path} removed.")
199
+ except Exception as e:
200
+ print(f"Error removing temporary audio file {processed_audio_path}: {e}")
201
+
202
+ def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path):
203
+ if not isinstance(raw_ts_list, list):
204
+ print(f"Warning: raw_ts_list is not a list ({type(raw_ts_list)}). Cannot play segment.")
205
+ return gr.Audio(value=None, label="Selected Segment")
206
+
207
+ if not current_audio_path:
208
+ print("No audio path available to play segment from.")
209
+ return gr.Audio(value=None, label="Selected Segment")
210
+
211
+ selected_index = evt.index[0]
212
+
213
+ if selected_index < 0 or selected_index >= len(raw_ts_list):
214
+ print(f"Invalid index {selected_index} selected for list of length {len(raw_ts_list)}.")
215
+ return gr.Audio(value=None, label="Selected Segment")
216
+
217
+ if not isinstance(raw_ts_list[selected_index], (list, tuple)) or len(raw_ts_list[selected_index]) != 2:
218
+ print(f"Warning: Data at index {selected_index} is not in the expected format [start, end].")
219
+ return gr.Audio(value=None, label="Selected Segment")
220
+
221
+ start_time_s, end_time_s = raw_ts_list[selected_index]
222
+
223
+ print(f"Attempting to play segment: {current_audio_path} from {start_time_s:.2f}s to {end_time_s:.2f}s")
224
+
225
+ segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s)
226
+
227
+ if segment_data:
228
+ print("Segment data retrieved successfully.")
229
+ return gr.Audio(value=segment_data, autoplay=True, label=f"Segment: {start_time_s:.2f}s - {end_time_s:.2f}s", interactive=False)
230
+ else:
231
+ print("Failed to get audio segment data.")
232
+ return gr.Audio(value=None, label="Selected Segment")
233
 
234
  article = (
235
+ "<p style='font-size: 1.1em;'>"
236
+ "This demo showcases <code><a href='https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2'>parakeet-tdt-0.6b-v2</a></code>, a 600-million-parameter model designed for high-quality English speech recognition."
237
+ "</p>"
238
+ "<p><strong style='color: red; font-size: 1.2em;'>Key Features:</strong></p>"
239
+ "<ul style='font-size: 1.1em;'>"
240
+ " <li>Automatic punctuation and capitalization</li>"
241
+ " <li>Accurate word-level timestamps (click on a segment in the table below to play it!)</li>"
242
+ " <li>Efficiently transcribes long audio segments (up to 20 minutes) <small>(For even longer audios, see <a href='https://github.com/NVIDIA/NeMo/blob/main/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py' target='_blank'>this script</a>)</small></li>"
243
+ " <li>Robust performance on spoken numbers, music, and songs</li>"
244
+ "</ul>"
245
+ "<p style='font-size: 1.1em;'>"
246
+ "This model is <strong>available for commercial and non-commercial use</strong>."
247
+ "</p>"
248
+ "<p style='text-align: center;'>"
249
+ "<a href='https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2' target='_blank'>🎙️ Learn more about the Model</a> | "
250
+ "<a href='https://arxiv.org/abs/2305.05084' target='_blank'>📄 Fast Conformer paper</a> | "
251
+ "<a href='https://arxiv.org/abs/2304.06795' target='_blank'>📚 TDT paper</a> | "
252
+ "<a href='https://github.com/NVIDIA/NeMo' target='_blank'>🧑‍💻 NeMo Repository</a>"
253
  "</p>"
254
  )
255
+
256
  examples = [
257
+ ["data/example-yt_saTD1u8PorI.mp3"],
 
258
  ]
259
 
260
+ # Define an NVIDIA-inspired theme
261
+ nvidia_theme = gr_themes.Default(
262
+ primary_hue=gr_themes.Color(
263
+ c50="#E6F1D9", # Lightest green
264
+ c100="#CEE3B3",
265
+ c200="#B5D58C",
266
+ c300="#9CC766",
267
+ c400="#84B940",
268
+ c500="#76B900", # NVIDIA Green
269
+ c600="#68A600",
270
+ c700="#5A9200",
271
+ c800="#4C7E00",
272
+ c900="#3E6A00", # Darkest green
273
+ c950="#2F5600"
274
+ ),
275
+ neutral_hue="gray", # Use gray for neutral elements
276
+ font=[gr_themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
277
+ ).set()
278
+
279
+ # Apply the custom theme
280
+ with gr.Blocks(theme=nvidia_theme) as demo:
281
+ model_display_name = MODEL_NAME.split('/')[-1] if '/' in MODEL_NAME else MODEL_NAME
282
+ gr.Markdown(f"<h1 style='text-align: center; margin: 0 auto;'>Speech Transcription with {model_display_name}</h1>")
283
+ gr.HTML(article)
284
+
285
+ current_audio_path_state = gr.State(None)
286
+ raw_timestamps_list_state = gr.State([])
287
+
288
+ with gr.Tabs():
289
+ with gr.TabItem("Audio File"):
290
+ file_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio File")
291
+ gr.Examples(examples=examples, inputs=[file_input], label="Example Audio Files (Click to Load)")
292
+ file_transcribe_btn = gr.Button("Transcribe Uploaded File", variant="primary")
293
+
294
+ with gr.TabItem("Microphone"):
295
+ mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Record Audio")
296
+ mic_transcribe_btn = gr.Button("Transcribe Microphone Input", variant="primary")
297
+
298
+ gr.Markdown("---")
299
+ gr.Markdown("<p><strong style='color: #FF0000; font-size: 1.2em;'>Transcription Results (Click row to play segment)</strong></p>")
300
+
301
+ # Define the DownloadButton *before* the DataFrame
302
+ download_btn = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
303
+
304
+ vis_timestamps_df = gr.DataFrame(
305
+ headers=["Start (s)", "End (s)", "Segment"],
306
+ datatype=["number", "number", "str"],
307
+ wrap=True,
308
+ label="Transcription Segments"
309
  )
 
310
 
311
+ # selected_segment_player was defined after download_btn previously, keep it after df for layout
312
+ selected_segment_player = gr.Audio(label="Selected Segment", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
+ mic_transcribe_btn.click(
315
+ fn=get_transcripts_and_raw_times,
316
+ inputs=[mic_input],
317
+ outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
318
+ api_name="transcribe_mic"
319
+ )
 
 
 
 
 
 
 
 
 
320
 
321
+ file_transcribe_btn.click(
322
+ fn=get_transcripts_and_raw_times,
323
+ inputs=[file_input],
324
+ outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
325
+ api_name="transcribe_file"
326
+ )
 
 
 
 
 
 
 
 
 
327
 
328
+ vis_timestamps_df.select(
329
+ fn=play_segment,
330
+ inputs=[raw_timestamps_list_state, current_audio_path_state],
331
+ outputs=[selected_segment_player],
332
+ )
333
 
334
+ if __name__ == "__main__":
335
+ print("Launching Gradio Demo...")
336
+ demo.queue()
337
+ demo.launch()
pre-requirements.txt DELETED
@@ -1 +0,0 @@
1
- Cython
 
 
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  Cython
2
- nemo-toolkit[all]==1.22.0
3
- yt_dlp
 
1
  Cython
2
+ git+https://github.com/NVIDIA/NeMo.[email protected].0#egg=nemo_toolkit[asr]
3
+ numpy<2.0