tanettech commited on
Commit
0907799
·
1 Parent(s): 8616c67

Add Gradio app files

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv_py311/
2
+ __pycache__/
3
+ *.pyc
README.md CHANGED
@@ -1,14 +1,12 @@
1
  ---
2
- title: Elvis Voice Assistant Demo
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
- short_description: Espnet Recipe
12
  ---
13
 
14
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
1
  ---
2
+ title: Voice Assistant Demo
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.43.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,64 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
8
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
1
+ try:
2
+ import versa
3
+ except ImportError:
4
+ from subprocess import call
5
+ with open('versa.sh', 'rb') as file:
6
+ script = file.read()
7
+ rc = call(script, shell=True)
8
+
9
+ import os
10
+ import shutil
11
+ import time
12
+ from typing import Generator, Optional, Tuple
13
+
14
  import gradio as gr
15
+ import nltk
16
+ import numpy as np
17
+ import torch
18
+ from huggingface_hub import HfApi
19
+ from pyscripts.utils.dialog_eval.ASR_WER import handle_espnet_ASR_WER
20
+ from pyscripts.utils.dialog_eval.human_feedback import (
21
+ natural_vote1_last_response,
22
+ natural_vote2_last_response,
23
+ natural_vote3_last_response,
24
+ natural_vote4_last_response,
25
+ relevant_vote1_last_response,
26
+ relevant_vote2_last_response,
27
+ relevant_vote3_last_response,
28
+ relevant_vote4_last_response,
29
+ )
30
+ from pyscripts.utils.dialog_eval.LLM_Metrics import (
31
+ DialoGPT_perplexity,
32
+ bert_score,
33
+ perplexity,
34
+ vert,
35
+ )
36
+ from pyscripts.utils.dialog_eval.TTS_intelligibility import (
37
+ handle_espnet_TTS_intelligibility,
38
+ )
39
+ from pyscripts.utils.dialog_eval.TTS_speech_quality import TTS_psuedomos
40
+
41
+ from espnet2.sds.espnet_model import ESPnetSDSModelInterface
42
+
43
+ # ------------------------
44
+ # Hyperparameters
45
+ # ------------------------
46
+
47
+ access_token = os.environ.get("HF_TOKEN")
48
+ ASR_name="pyf98/owsm_ctc_v3.1_1B"
49
+ LLM_name="meta-llama/Llama-3.2-1B-Instruct"
50
+ TTS_name="kan-bayashi/ljspeech_vits"
51
+ ASR_options="pyf98/owsm_ctc_v3.1_1B,espnet/owsm_ctc_v3.2_ft_1B,espnet/owsm_v3.1_ebf,librispeech_asr,whisper-large".split(",")
52
+ LLM_options="meta-llama/Llama-3.2-1B-Instruct,HuggingFaceTB/SmolLM2-1.7B-Instruct".split(",")
53
+ TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
54
+ Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
55
+ upload_to_hub=None
56
+ dialogue_model = ESPnetSDSModelInterface(
57
+ ASR_name, LLM_name, TTS_name, "Cascaded", access_token
58
+ )
59
+ ASR_curr_name=None
60
+ LLM_curr_name=None
61
+ TTS_curr_name=None
62
+
63
+ latency_ASR = 0.0
64
+ latency_LM = 0.0
65
+ latency_TTS = 0.0
66
 
67
+ text_str = ""
68
+ asr_output_str = ""
69
+ vad_output = None
70
+ audio_output = None
71
+ audio_output1 = None
72
+ LLM_response_arr = []
73
+ total_response_arr = []
74
+ callback = gr.CSVLogger()
75
+ start_record_time = None
76
+ enable_btn = gr.Button(interactive=True, visible=True)
77
 
78
+ # ------------------------
79
+ # Function Definitions
80
+ # ------------------------
81
 
82
+ def handle_eval_selection(
83
+ option: str,
84
+ TTS_audio_output: str,
85
+ LLM_Output: str,
86
+ ASR_audio_output: str,
87
+ ASR_transcript: str,
 
88
  ):
89
+ """
90
+ Handles the evaluation of a selected metric based on
91
+ user input and provided outputs.
92
+
93
+ This function evaluates different aspects of a
94
+ casacaded conversational AI pipeline, such as:
95
+ Latency, TTS intelligibility, TTS speech quality,
96
+ ASR WER, and text dialog metrics.
97
+ It is designed to integrate with Gradio via
98
+ multiple yield statements,
99
+ allowing updates to be displayed in real time.
100
+
101
+ Parameters:
102
+ ----------
103
+ option : str
104
+ The evaluation metric selected by the user.
105
+ Supported options include:
106
+ - "Latency"
107
+ - "TTS Intelligibility"
108
+ - "TTS Speech Quality"
109
+ - "ASR WER"
110
+ - "Text Dialog Metrics"
111
+ TTS_audio_output : np.ndarray
112
+ The audio output generated by the TTS module for evaluation.
113
+ LLM_Output : str
114
+ The text output generated by the LLM module for evaluation.
115
+ ASR_audio_output : np.ndarray
116
+ The audio input/output used for ASR evaluation.
117
+ ASR_transcript : str
118
+ The transcript generated by the ASR module for evaluation.
119
+
120
+ Returns:
121
+ -------
122
+ str
123
+ A string representation of the evaluation results.
124
+ The specific result depends on the selected evaluation metric:
125
+ - "Latency": Latencies of ASR, LLM, and TTS modules.
126
+ - "TTS Intelligibility": A range of scores indicating how intelligible
127
+ the TTS audio output is based on different reference ASR models.
128
+ - "TTS Speech Quality": A range of scores representing the
129
+ speech quality of the TTS audio output.
130
+ - "ASR WER": The Word Error Rate (WER) of the ASR output
131
+ based on different judge ASR models.
132
+ - "Text Dialog Metrics": A combination of perplexity,
133
+ diversity metrics, and relevance scores for the dialog.
134
+
135
+ Raises:
136
+ ------
137
+ ValueError
138
+ If the `option` parameter does not match any supported evaluation metric.
139
+
140
+ Example:
141
+ -------
142
+ >>> result = handle_eval_selection(
143
+ option="Latency",
144
+ TTS_audio_output=audio_array,
145
+ LLM_Output="Generated response",
146
+ ASR_audio_output=audio_input,
147
+ ASR_transcript="Expected transcript"
148
+ )
149
+ >>> print(result)
150
+ "ASR Latency: 0.14
151
+ LLM Latency: 0.42
152
+ TTS Latency: 0.21"
153
+ """
154
+ global LLM_response_arr
155
+ global total_response_arr
156
+ yield (option, gr.Textbox(visible=True))
157
+ if option == "Latency":
158
+ text = (
159
+ f"ASR Latency: {latency_ASR:.2f}\n"
160
+ f"LLM Latency: {latency_LM:.2f}\n"
161
+ f"TTS Latency: {latency_TTS:.2f}"
162
+ )
163
+ yield (None, text)
164
+ elif option == "TTS Intelligibility":
165
+ yield (None, handle_espnet_TTS_intelligibility(TTS_audio_output, LLM_Output))
166
+ elif option == "TTS Speech Quality":
167
+ yield (None, TTS_psuedomos(TTS_audio_output))
168
+ elif option == "ASR WER":
169
+ yield (None, handle_espnet_ASR_WER(ASR_audio_output, ASR_transcript))
170
+ elif option == "Text Dialog Metrics":
171
+ yield (
172
+ None,
173
+ perplexity(LLM_Output.replace("\n", " "))
174
+ + vert(LLM_response_arr)
175
+ + bert_score(total_response_arr)
176
+ + DialoGPT_perplexity(
177
+ ASR_transcript.replace("\n", " "), LLM_Output.replace("\n", " ")
178
+ ),
179
+ )
180
+ elif option is None:
181
+ return
182
+ else:
183
+ raise ValueError(f"Unknown option: {option}")
184
+
185
+
186
+ def handle_eval_selection_E2E(
187
+ option: str,
188
+ TTS_audio_output: str,
189
+ LLM_Output: str,
190
+ ):
191
+ """
192
+ Handles the evaluation of a selected metric based on user input
193
+ and provided outputs.
194
+
195
+ This function evaluates different aspects of an E2E
196
+ conversational AI model, such as:
197
+ Latency, TTS intelligibility, TTS speech quality, and
198
+ text dialog metrics.
199
+ It is designed to integrate with Gradio via
200
+ multiple yield statements,
201
+ allowing updates to be displayed in real time.
202
+
203
+ Parameters:
204
+ ----------
205
+ option : str
206
+ The evaluation metric selected by the user.
207
+ Supported options include:
208
+ - "Latency"
209
+ - "TTS Intelligibility"
210
+ - "TTS Speech Quality"
211
+ - "Text Dialog Metrics"
212
+ TTS_audio_output : np.ndarray
213
+ The audio output generated by the TTS module for evaluation.
214
+ LLM_Output : str
215
+ The text output generated by the LLM module for evaluation.
216
+
217
+ Returns:
218
+ -------
219
+ str
220
+ A string representation of the evaluation results.
221
+ The specific result depends on the selected evaluation metric:
222
+ - "Latency": Latency of the entire system.
223
+ - "TTS Intelligibility": A range of scores indicating how intelligible the
224
+ TTS audio output is based on different reference ASR models.
225
+ - "TTS Speech Quality": A range of scores representing the
226
+ speech quality of the TTS audio output.
227
+ - "Text Dialog Metrics": A combination of perplexity and
228
+ diversity metrics for the dialog.
229
+
230
+ Raises:
231
+ ------
232
+ ValueError
233
+ If the `option` parameter does not match any supported evaluation metric.
234
+
235
+ Example:
236
+ -------
237
+ >>> result = handle_eval_selection(
238
+ option="Latency",
239
+ TTS_audio_output=audio_array,
240
+ LLM_Output="Generated response",
241
+ )
242
+ >>> print(result)
243
+ "Total Latency: 2.34"
244
+ """
245
+ global LLM_response_arr
246
+ global total_response_arr
247
+ yield (option, gr.Textbox(visible=True))
248
+ if option == "Latency":
249
+ text = f"Total Latency: {latency_TTS:.2f}"
250
+ yield (None, text)
251
+ elif option == "TTS Intelligibility":
252
+ yield (None, handle_espnet_TTS_intelligibility(TTS_audio_output, LLM_Output))
253
+ elif option == "TTS Speech Quality":
254
+ yield (None, TTS_psuedomos(TTS_audio_output))
255
+ elif option == "Text Dialog Metrics":
256
+ yield (None, perplexity(LLM_Output.replace("\n", " ")) + vert(LLM_response_arr))
257
+ elif option is None:
258
+ return
259
+ else:
260
+ raise ValueError(f"Unknown option: {option}")
261
+
262
+
263
+ def start_warmup():
264
+ """
265
+ Initializes and warms up the dialogue and evaluation model.
266
+
267
+ This function is designed to ensure that all
268
+ components of the dialogue model are pre-loaded
269
+ and ready for execution, avoiding delays during runtime.
270
+ """
271
+ global dialogue_model
272
+ global ASR_options
273
+ global LLM_options
274
+ global TTS_options
275
+ global ASR_name
276
+ global LLM_name
277
+ global TTS_name
278
+ remove=0
279
+ for opt_count in range(len(ASR_options)):
280
+ opt_count-=remove
281
+ if opt_count>=len(ASR_options):
282
+ break
283
+ print(opt_count)
284
+ print(ASR_options)
285
+ opt = ASR_options[opt_count]
286
+ try:
287
+ for _ in dialogue_model.handle_ASR_selection(opt):
288
+ continue
289
+ except Exception:
290
+ print("Removing " + opt + " from ASR options since it cannot be loaded.")
291
+ ASR_options = ASR_options[:opt_count] + ASR_options[(opt_count + 1) :]
292
+ remove+=1
293
+ if opt == ASR_name:
294
+ ASR_name = ASR_options[0]
295
+ for opt_count in range(len(LLM_options)):
296
+ opt = LLM_options[opt_count]
297
+ try:
298
+ for _ in dialogue_model.handle_LLM_selection(opt):
299
+ continue
300
+ except Exception:
301
+ print("Removing " + opt + " from LLM options since it cannot be loaded.")
302
+ LLM_options = LLM_options[:opt_count] + LLM_options[(opt_count + 1) :]
303
+ if opt == LLM_name:
304
+ LLM_name = LLM_options[0]
305
+ for opt_count in range(len(TTS_options)):
306
+ opt = TTS_options[opt_count]
307
+ try:
308
+ for _ in dialogue_model.handle_TTS_selection(opt):
309
+ continue
310
+ except Exception:
311
+ print("Removing " + opt + " from TTS options since it cannot be loaded.")
312
+ TTS_options = TTS_options[:opt_count] + TTS_options[(opt_count + 1) :]
313
+ if opt == TTS_name:
314
+ TTS_name = TTS_options[0]
315
+ dialogue_model.handle_E2E_selection()
316
+ dialogue_model.client = None
317
+ for _ in dialogue_model.handle_TTS_selection(TTS_name):
318
+ continue
319
+ for _ in dialogue_model.handle_ASR_selection(ASR_name):
320
+ continue
321
+ for _ in dialogue_model.handle_LLM_selection(LLM_name):
322
+ continue
323
+ dummy_input = (
324
+ torch.randn(
325
+ (3000),
326
+ dtype=getattr(torch, "float16"),
327
+ device="cpu",
328
+ )
329
+ .cpu()
330
+ .numpy()
331
+ )
332
+ dummy_text = "This is dummy text"
333
+ for opt in Eval_options:
334
+ handle_eval_selection(opt, dummy_input, dummy_text, dummy_input, dummy_text)
335
+
336
+
337
+ def flash_buttons():
338
+ """
339
+ Enables human feedback buttons after displaying system output.
340
+ """
341
+ btn_updates = (enable_btn,) * 8
342
+ yield (
343
+ "",
344
+ "",
345
+ ) + btn_updates
346
+
347
+
348
+ def transcribe(
349
+ stream: np.ndarray,
350
+ new_chunk: Tuple[int, np.ndarray],
351
+ TTS_option: str,
352
+ ASR_option: str,
353
+ LLM_option: str,
354
+ type_option: str,
355
+ input_text: str,
356
+ ):
357
+ """
358
+ Processes and transcribes an audio stream in real-time.
359
+
360
+ This function handles the transcription of audio input
361
+ and its transformation through a cascaded
362
+ or E2E conversational AI system.
363
+ It dynamically updates the transcription, text generation,
364
+ and synthesized speech output, while managing global states and latencies.
365
+
366
+ Args:
367
+ stream: The current audio stream buffer.
368
+ `None` if the stream is being reset (e.g., after user refresh).
369
+ new_chunk: A tuple containing:
370
+ - `sr`: Sample rate of the new audio chunk.
371
+ - `y`: New audio data chunk.
372
+ TTS_option: Selected TTS model option.
373
+ ASR_option: Selected ASR model option.
374
+ LLM_option: Selected LLM model option.
375
+ type_option: Type of system ("Cascaded" or "E2E").
376
+
377
+ Yields:
378
+ Tuple[Optional[np.ndarray], Optional[str], Optional[str],
379
+ Optional[Tuple[int, np.ndarray]], Optional[Tuple[int, np.ndarray]]]:
380
+ A tuple containing:
381
+ - Updated stream buffer.
382
+ - ASR output text.
383
+ - Generated LLM output text.
384
+ - Audio output as a tuple of sample rate and audio waveform.
385
+ - User input audio as a tuple of sample rate and audio waveform.
386
+
387
+ Notes:
388
+ - Resets the session if the transcription exceeds 5 minutes.
389
+ - Updates the Gradio interface elements dynamically.
390
+ - Manages latencies.
391
+ """
392
+ sr, y = new_chunk
393
+ global text_str
394
+ global chat
395
+ global user_role
396
+ global audio_output
397
+ global audio_output1
398
+ global vad_output
399
+ global asr_output_str
400
+ global start_record_time
401
+ global sids
402
+ global spembs
403
+ global latency_ASR
404
+ global latency_LM
405
+ global latency_TTS
406
+ global LLM_response_arr
407
+ global total_response_arr
408
+ if stream is None:
409
+ # Handle user refresh
410
+ for (
411
+ _,
412
+ _,
413
+ _,
414
+ _,
415
+ asr_output_box,
416
+ text_box,
417
+ audio_box,
418
+ _,
419
+ _,
420
+ ) in dialogue_model.handle_type_selection(
421
+ type_option, TTS_option, ASR_option, LLM_option
422
+ ):
423
+ gr.Info("The models are being reloaded due to a browser refresh.")
424
+ yield (stream, asr_output_box, text_box, audio_box, gr.Audio(visible=False))
425
+ stream = y
426
+ text_str = ""
427
+ audio_output = None
428
+ audio_output1 = None
429
+ else:
430
+ stream = np.concatenate((stream, y))
431
+ # import pdb;pdb.set_trace()
432
+ dialogue_model.chat.init_chat(
433
+ {
434
+ "role": "system",
435
+ "content": (
436
+ input_text
437
+ ),
438
+ }
439
+ )
440
+ (
441
+ asr_output_str,
442
+ text_str,
443
+ audio_output,
444
+ audio_output1,
445
+ latency_ASR,
446
+ latency_LM,
447
+ latency_TTS,
448
+ stream,
449
+ change,
450
+ ) = dialogue_model(
451
+ y,
452
+ sr,
453
+ stream,
454
+ asr_output_str,
455
+ text_str,
456
+ audio_output,
457
+ audio_output1,
458
+ latency_ASR,
459
+ latency_LM,
460
+ latency_TTS,
461
+ )
462
+ text_str1 = text_str
463
+ if change:
464
+ print("Output changed")
465
+ if asr_output_str != "":
466
+ total_response_arr.append(asr_output_str.replace("\n", " "))
467
+ LLM_response_arr.append(text_str.replace("\n", " "))
468
+ total_response_arr.append(text_str.replace("\n", " "))
469
+ if (text_str != "") and (start_record_time is None):
470
+ start_record_time = time.time()
471
+ elif start_record_time is not None:
472
+ current_record_time = time.time()
473
+ if current_record_time - start_record_time > 300:
474
+ gr.Info(
475
+ "Conversations are limited to 5 minutes. "
476
+ "The session will restart in approximately 60 seconds. "
477
+ "Please wait for the demo to reset. "
478
+ "Close this message once you have read it.",
479
+ duration=None,
480
+ )
481
+ yield stream, gr.Textbox(visible=False), gr.Textbox(
482
+ visible=False
483
+ ), gr.Audio(visible=False), gr.Audio(visible=False)
484
+ if upload_to_hub is not None:
485
+ api.upload_folder(
486
+ folder_path="flagged_data_points",
487
+ path_in_repo="checkpoint_" + str(start_record_time),
488
+ repo_id=upload_to_hub,
489
+ repo_type="dataset",
490
+ token=access_token,
491
+ )
492
+ dialogue_model.chat.buffer = []
493
+ text_str = ""
494
+ audio_output = None
495
+ audio_output1 = None
496
+ asr_output_str = ""
497
+ start_record_time = None
498
+ LLM_response_arr = []
499
+ total_response_arr = []
500
+ shutil.rmtree("flagged_data_points")
501
+ os.mkdir("flagged_data_points")
502
+ yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
503
+ yield stream, gr.Textbox(visible=True), gr.Textbox(visible=True), gr.Audio(
504
+ visible=True
505
+ ), gr.Audio(visible=False)
506
+
507
+ yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
508
+
509
+
510
+ # ------------------------
511
+ # Executable Script
512
+ # ------------------------
513
+ api = HfApi()
514
+ nltk.download("averaged_perceptron_tagger_eng")
515
+ start_warmup()
516
+ default_instruct=(
517
+ "You are a helpful and friendly AI "
518
+ "assistant. "
519
+ "You are polite, respectful, and aim to "
520
+ "provide concise and complete responses of "
521
+ "less than 15 words."
522
  )
523
+ import pandas as pd
524
+ examples = pd.DataFrame([
525
+ ["General Purpose Conversation", default_instruct],
526
+ ["Translation", "You are a translator. Translate user text into English."],
527
+ ["General Purpose Conversation with Disfluencies", "Please reply to user with lot of filler words like ummm, so"],
528
+ ["Summarization", "You are summarizer. Summarize user's utterance."]
529
+ ], columns=["Task", "LLM Prompt"])
530
+ with gr.Blocks(
531
+ title="E2E Spoken Dialog System",
532
+ ) as demo:
533
+ with gr.Row():
534
+ gr.Markdown(
535
+ """
536
+ ## ESPnet-SDS
537
+ Welcome to our unified web interface for various cascaded and
538
+ E2E spoken dialogue systems built using ESPnet-SDS toolkit,
539
+ supporting real-time automated evaluation metrics, and
540
+ human-in-the-loop feedback collection.
541
+
542
+ For more details on how to use the app, refer to the [README]
543
+ (https://github.com/siddhu001/espnet/tree/sds_demo_recipe/egs2/TEMPLATE/sds1#how-to-use).
544
+ """
545
+ )
546
+ with gr.Row():
547
+ with gr.Column(scale=1):
548
+ user_audio = gr.Audio(
549
+ sources=["microphone"],
550
+ streaming=True,
551
+ waveform_options=gr.WaveformOptions(sample_rate=16000),
552
+ )
553
+ input_text=gr.Textbox(
554
+ label="LLM prompt",
555
+ visible=True,
556
+ interactive=True,
557
+ value=default_instruct
558
+ )
559
+ with gr.Row():
560
+ type_radio = gr.Radio(
561
+ choices=["Cascaded", "E2E"],
562
+ label="Choose type of Spoken Dialog:",
563
+ value="Cascaded",
564
+ )
565
+ with gr.Row():
566
+ ASR_radio = gr.Radio(
567
+ choices=ASR_options,
568
+ label="Choose ASR:",
569
+ value=ASR_name,
570
+ )
571
+ with gr.Row():
572
+ LLM_radio = gr.Radio(
573
+ choices=LLM_options,
574
+ label="Choose LLM:",
575
+ value=LLM_name,
576
+ )
577
+ with gr.Row():
578
+ radio = gr.Radio(
579
+ choices=TTS_options,
580
+ label="Choose TTS:",
581
+ value=TTS_name,
582
+ )
583
+ with gr.Row():
584
+ E2Eradio = gr.Radio(
585
+ choices=["mini-omni"],
586
+ label="Choose E2E model:",
587
+ value="mini-omni",
588
+ visible=False,
589
+ )
590
+ with gr.Row():
591
+ feedback_btn = gr.Button(
592
+ value=(
593
+ "Please provide your feedback "
594
+ "after each system response below."
595
+ ),
596
+ visible=True,
597
+ interactive=False,
598
+ elem_id="button",
599
+ )
600
+ with gr.Row():
601
+ natural_btn1 = gr.Button(
602
+ value="Very Natural", visible=False, interactive=False, scale=1
603
+ )
604
+ natural_btn2 = gr.Button(
605
+ value="Somewhat Awkward", visible=False, interactive=False, scale=1
606
+ )
607
+ natural_btn3 = gr.Button(
608
+ value="Very Awkward", visible=False, interactive=False, scale=1
609
+ )
610
+ natural_btn4 = gr.Button(
611
+ value="Unnatural", visible=False, interactive=False, scale=1
612
+ )
613
+ with gr.Row():
614
+ relevant_btn1 = gr.Button(
615
+ value="Highly Relevant", visible=False, interactive=False, scale=1
616
+ )
617
+ relevant_btn2 = gr.Button(
618
+ value="Partially Relevant",
619
+ visible=False,
620
+ interactive=False,
621
+ scale=1,
622
+ )
623
+ relevant_btn3 = gr.Button(
624
+ value="Slightly Irrelevant",
625
+ visible=False,
626
+ interactive=False,
627
+ scale=1,
628
+ )
629
+ relevant_btn4 = gr.Button(
630
+ value="Completely Irrelevant",
631
+ visible=False,
632
+ interactive=False,
633
+ scale=1,
634
+ )
635
+ with gr.Column(scale=1):
636
+ output_audio = gr.Audio(label="Output", autoplay=True, visible=True, interactive=False)
637
+ output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False, interactive=False)
638
+ output_asr_text = gr.Textbox(label="ASR output", interactive=False)
639
+ output_text = gr.Textbox(label="LLM output", interactive=False)
640
+ eval_radio = gr.Radio(
641
+ choices=[
642
+ "Latency",
643
+ "TTS Intelligibility",
644
+ "TTS Speech Quality",
645
+ "ASR WER",
646
+ "Text Dialog Metrics",
647
+ ],
648
+ label="Choose Evaluation metrics:",
649
+ )
650
+ eval_radio_E2E = gr.Radio(
651
+ choices=[
652
+ "Latency",
653
+ "TTS Intelligibility",
654
+ "TTS Speech Quality",
655
+ "Text Dialog Metrics",
656
+ ],
657
+ label="Choose Evaluation metrics:",
658
+ visible=False,
659
+ )
660
+ output_eval_text = gr.Textbox(label="Evaluation Results")
661
+ state = gr.State()
662
+ gr.Markdown("### Example Prompts & Responses")
663
+ gr.DataFrame(value=examples, headers=["Task", "LLM Prompt"], interactive=False)
664
+ with gr.Row():
665
+ privacy_text = gr.Textbox(
666
+ label="Privacy Notice",
667
+ interactive=False,
668
+ value=(
669
+ "By using this demo, you acknowledge that"
670
+ "interactions with this dialog system are collected "
671
+ "for research and improvement purposes. The data "
672
+ "will only be used to enhance the performance and "
673
+ "understanding of the system. If you have any "
674
+ "concerns about data collection, please discontinue "
675
+ "use."
676
+ ),
677
+ )
678
 
679
+ btn_list = [
680
+ natural_btn1,
681
+ natural_btn2,
682
+ natural_btn3,
683
+ natural_btn4,
684
+ relevant_btn1,
685
+ relevant_btn2,
686
+ relevant_btn3,
687
+ relevant_btn4,
688
+ ]
689
+ natural_btn_list = [
690
+ natural_btn1,
691
+ natural_btn2,
692
+ natural_btn3,
693
+ natural_btn4,
694
+ ]
695
+ relevant_btn_list = [
696
+ relevant_btn1,
697
+ relevant_btn2,
698
+ relevant_btn3,
699
+ relevant_btn4,
700
+ ]
701
+ natural_response = gr.Textbox(
702
+ label="natural_response", visible=False, interactive=False
703
+ )
704
+ diversity_response = gr.Textbox(
705
+ label="diversity_response", visible=False, interactive=False
706
+ )
707
+ ip_address = gr.Textbox(label="ip_address", visible=False, interactive=False)
708
+ callback.setup(
709
+ [
710
+ user_audio,
711
+ output_asr_text,
712
+ output_text,
713
+ output_audio,
714
+ output_audio1,
715
+ type_radio,
716
+ ASR_radio,
717
+ LLM_radio,
718
+ radio,
719
+ E2Eradio,
720
+ natural_response,
721
+ diversity_response,
722
+ ip_address,
723
+ ],
724
+ "flagged_data_points",
725
+ )
726
+ user_audio.stream(
727
+ transcribe,
728
+ inputs=[state, user_audio, radio, ASR_radio, LLM_radio, type_radio, input_text],
729
+ outputs=[state, output_asr_text, output_text, output_audio, output_audio1],
730
+ ).then(
731
+ lambda *args: callback.flag(list(args)), [user_audio], None, preprocess=False
732
+ )
733
+ radio.change(
734
+ fn=dialogue_model.handle_TTS_selection,
735
+ inputs=[radio],
736
+ outputs=[output_asr_text, output_text, output_audio],
737
+ )
738
+ LLM_radio.change(
739
+ fn=dialogue_model.handle_LLM_selection,
740
+ inputs=[LLM_radio],
741
+ outputs=[output_asr_text, output_text, output_audio],
742
+ )
743
+ ASR_radio.change(
744
+ fn=dialogue_model.handle_ASR_selection,
745
+ inputs=[ASR_radio],
746
+ outputs=[output_asr_text, output_text, output_audio],
747
+ )
748
+ eval_radio.change(
749
+ fn=handle_eval_selection,
750
+ inputs=[eval_radio, output_audio, output_text, output_audio1, output_asr_text],
751
+ outputs=[eval_radio, output_eval_text],
752
+ )
753
+ eval_radio_E2E.change(
754
+ fn=handle_eval_selection_E2E,
755
+ inputs=[eval_radio_E2E, output_audio, output_text],
756
+ outputs=[eval_radio_E2E, output_eval_text],
757
+ )
758
+ type_radio.change(
759
+ fn=dialogue_model.handle_type_selection,
760
+ inputs=[type_radio, radio, ASR_radio, LLM_radio],
761
+ outputs=[
762
+ radio,
763
+ ASR_radio,
764
+ LLM_radio,
765
+ E2Eradio,
766
+ output_asr_text,
767
+ output_text,
768
+ output_audio,
769
+ eval_radio,
770
+ eval_radio_E2E,
771
+ ],
772
+ )
773
+ output_audio.play(
774
+ flash_buttons, [], [natural_response, diversity_response] + btn_list
775
+ ).then(
776
+ lambda *args: callback.flag(list(args)),
777
+ [
778
+ user_audio,
779
+ output_asr_text,
780
+ output_text,
781
+ output_audio,
782
+ output_audio1,
783
+ type_radio,
784
+ ASR_radio,
785
+ LLM_radio,
786
+ radio,
787
+ E2Eradio,
788
+ ],
789
+ None,
790
+ preprocess=False,
791
+ )
792
+ natural_btn1.click(
793
+ natural_vote1_last_response,
794
+ [],
795
+ [natural_response, ip_address] + natural_btn_list,
796
+ ).then(
797
+ lambda *args: callback.flag(list(args)),
798
+ [
799
+ user_audio,
800
+ output_asr_text,
801
+ output_text,
802
+ output_audio,
803
+ output_audio1,
804
+ type_radio,
805
+ ASR_radio,
806
+ LLM_radio,
807
+ radio,
808
+ E2Eradio,
809
+ natural_response,
810
+ diversity_response,
811
+ ip_address,
812
+ ],
813
+ None,
814
+ preprocess=False,
815
+ )
816
+ natural_btn2.click(
817
+ natural_vote2_last_response,
818
+ [],
819
+ [natural_response, ip_address] + natural_btn_list,
820
+ ).then(
821
+ lambda *args: callback.flag(list(args)),
822
+ [
823
+ user_audio,
824
+ output_asr_text,
825
+ output_text,
826
+ output_audio,
827
+ output_audio1,
828
+ type_radio,
829
+ ASR_radio,
830
+ LLM_radio,
831
+ radio,
832
+ E2Eradio,
833
+ natural_response,
834
+ diversity_response,
835
+ ip_address,
836
+ ],
837
+ None,
838
+ preprocess=False,
839
+ )
840
+ natural_btn3.click(
841
+ natural_vote3_last_response,
842
+ [],
843
+ [natural_response, ip_address] + natural_btn_list,
844
+ ).then(
845
+ lambda *args: callback.flag(list(args)),
846
+ [
847
+ user_audio,
848
+ output_asr_text,
849
+ output_text,
850
+ output_audio,
851
+ output_audio1,
852
+ type_radio,
853
+ ASR_radio,
854
+ LLM_radio,
855
+ radio,
856
+ E2Eradio,
857
+ natural_response,
858
+ diversity_response,
859
+ ip_address,
860
+ ],
861
+ None,
862
+ preprocess=False,
863
+ )
864
+ natural_btn4.click(
865
+ natural_vote4_last_response,
866
+ [],
867
+ [natural_response, ip_address] + natural_btn_list,
868
+ ).then(
869
+ lambda *args: callback.flag(list(args)),
870
+ [
871
+ user_audio,
872
+ output_asr_text,
873
+ output_text,
874
+ output_audio,
875
+ output_audio1,
876
+ type_radio,
877
+ ASR_radio,
878
+ LLM_radio,
879
+ radio,
880
+ E2Eradio,
881
+ natural_response,
882
+ diversity_response,
883
+ ip_address,
884
+ ],
885
+ None,
886
+ preprocess=False,
887
+ )
888
+ relevant_btn1.click(
889
+ relevant_vote1_last_response,
890
+ [],
891
+ [diversity_response, ip_address] + relevant_btn_list,
892
+ ).then(
893
+ lambda *args: callback.flag(list(args)),
894
+ [
895
+ user_audio,
896
+ output_asr_text,
897
+ output_text,
898
+ output_audio,
899
+ output_audio1,
900
+ type_radio,
901
+ ASR_radio,
902
+ LLM_radio,
903
+ radio,
904
+ E2Eradio,
905
+ natural_response,
906
+ diversity_response,
907
+ ip_address,
908
+ ],
909
+ None,
910
+ preprocess=False,
911
+ )
912
+ relevant_btn2.click(
913
+ relevant_vote2_last_response,
914
+ [],
915
+ [diversity_response, ip_address] + relevant_btn_list,
916
+ ).then(
917
+ lambda *args: callback.flag(list(args)),
918
+ [
919
+ user_audio,
920
+ output_asr_text,
921
+ output_text,
922
+ output_audio,
923
+ output_audio1,
924
+ type_radio,
925
+ ASR_radio,
926
+ LLM_radio,
927
+ radio,
928
+ E2Eradio,
929
+ natural_response,
930
+ diversity_response,
931
+ ip_address,
932
+ ],
933
+ None,
934
+ preprocess=False,
935
+ )
936
+ relevant_btn3.click(
937
+ relevant_vote3_last_response,
938
+ [],
939
+ [diversity_response, ip_address] + relevant_btn_list,
940
+ ).then(
941
+ lambda *args: callback.flag(list(args)),
942
+ [
943
+ user_audio,
944
+ output_asr_text,
945
+ output_text,
946
+ output_audio,
947
+ output_audio1,
948
+ type_radio,
949
+ ASR_radio,
950
+ LLM_radio,
951
+ radio,
952
+ E2Eradio,
953
+ natural_response,
954
+ diversity_response,
955
+ ip_address,
956
+ ],
957
+ None,
958
+ preprocess=False,
959
+ )
960
+ relevant_btn4.click(
961
+ relevant_vote4_last_response,
962
+ [],
963
+ [diversity_response, ip_address] + relevant_btn_list,
964
+ ).then(
965
+ lambda *args: callback.flag(list(args)),
966
+ [
967
+ user_audio,
968
+ output_asr_text,
969
+ output_text,
970
+ output_audio,
971
+ output_audio1,
972
+ type_radio,
973
+ ASR_radio,
974
+ LLM_radio,
975
+ radio,
976
+ E2Eradio,
977
+ natural_response,
978
+ diversity_response,
979
+ ip_address,
980
+ ],
981
+ None,
982
+ preprocess=False,
983
+ )
984
+ demo.queue(max_size=10, default_concurrency_limit=1)
985
+ demo.launch(share=True)
986
 
 
 
pyscripts/utils/dialog_eval/ASR_WER.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+
5
+ from espnet2.sds.utils.utils import int2float
6
+
7
+
8
+ def handle_espnet_ASR_WER(
9
+ ASR_audio_output: Tuple[int, np.ndarray], ASR_transcript: str
10
+ ) -> str:
11
+ """
12
+ Compute and return Word Error Rate (WER) and Character Error Rate (CER) metrics
13
+ for multiple judge ASR systems (ESPnet, OWSM, Whisper) using the Versa library.
14
+
15
+ This function performs the following:
16
+ 1. Imports necessary metrics and setup functions from Versa.
17
+ 2. Prepares configuration arguments for each ASR system (ESPnet, OWSM, Whisper).
18
+ 3. Runs the Levenshtein-based WER/CER calculations.
19
+ 4. Returns a formatted string summarizing WER and CER
20
+ results for reference produced by each ASR system.
21
+
22
+ Args:
23
+ ASR_audio_output (tuple):
24
+ A tuple where:
25
+ - The first element is the frame rate.
26
+ - The second element is the audio signal (NumPy array).
27
+ ASR_transcript (str):
28
+ The transcript produced by the ASR model in the cascaded
29
+ conversational AI pipeline.
30
+
31
+ Returns:
32
+ str:
33
+ A formatted string showing the WER and CER percentages
34
+ for ESPnet, OWSM, and Whisper. Example output:
35
+
36
+ "ESPnet WER: 10.50
37
+ ESPnet CER: 7.20
38
+ OWSM WER: 11.30
39
+ OWSM CER: 8.00
40
+ Whisper WER: 9.25
41
+ Whisper CER: 6.50"
42
+
43
+ Raises:
44
+ ImportError:
45
+ If Versa is not installed or cannot be imported.
46
+
47
+ Example:
48
+ >>> asr_audio_output = (16000, audio_array)
49
+ >>> asr_transcript = "This is the ASR transcript."
50
+ >>> result = handle_espnet_ASR_WER(asr_audio_output, asr_transcript)
51
+ >>> print(result)
52
+ "ESPnet WER: 10.50
53
+ ESPnet CER: 7.20
54
+ OWSM WER: 11.30
55
+ OWSM CER: 8.00
56
+ Whisper WER: 9.25
57
+ Whisper CER: 6.50"
58
+ """
59
+ try:
60
+ from versa import (
61
+ espnet_levenshtein_metric,
62
+ espnet_wer_setup,
63
+ owsm_levenshtein_metric,
64
+ owsm_wer_setup,
65
+ whisper_levenshtein_metric,
66
+ whisper_wer_setup,
67
+ )
68
+ except Exception as e:
69
+ print("Error: Versa is not properly installed.")
70
+ raise e
71
+ score_modules_espnet = {
72
+ "module": espnet_levenshtein_metric,
73
+ "args": espnet_wer_setup(
74
+ model_tag="default",
75
+ beam_size=1,
76
+ text_cleaner="whisper_en",
77
+ use_gpu=True,
78
+ ),
79
+ }
80
+ dict1 = score_modules_espnet["module"](
81
+ score_modules_espnet["args"],
82
+ int2float(ASR_audio_output[1]),
83
+ ASR_transcript,
84
+ ASR_audio_output[0],
85
+ )
86
+ espnet_wer = (
87
+ dict1["espnet_wer_delete"]
88
+ + dict1["espnet_wer_insert"]
89
+ + dict1["espnet_wer_replace"]
90
+ ) / (
91
+ dict1["espnet_wer_insert"]
92
+ + dict1["espnet_wer_replace"]
93
+ + dict1["espnet_wer_equal"]
94
+ )
95
+ espnet_cer = (
96
+ dict1["espnet_cer_delete"]
97
+ + dict1["espnet_cer_insert"]
98
+ + dict1["espnet_cer_replace"]
99
+ ) / (
100
+ dict1["espnet_cer_insert"]
101
+ + dict1["espnet_cer_replace"]
102
+ + dict1["espnet_cer_equal"]
103
+ )
104
+ score_modules_owsm = {
105
+ "module": owsm_levenshtein_metric,
106
+ "args": owsm_wer_setup(
107
+ model_tag="default",
108
+ beam_size=1,
109
+ text_cleaner="whisper_en",
110
+ use_gpu=True,
111
+ ),
112
+ }
113
+ dict1 = score_modules_owsm["module"](
114
+ score_modules_owsm["args"],
115
+ int2float(ASR_audio_output[1]),
116
+ ASR_transcript,
117
+ ASR_audio_output[0],
118
+ )
119
+ owsm_wer = (
120
+ dict1["owsm_wer_delete"] + dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"]
121
+ ) / (dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"] + dict1["owsm_wer_equal"])
122
+ owsm_cer = (
123
+ dict1["owsm_cer_delete"] + dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"]
124
+ ) / (dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"] + dict1["owsm_cer_equal"])
125
+ score_modules_whisper = {
126
+ "module": whisper_levenshtein_metric,
127
+ "args": whisper_wer_setup(
128
+ model_tag="default",
129
+ beam_size=1,
130
+ text_cleaner="whisper_en",
131
+ use_gpu=True,
132
+ ),
133
+ }
134
+ dict1 = score_modules_whisper["module"](
135
+ score_modules_whisper["args"],
136
+ int2float(ASR_audio_output[1]),
137
+ ASR_transcript,
138
+ ASR_audio_output[0],
139
+ )
140
+ whisper_wer = (
141
+ dict1["whisper_wer_delete"]
142
+ + dict1["whisper_wer_insert"]
143
+ + dict1["whisper_wer_replace"]
144
+ ) / (
145
+ dict1["whisper_wer_insert"]
146
+ + dict1["whisper_wer_replace"]
147
+ + dict1["whisper_wer_equal"]
148
+ )
149
+ whisper_cer = (
150
+ dict1["whisper_cer_delete"]
151
+ + dict1["whisper_cer_insert"]
152
+ + dict1["whisper_cer_replace"]
153
+ ) / (
154
+ dict1["whisper_cer_insert"]
155
+ + dict1["whisper_cer_replace"]
156
+ + dict1["whisper_cer_equal"]
157
+ )
158
+ return (
159
+ f"ESPnet WER: {espnet_wer*100:.2f}\n"
160
+ f"ESPnet CER: {espnet_cer*100:.2f}\n"
161
+ f"OWSM WER: {owsm_wer*100:.2f}\n"
162
+ f"OWSM CER: {owsm_cer*100:.2f}\n"
163
+ f"Whisper WER: {whisper_wer*100:.2f}\n"
164
+ f"Whisper CER: {whisper_cer*100:.2f}"
165
+ )
pyscripts/utils/dialog_eval/LLM_Metrics.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import Pool
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import torch
6
+ from pyscripts.utils.dialog_eval.vert import (
7
+ get_auto_bleu2_geometric,
8
+ get_self_bleu2_geometric,
9
+ run_f,
10
+ )
11
+ from scipy.stats import gmean
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
14
+
15
+
16
+ def perplexity(LLM_Output: str, model_id: str = "gpt2") -> str:
17
+ """
18
+ Compute the perplexity of the given text using a specified model from the
19
+ `evaluate` library (default: GPT-2).
20
+
21
+ Args:
22
+ LLM_Output str:
23
+ The text (string) for which perplexity is to be computed.
24
+ model_id (str, optional):
25
+ The identifier of the model to use for computing
26
+ perplexity. Defaults to "gpt2".
27
+
28
+ Returns:
29
+ str:
30
+ A formatted string showing the perplexity of the
31
+ provided text(s), for example:
32
+ "Perplexity: 45.23\n"
33
+
34
+ Raises:
35
+ ImportError:
36
+ If the `evaluate` library is not installed or cannot be imported.
37
+
38
+ Example:
39
+ >>> text = "Hello world, this is a test."
40
+ >>> result = perplexity(text, model_id="gpt2")
41
+ >>> print(result)
42
+ "Perplexity: 27.34\n"
43
+ """
44
+ try:
45
+ import evaluate
46
+ except Exception as e:
47
+ print("Error: evaluate is not properly installed.")
48
+ raise e
49
+ perplexity = evaluate.load("perplexity", module_type="metric")
50
+ results = perplexity.compute(model_id=model_id, predictions=[LLM_Output])
51
+ return f"Perplexity: {results['mean_perplexity']:.2f}\n"
52
+
53
+
54
+ def vert(LLM_response_arr: List[str]) -> str:
55
+ """
56
+ Calculate and return Self BLEU-2, Auto BLEU-2 and VERT-2
57
+ metrics for a list of LLM responses.
58
+
59
+ Args:
60
+ LLM_response_arr (List[str]):
61
+ A list of responses (strings) generated by the language
62
+ model acting as text dialog response generator.
63
+
64
+ Returns:
65
+ str:
66
+ A formatted string that includes each computed metric and the final
67
+ VERT value, for example:
68
+
69
+ "Self-BLEU2-geometric: 42.13
70
+ Auto-BLEU2-geometric: 38.94
71
+ VERT: 40.5
72
+ "
73
+
74
+ Example:
75
+ >>> # Suppose we have the following LLM responses:
76
+ >>> responses = ["Hello world", "Foo bar", "Lorem ipsum dolor sit amet"]
77
+ >>> result = vert(responses)
78
+ >>> print(result)
79
+ "Self-BLEU2-geometric: 42.13
80
+ Auto-BLEU2-geometric: 38.94
81
+ VERT: 40.5
82
+ "
83
+ """
84
+ terms = [x.strip().split() for x in LLM_response_arr]
85
+
86
+ tasks = [
87
+ ("Self-BLEU2-geometric", get_self_bleu2_geometric),
88
+ ("Auto-BLEU2-geometric", get_auto_bleu2_geometric),
89
+ ]
90
+ n_processes = min(16, len(tasks))
91
+ with Pool(n_processes) as pool:
92
+ metrics = pool.map(run_f, [(t[1], terms) for t in tasks])
93
+ metric_arr = []
94
+ str1 = ""
95
+ for (metric_name, _), metric in zip(tasks, metrics):
96
+ metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric))
97
+
98
+ metric, sem = [round(100 * x, 2) for x in [metric, sem]]
99
+ metric_arr.append(metric)
100
+
101
+ str1 += f"{metric_name}: {metric}\n"
102
+ str1 += f"VERT: {round(gmean(metric_arr), 2)}\n"
103
+ return str1
104
+
105
+
106
+ def bert_score(
107
+ total_response_arr: List[str], bert_model_name: str = "bert-base-uncased"
108
+ ) -> str:
109
+ """
110
+ Compute a cosine similarity score between the concatenated
111
+ context (all but the last element)
112
+ and the final response (last element) using a BERT-based model.
113
+ This serves as a simplified
114
+ measure of how closely the response aligns with the preceding context semantically.
115
+
116
+ Args:
117
+ total_response_arr (List[str]):
118
+ A list of strings. The last element represents the response,
119
+ while all other elements
120
+ are treated as the context.
121
+ bert_model_name (str, optional):
122
+ The name or path of the BERT model to use (from the Hugging Face Model Hub).
123
+ Defaults to "bert-base-uncased".
124
+
125
+ Returns:
126
+ str:
127
+ A string containing the cosine similarity
128
+ (as a percentage) followed by a newline.
129
+ For example:
130
+ "Cosine Similarity: 85.67\n"
131
+
132
+ Example:
133
+ >>> total_responses = [
134
+ ... "User: Hi, how are you?",
135
+ ... "Assistant: I'm good! How can I help you today?",
136
+ ... "User: Can you tell me a joke?",
137
+ ... "Assistant: Sure! Here's one: Why did the chicken join a band?"
138
+ ... ]
139
+ >>> result = bert_score(total_responses, bert_model_name="bert-base-uncased")
140
+ >>> print(result)
141
+ "Cosine Similarity: 75.89\n"
142
+ """
143
+
144
+ def cosine_similarity_context_response(context, response, model, tokenizer):
145
+ # Tokenize and encode both context and response
146
+ context_inputs = tokenizer(context, return_tensors="pt", truncation=True)
147
+ response_inputs = tokenizer(response, return_tensors="pt", truncation=True)
148
+ for k in context_inputs:
149
+ context_inputs[k] = context_inputs[k].cuda()
150
+ for k in response_inputs:
151
+ response_inputs[k] = response_inputs[k].cuda()
152
+
153
+ # Get embeddings from the model
154
+ with torch.no_grad():
155
+ context_embedding = model(**context_inputs).last_hidden_state.mean(dim=1)
156
+ response_embedding = model(**response_inputs).last_hidden_state.mean(dim=1)
157
+
158
+ # Compute cosine similarity
159
+ similarity = cosine_similarity(
160
+ context_embedding.cpu().numpy(), response_embedding.cpu().numpy()
161
+ )
162
+ return similarity[0][0]
163
+
164
+ bert_model = AutoModel.from_pretrained(bert_model_name).cuda()
165
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
166
+ similarity = cosine_similarity_context_response(
167
+ " ".join(total_response_arr[:-1]),
168
+ total_response_arr[-1],
169
+ bert_model,
170
+ bert_tokenizer,
171
+ )
172
+ return f"Cosine Similarity: {similarity*100:.2f}" + "\n"
173
+
174
+
175
+ def DialoGPT_perplexity(
176
+ user_utterance: str,
177
+ response: str,
178
+ dialog_model_name: str = "microsoft/DialoGPT-medium",
179
+ ) -> str:
180
+ """
181
+ Compute the perplexity of a response given a user utterance using a pre-trained
182
+ DialoGPT model. The function loads DialoGPT (medium by default)
183
+ from the Hugging Face Model Hub, then calculates the perplexity
184
+ for the
185
+ (context + response) sequence.
186
+
187
+ Args:
188
+ user_utterance (str):
189
+ The user utterance preceding the model's response.
190
+ response (str):
191
+ The generated response whose perplexity needs to be evaluated.
192
+
193
+ Returns:
194
+ str:
195
+ A formatted string containing the DialoGPT perplexity score. For example:
196
+ "DialoGPT Perplexity: 25.67\n"
197
+
198
+ Example:
199
+ >>> user_text = "Hi, how are you today?"
200
+ >>> system_response = "I'm good, thank you! How can I help you?"
201
+ >>> result = DialoGPT_perplexity(user_text, system_response)
202
+ >>> print(result)
203
+ "DialoGPT Perplexity: 31.45\n"
204
+ """
205
+
206
+ def evaluate_response_with_dialoGPT(context, response, model, tokenizer):
207
+ """
208
+ Evaluate the appropriateness of a response based on the
209
+ given context using DialoGPT.
210
+
211
+ Args:
212
+ context (str): The dialogue context (previous conversation).
213
+ response (str): The generated response to evaluate.
214
+ model: Pre-trained DialoGPT model.
215
+ tokenizer: Corresponding tokenizer for the DialoGPT model.
216
+
217
+ Returns:
218
+ float: Perplexity score of the response given the context.
219
+ """
220
+ model.eval()
221
+
222
+ # Combine context and response as input
223
+ input_text = context + tokenizer.eos_token + response + tokenizer.eos_token
224
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
225
+ inputs["input_ids"] = inputs["input_ids"].cuda()
226
+ inputs["attention_mask"] = inputs["attention_mask"].cuda()
227
+ # import pdb;pdb.set_trace()
228
+
229
+ # Compute model outputs and loss
230
+ with torch.no_grad():
231
+ outputs = model(**inputs, labels=inputs["input_ids"].cuda())
232
+ loss = outputs.loss
233
+
234
+ # Calculate perplexity
235
+ perplexity = torch.exp(loss)
236
+ return perplexity.cpu().item()
237
+
238
+ # Load DialoGPT model and tokenizer
239
+ model_name = dialog_model_name
240
+ model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
241
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
242
+ perplexity = evaluate_response_with_dialoGPT(
243
+ user_utterance, response, model, tokenizer
244
+ )
245
+ return f"DialoGPT Perplexity: {perplexity:.2f}" + "\n"
pyscripts/utils/dialog_eval/TTS_intelligibility.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+
5
+ from espnet2.sds.utils.utils import int2float
6
+
7
+
8
+ def handle_espnet_TTS_intelligibility(
9
+ TTS_audio_output: Tuple[int, np.ndarray], LLM_Output: str
10
+ ) -> str:
11
+ """
12
+ Compute and return Word Error Rate (WER) and Character Error Rate (CER) metrics
13
+ for multiple ASR systems (ESPnet, OWSM, Whisper) using the Versa library.
14
+
15
+ This function:
16
+ 1. Imports the necessary metrics and setup functions from Versa.
17
+ 2. Prepares configuration arguments for each ASR system (ESPnet, OWSM, Whisper).
18
+ 3. Runs the Levenshtein-based WER/CER calculations on the provided TTS audio.
19
+ 4. Returns a formatted string summarizing WER and CER results
20
+ for hypotheses produced
21
+ by each ASR system when transcribing the TTS audio, using
22
+ the LLM output as the reference text.
23
+
24
+ Args:
25
+ TTS_audio_output (Tuple[int, np.ndarray]):
26
+ A tuple consisting of:
27
+ - The first element (int): the frame rate of the audio.
28
+ - The second element (np.ndarray):
29
+ the audio signal (e.g., a NumPy array).
30
+ LLM_Output (str):
31
+ The reference text generated by the LLM, which serves as the ground truth
32
+ for evaluating the TTS audio.
33
+
34
+ Returns:
35
+ str:
36
+ A formatted string showing the WER and CER percentages
37
+ for ESPnet, OWSM, and Whisper.
38
+ Example:
39
+
40
+ ESPnet WER: 10.50
41
+ ESPnet CER: 7.20
42
+ OWSM WER: 11.30
43
+ OWSM CER: 8.00
44
+ Whisper WER: 9.25
45
+ Whisper CER: 6.50
46
+
47
+ Raises:
48
+ ImportError:
49
+ If the Versa library is not installed or cannot be imported.
50
+
51
+ Example:
52
+ >>> tts_audio_output = (16000, audio_array)
53
+ >>> llm_output = "This is the reference text for evaluation."
54
+ >>> result = handle_espnet_TTS_intelligibility(tts_audio_output, llm_output)
55
+ >>> print(result)
56
+ ESPnet WER: 10.50
57
+ ESPnet CER: 7.20
58
+ OWSM WER: 11.30
59
+ OWSM CER: 8.00
60
+ Whisper WER: 9.25
61
+ Whisper CER: 6.50
62
+ """
63
+ try:
64
+ from versa import (
65
+ espnet_levenshtein_metric,
66
+ espnet_wer_setup,
67
+ owsm_levenshtein_metric,
68
+ owsm_wer_setup,
69
+ whisper_levenshtein_metric,
70
+ whisper_wer_setup,
71
+ )
72
+ except Exception as e:
73
+ print("Error: Versa is not properly installed.")
74
+ raise e
75
+ score_modules_espnet = {
76
+ "module": espnet_levenshtein_metric,
77
+ "args": espnet_wer_setup(
78
+ model_tag="default",
79
+ beam_size=1,
80
+ text_cleaner="whisper_en",
81
+ use_gpu=True,
82
+ ),
83
+ }
84
+ dict1 = score_modules_espnet["module"](
85
+ score_modules_espnet["args"],
86
+ int2float(TTS_audio_output[1]),
87
+ LLM_Output,
88
+ TTS_audio_output[0],
89
+ )
90
+ espnet_wer = (
91
+ dict1["espnet_wer_delete"]
92
+ + dict1["espnet_wer_insert"]
93
+ + dict1["espnet_wer_replace"]
94
+ ) / (
95
+ dict1["espnet_wer_delete"]
96
+ + dict1["espnet_wer_replace"]
97
+ + dict1["espnet_wer_equal"]
98
+ )
99
+ espnet_cer = (
100
+ dict1["espnet_cer_delete"]
101
+ + dict1["espnet_cer_insert"]
102
+ + dict1["espnet_cer_replace"]
103
+ ) / (
104
+ dict1["espnet_cer_delete"]
105
+ + dict1["espnet_cer_replace"]
106
+ + dict1["espnet_cer_equal"]
107
+ )
108
+ score_modules_owsm = {
109
+ "module": owsm_levenshtein_metric,
110
+ "args": owsm_wer_setup(
111
+ model_tag="default",
112
+ beam_size=1,
113
+ text_cleaner="whisper_en",
114
+ use_gpu=True,
115
+ ),
116
+ }
117
+ dict1 = score_modules_owsm["module"](
118
+ score_modules_owsm["args"],
119
+ int2float(TTS_audio_output[1]),
120
+ LLM_Output,
121
+ TTS_audio_output[0],
122
+ )
123
+ owsm_wer = (
124
+ dict1["owsm_wer_delete"] + dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"]
125
+ ) / (dict1["owsm_wer_delete"] + dict1["owsm_wer_replace"] + dict1["owsm_wer_equal"])
126
+ owsm_cer = (
127
+ dict1["owsm_cer_delete"] + dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"]
128
+ ) / (dict1["owsm_cer_delete"] + dict1["owsm_cer_replace"] + dict1["owsm_cer_equal"])
129
+ score_modules_whisper = {
130
+ "module": whisper_levenshtein_metric,
131
+ "args": whisper_wer_setup(
132
+ model_tag="default",
133
+ beam_size=1,
134
+ text_cleaner="whisper_en",
135
+ use_gpu=True,
136
+ ),
137
+ }
138
+ dict1 = score_modules_whisper["module"](
139
+ score_modules_whisper["args"],
140
+ int2float(TTS_audio_output[1]),
141
+ LLM_Output,
142
+ TTS_audio_output[0],
143
+ )
144
+ whisper_wer = (
145
+ dict1["whisper_wer_delete"]
146
+ + dict1["whisper_wer_insert"]
147
+ + dict1["whisper_wer_replace"]
148
+ ) / (
149
+ dict1["whisper_wer_delete"]
150
+ + dict1["whisper_wer_replace"]
151
+ + dict1["whisper_wer_equal"]
152
+ )
153
+ whisper_cer = (
154
+ dict1["whisper_cer_delete"]
155
+ + dict1["whisper_cer_insert"]
156
+ + dict1["whisper_cer_replace"]
157
+ ) / (
158
+ dict1["whisper_cer_delete"]
159
+ + dict1["whisper_cer_replace"]
160
+ + dict1["whisper_cer_equal"]
161
+ )
162
+ return (
163
+ f"ESPnet WER: {espnet_wer*100:.2f}\n"
164
+ f"ESPnet CER: {espnet_cer*100:.2f}\n"
165
+ f"OWSM WER: {owsm_wer*100:.2f}\n"
166
+ f"OWSM CER: {owsm_cer*100:.2f}\n"
167
+ f"Whisper WER: {whisper_wer*100:.2f}\n"
168
+ f"Whisper CER: {whisper_cer*100:.2f}"
169
+ )
pyscripts/utils/dialog_eval/TTS_speech_quality.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+
5
+ from espnet2.sds.utils.utils import int2float
6
+
7
+
8
+ def TTS_psuedomos(TTS_audio_output: Tuple[int, np.ndarray]) -> str:
9
+ """
10
+ Compute and return speech quality metrics
11
+ for the given synthesized audio output
12
+ using the Versa library.
13
+
14
+ Args:
15
+ TTS_audio_output (Tuple[int, np.ndarray]):
16
+ A tuple containing:
17
+ - The first element (int): The frame rate of the audio.
18
+ - The second element (np.ndarray): The audio signal,
19
+ typically a NumPy array.
20
+
21
+ Returns:
22
+ str:
23
+ A formatted string containing each metric name
24
+ and its corresponding score, for example:
25
+
26
+ utmos: 3.54
27
+ dnsmos: 3.47
28
+ plcmos: 3.62
29
+ sheet_ssqa: 4.03
30
+
31
+ Raises:
32
+ ImportError:
33
+ If the Versa library is not installed or cannot be imported.
34
+
35
+ Example:
36
+ >>> tts_audio_output = (16000, audio_array)
37
+ >>> result = TTS_psuedomos(tts_audio_output)
38
+ >>> print(result)
39
+ utmos: 3.54
40
+ dnsmos: 3.47
41
+ plcmos: 3.62
42
+ sheet_ssqa: 4.03
43
+ """
44
+ try:
45
+ from versa import (
46
+ pseudo_mos_metric,
47
+ pseudo_mos_setup,
48
+ sheet_ssqa,
49
+ sheet_ssqa_setup,
50
+ )
51
+ except Exception as e:
52
+ print("Error: Versa is not properly installed.")
53
+ raise e
54
+
55
+ predictor_dict, predictor_fs = pseudo_mos_setup(
56
+ use_gpu=True,
57
+ predictor_types=["utmos", "dnsmos", "plcmos"],
58
+ predictor_args={
59
+ "utmos": {"fs": 16000},
60
+ "dnsmos": {"fs": 16000},
61
+ "plcmos": {"fs": 16000},
62
+ },
63
+ )
64
+ score_modules = {
65
+ "module": pseudo_mos_metric,
66
+ "args": {
67
+ "predictor_dict": predictor_dict,
68
+ "predictor_fs": predictor_fs,
69
+ "use_gpu": True,
70
+ },
71
+ }
72
+ dict1 = score_modules["module"](
73
+ int2float(TTS_audio_output[1]),
74
+ TTS_audio_output[0],
75
+ **score_modules["args"],
76
+ )
77
+ str1 = ""
78
+ for k in dict1:
79
+ str1 = str1 + f"{k}: {dict1[k]:.2f}\n"
80
+ sheet_model = sheet_ssqa_setup(
81
+ model_tag="default",
82
+ model_path=None,
83
+ model_config=None,
84
+ use_gpu=True,
85
+ )
86
+ score_modules = {
87
+ "module": sheet_ssqa,
88
+ "args": {"model": sheet_model, "use_gpu": True},
89
+ }
90
+ dict1 = score_modules["module"](
91
+ score_modules["args"]["model"],
92
+ int2float(TTS_audio_output[1]),
93
+ TTS_audio_output[0],
94
+ use_gpu=score_modules["args"]["use_gpu"],
95
+ )
96
+ for k in dict1:
97
+ str1 = str1 + f"{k}: {dict1[k]:.2f}\n"
98
+ return str1
pyscripts/utils/dialog_eval/human_feedback.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ disable_btn = gr.Button(interactive=False, visible=False)
4
+
5
+
6
+ def get_ip(request: gr.Request) -> str:
7
+ """
8
+ Retrieve the IP address from an incoming HTTP request.
9
+
10
+ Args:
11
+ request (gr.Request):
12
+ The incoming HTTP request from which the IP address will be extracted.
13
+
14
+ Returns:
15
+ str:
16
+ The IP address as a string.
17
+ """
18
+ if "cf-connecting-ip" in request.headers:
19
+ ip = request.headers["cf-connecting-ip"]
20
+ elif "x-forwarded-for" in request.headers:
21
+ ip = request.headers["x-forwarded-for"]
22
+ if "," in ip:
23
+ ip = ip.split(",")[0]
24
+ else:
25
+ ip = request.client.host
26
+ return ip
27
+
28
+
29
+ def natural_vote1_last_response(request: gr.Request):
30
+ """
31
+ Handle a user vote for naturalness as "Very Natural".
32
+
33
+
34
+ Args:
35
+ request (gr.Request):
36
+ The Gradio request object providing access to HTTP headers and metadata.
37
+
38
+ Returns:
39
+ tuple:
40
+ A tuple containing:
41
+ ("Very Natural", <ip_address>, (disable_btn,) * 4)
42
+
43
+ - "Very Natural": The selected vote or label.
44
+ - <ip_address>: The IP address of the client retrieved from the request.
45
+ - disable_btn: An object repeated four times,
46
+ to disable natural vote buttons.
47
+ """
48
+ ip_address1 = get_ip(request)
49
+ print(f"Very Natural (voted). ip: {ip_address1}")
50
+ return (
51
+ "Very Natural",
52
+ ip_address1,
53
+ ) + (disable_btn,) * 4
54
+
55
+
56
+ def natural_vote2_last_response(request: gr.Request):
57
+ """
58
+ Handle a user vote for naturalness as "Somewhat Awkward".
59
+
60
+
61
+ Args:
62
+ request (gr.Request):
63
+ The Gradio request object providing access to HTTP headers and metadata.
64
+
65
+ Returns:
66
+ tuple:
67
+ A tuple containing:
68
+ ("Somewhat Awkward", <ip_address>, (disable_btn,) * 4)
69
+
70
+ - "Somewhat Awkward": The selected vote or label.
71
+ - <ip_address>: The IP address of the client retrieved from the request.
72
+ - disable_btn: An object repeated four times,
73
+ to disable natural vote buttons.
74
+ """
75
+ ip_address1 = get_ip(request)
76
+ print(f"Somewhat Awkward (voted). ip: {ip_address1}")
77
+ return (
78
+ "Somewhat Awkward",
79
+ ip_address1,
80
+ ) + (disable_btn,) * 4
81
+
82
+
83
+ def natural_vote3_last_response(request: gr.Request):
84
+ """
85
+ Handle a user vote for naturalness as "Very Awkward".
86
+
87
+
88
+ Args:
89
+ request (gr.Request):
90
+ The Gradio request object providing access to HTTP headers and metadata.
91
+
92
+ Returns:
93
+ tuple:
94
+ A tuple containing:
95
+ ("Very Awkward", <ip_address>, (disable_btn,) * 4)
96
+
97
+ - "Very Awkward": The selected vote or label.
98
+ - <ip_address>: The IP address of the client retrieved from the request.
99
+ - disable_btn: An object repeated four times,
100
+ to disable natural vote buttons.
101
+ """
102
+ ip_address1 = get_ip(request)
103
+ print(f"Very Awkward (voted). ip: {ip_address1}")
104
+ return (
105
+ "Very Awkward",
106
+ ip_address1,
107
+ ) + (disable_btn,) * 4
108
+
109
+
110
+ def natural_vote4_last_response(request: gr.Request):
111
+ """
112
+ Handle a user vote for naturalness as "Unnatural".
113
+
114
+
115
+ Args:
116
+ request (gr.Request):
117
+ The Gradio request object providing access to HTTP headers and metadata.
118
+
119
+ Returns:
120
+ tuple:
121
+ A tuple containing:
122
+ ("Unnatural", <ip_address>, (disable_btn,) * 4)
123
+
124
+ - "Unnatural": The selected vote or label.
125
+ - <ip_address>: The IP address of the client retrieved from the request.
126
+ - disable_btn: An object repeated four times,
127
+ to disable natural vote buttons.
128
+ """
129
+ ip_address1 = get_ip(request)
130
+ print(f"Unnatural (voted). ip: {ip_address1}")
131
+ return (
132
+ "Unnatural",
133
+ ip_address1,
134
+ ) + (disable_btn,) * 4
135
+
136
+
137
+ def relevant_vote1_last_response(request: gr.Request):
138
+ """
139
+ Handle a user vote for relevance as "Highly Relevant".
140
+
141
+
142
+ Args:
143
+ request (gr.Request):
144
+ The Gradio request object providing access to HTTP headers and metadata.
145
+
146
+ Returns:
147
+ tuple:
148
+ A tuple containing:
149
+ ("Highly Relevant", <ip_address>, (disable_btn,) * 4)
150
+
151
+ - "Highly Relevant": The selected vote or label.
152
+ - <ip_address>: The IP address of the client retrieved from the request.
153
+ - disable_btn: An object repeated four times,
154
+ to disable relevance vote buttons.
155
+ """
156
+ ip_address1 = get_ip(request)
157
+ print(f"Highly Relevant (voted). ip: {ip_address1}")
158
+ return (
159
+ "Highly Relevant",
160
+ ip_address1,
161
+ ) + (disable_btn,) * 4
162
+
163
+
164
+ def relevant_vote2_last_response(request: gr.Request):
165
+ """
166
+ Handle a user vote for relevance as "Partially Relevant".
167
+
168
+
169
+ Args:
170
+ request (gr.Request):
171
+ The Gradio request object providing access to HTTP headers and metadata.
172
+
173
+ Returns:
174
+ tuple:
175
+ A tuple containing:
176
+ ("Partially Relevant", <ip_address>, (disable_btn,) * 4)
177
+
178
+ - "Partially Relevant": The selected vote or label.
179
+ - <ip_address>: The IP address of the client retrieved from the request.
180
+ - disable_btn: An object repeated four times,
181
+ to disable relevance vote buttons.
182
+ """
183
+ ip_address1 = get_ip(request)
184
+ print(f"Partially Relevant (voted). ip: {ip_address1}")
185
+ return (
186
+ "Partially Relevant",
187
+ ip_address1,
188
+ ) + (disable_btn,) * 4
189
+
190
+
191
+ def relevant_vote3_last_response(request: gr.Request):
192
+ """
193
+ Handle a user vote for relevance as "Slightly Irrelevant".
194
+
195
+
196
+ Args:
197
+ request (gr.Request):
198
+ The Gradio request object providing access to HTTP headers and metadata.
199
+
200
+ Returns:
201
+ tuple:
202
+ A tuple containing:
203
+ ("Slightly Irrelevant", <ip_address>, (disable_btn,) * 4)
204
+
205
+ - "Slightly Irrelevant": The selected vote or label.
206
+ - <ip_address>: The IP address of the client retrieved from the request.
207
+ - disable_btn: An object repeated four times,
208
+ to disable relevance vote buttons.
209
+ """
210
+ ip_address1 = get_ip(request)
211
+ print(f"Slightly Irrelevant (voted). ip: {ip_address1}")
212
+ return (
213
+ "Slightly Irrelevant",
214
+ ip_address1,
215
+ ) + (disable_btn,) * 4
216
+
217
+
218
+ def relevant_vote4_last_response(request: gr.Request):
219
+ """
220
+ Handle a user vote for relevance as "Completely Irrelevant".
221
+
222
+
223
+ Args:
224
+ request (gr.Request):
225
+ The Gradio request object providing access to HTTP headers and metadata.
226
+
227
+ Returns:
228
+ tuple:
229
+ A tuple containing:
230
+ ("Completely Irrelevant", <ip_address>, (disable_btn,) * 4)
231
+
232
+ - "Completely Irrelevant": The selected vote or label.
233
+ - <ip_address>: The IP address of the client retrieved from the request.
234
+ - disable_btn: An object repeated four times,
235
+ to disable relevance vote buttons.
236
+ """
237
+ ip_address1 = get_ip(request)
238
+ print(f"Completely Irrelevant (voted). ip: {ip_address1}")
239
+ return (
240
+ "Completely Irrelevant",
241
+ ip_address1,
242
+ ) + (disable_btn,) * 4
pyscripts/utils/dialog_eval/vert.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import sys
8
+ import warnings
9
+ from collections import Counter
10
+ from fractions import Fraction
11
+
12
+ import nltk
13
+ import numpy as np
14
+ from nltk.translate.bleu_score import (
15
+ SmoothingFunction,
16
+ brevity_penalty,
17
+ closest_ref_length,
18
+ modified_precision,
19
+ )
20
+
21
+
22
+ def corpus_bleu(
23
+ list_of_references,
24
+ hypotheses,
25
+ weights=(0.25, 0.25, 0.25, 0.25),
26
+ smoothing_function=None,
27
+ auto_reweigh=False,
28
+ averaging_mode="geometric",
29
+ no_length_penalty=False,
30
+ ):
31
+ """
32
+ Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all
33
+ the hypotheses and their respective references.
34
+
35
+ Instead of averaging the sentence level BLEU scores (i.e. marco-average
36
+ precision), the original BLEU metric (Papineni et al. 2002) accounts for
37
+ the micro-average precision (i.e. summing the numerators and denominators
38
+ for each hypothesis-reference(s) pairs before the division).
39
+
40
+ >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
41
+ ... 'ensures', 'that', 'the', 'military', 'always',
42
+ ... 'obeys', 'the', 'commands', 'of', 'the', 'party']
43
+ >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
44
+ ... 'ensures', 'that', 'the', 'military', 'will', 'forever',
45
+ ... 'heed', 'Party', 'commands']
46
+ >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which',
47
+ ... 'guarantees', 'the', 'military', 'forces', 'always',
48
+ ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party']
49
+ >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
50
+ ... 'army', 'always', 'to', 'heed', 'the', 'directions',
51
+ ... 'of', 'the', 'party']
52
+
53
+ >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was',
54
+ ... 'interested', 'in', 'world', 'history']
55
+ >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history',
56
+ ... 'because', 'he', 'read', 'the', 'book']
57
+
58
+ >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]]
59
+ >>> hypotheses = [hyp1, hyp2]
60
+ >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS
61
+ 0.5920...
62
+
63
+ The example below show that corpus_bleu() is different from averaging
64
+ sentence_bleu() for hypotheses
65
+
66
+ >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1)
67
+ >>> score2 = sentence_bleu([ref2a], hyp2)
68
+ >>> (score1 + score2) / 2 # doctest: +ELLIPSIS
69
+ 0.6223...
70
+
71
+ :param list_of_references: a corpus of lists of reference
72
+ sentences, w.r.t. hypotheses
73
+ :type list_of_references: list(list(list(str)))
74
+ :param hypotheses: a list of hypothesis sentences
75
+ :type hypotheses: list(list(str))
76
+ :param weights: weights for unigrams, bigrams, trigrams and so on
77
+ :type weights: list(float)
78
+ :param smoothing_function:
79
+ :type smoothing_function: SmoothingFunction
80
+ :param auto_reweigh: Option to re-normalize the weights uniformly.
81
+ :type auto_reweigh: bool
82
+ :return: The corpus-level BLEU score.
83
+ :rtype: float
84
+ """
85
+ # Before proceeding to compute BLEU, perform sanity checks.
86
+
87
+ p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches.
88
+ p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref.
89
+ hyp_lengths, ref_lengths = 0, 0
90
+
91
+ assert len(list_of_references) == len(hypotheses), (
92
+ "The number of hypotheses and their reference(s) should be the " "same "
93
+ )
94
+
95
+ # Iterate through each hypothesis and their corresponding references.
96
+ for references, hypothesis in zip(list_of_references, hypotheses):
97
+ # For each order of ngram, calculate the numerator and
98
+ # denominator for the corpus-level modified precision.
99
+ for i, _ in enumerate(weights, start=1):
100
+ p_i = modified_precision(references, hypothesis, i)
101
+ p_numerators[i] += p_i.numerator
102
+ p_denominators[i] += p_i.denominator
103
+
104
+ # Calculate the hypothesis length and the closest reference length.
105
+ # Adds them to the corpus-level hypothesis and reference counts.
106
+ hyp_len = len(hypothesis)
107
+ hyp_lengths += hyp_len
108
+ ref_lengths += closest_ref_length(references, hyp_len)
109
+
110
+ # Calculate corpus-level brevity penalty.
111
+ if no_length_penalty and averaging_mode == "geometric":
112
+ bp = 1.0
113
+ elif no_length_penalty and averaging_mode == "arithmetic":
114
+ bp = 0.0
115
+ else:
116
+ assert not no_length_penalty
117
+ assert (
118
+ averaging_mode != "arithmetic"
119
+ ), "Not sure how to apply length penalty when aurithmetic mode"
120
+ bp = brevity_penalty(ref_lengths, hyp_lengths)
121
+
122
+ # Uniformly re-weighting based on maximum hypothesis lengths if largest
123
+ # order of n-grams < 4 and weights is set at default.
124
+ if auto_reweigh:
125
+ if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
126
+ weights = (1 / hyp_lengths,) * hyp_lengths
127
+
128
+ # Collects the various precision values for the different ngram orders.
129
+ p_n = [
130
+ Fraction(p_numerators[i], p_denominators[i], _normalize=False)
131
+ for i, _ in enumerate(weights, start=1)
132
+ ]
133
+
134
+ # Returns 0 if there's no matching n-grams
135
+ # We only need to check for p_numerators[1] == 0, since if there's
136
+ # no unigrams, there won't be any higher order ngrams.
137
+ if p_numerators[1] == 0:
138
+ return 0
139
+
140
+ # If there's no smoothing, set use method0 from SmoothinFunction class.
141
+ if not smoothing_function:
142
+ smoothing_function = SmoothingFunction().method0
143
+ # Smoothen the modified precision.
144
+ # Note: smoothing_function() may convert values into floats;
145
+ # it tries to retain the Fraction object as much as the
146
+ # smoothing method allows.
147
+ p_n = smoothing_function(
148
+ p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths
149
+ )
150
+
151
+ if averaging_mode == "geometric":
152
+ s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
153
+ s = bp * math.exp(math.fsum(s))
154
+ elif averaging_mode == "arithmetic":
155
+ s = (w_i * p_i for w_i, p_i in zip(weights, p_n))
156
+ s = math.fsum(s)
157
+
158
+ return s
159
+
160
+
161
+ def sentence_bleu(
162
+ references,
163
+ hypothesis,
164
+ weights=(0.25, 0.25, 0.25, 0.25),
165
+ smoothing_function=None,
166
+ auto_reweigh=False,
167
+ averaging_mode="geometric",
168
+ no_length_penalty=False,
169
+ ):
170
+ return corpus_bleu(
171
+ [references],
172
+ [hypothesis],
173
+ weights,
174
+ smoothing_function,
175
+ auto_reweigh,
176
+ averaging_mode,
177
+ no_length_penalty,
178
+ )
179
+
180
+
181
+ def get_target_sequences(manifest, ground_truth, to_take=1000):
182
+ import json
183
+ import pathlib
184
+
185
+ with open(ground_truth, "r") as fin:
186
+ original_continuations = json.loads(fin.read())
187
+
188
+ sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
189
+ assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
190
+
191
+ sequence2length.sort(key=lambda x: x[1])
192
+ to_take_sequences = set(v[0] for v in sequence2length[:to_take])
193
+ to_take_ids = []
194
+
195
+ with open(manifest, "r") as f:
196
+ f.readline()
197
+
198
+ for i, line in enumerate(f.readlines()):
199
+ seq_id = line.split()[0]
200
+ seq_id = pathlib.Path(seq_id).name.split("__")[0]
201
+
202
+ if seq_id in to_take_sequences:
203
+ to_take_ids.append(i)
204
+
205
+ print(f"Took {len(to_take_ids)} ids")
206
+ return set(to_take_ids)
207
+
208
+
209
+ def get_self_bleu(utterances, averaging_mode, weights):
210
+ self_bleu = []
211
+
212
+ for i in range(len(utterances)):
213
+ hypo = utterances[i]
214
+ rest = utterances[:i] + utterances[i + 1 :]
215
+
216
+ self_bleu.append(
217
+ sentence_bleu(
218
+ rest,
219
+ hypo,
220
+ weights,
221
+ no_length_penalty=True,
222
+ averaging_mode=averaging_mode,
223
+ )
224
+ )
225
+
226
+ return self_bleu
227
+
228
+
229
+ def get_self_bleu2_arithmetic(utterances):
230
+ weights = (0.5, 0.5) # equal weight for unigrams and bigrams
231
+ return get_self_bleu(utterances, averaging_mode="arithmetic", weights=weights)
232
+
233
+
234
+ def get_self_bleu2_geometric(utterances):
235
+ weights = (0.5, 0.5)
236
+ return get_self_bleu(utterances, averaging_mode="geometric", weights=weights)
237
+
238
+
239
+ def get_auto_bleu2_arithmetic(utterances):
240
+ weights = (0.5, 0.5)
241
+ return [auto_bleu(u, mean_mode="arithmetic", weights=weights) for u in utterances]
242
+
243
+
244
+ def get_auto_bleu2_geometric(utterances):
245
+ weights = (0.5, 0.5)
246
+ return [auto_bleu(u, mean_mode="geometric", weights=weights) for u in utterances]
247
+
248
+
249
+ def get_auto_bleu3_geometric(utterances):
250
+ weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
251
+ return [auto_bleu(u, mean_mode="geometric", weights=weights) for u in utterances]
252
+
253
+
254
+ def get_auto_bleu3_arithmetic(utterances):
255
+ weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
256
+ return [auto_bleu(u, mean_mode="arithmetic", weights=weights) for u in utterances]
257
+
258
+
259
+ def get_self_bleu3_arithmetic(utterances):
260
+ weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
261
+ return get_self_bleu(utterances, averaging_mode="arithmetic", weights=weights)
262
+
263
+
264
+ def get_self_bleu3_geometric(utterances):
265
+ weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
266
+ return get_self_bleu(utterances, averaging_mode="geometric", weights=weights)
267
+
268
+
269
+ def auto_bleu(sentence, weights, mean_mode="arithmetic"):
270
+ if len(sentence) <= 1:
271
+ return 0
272
+
273
+ N = len(weights)
274
+
275
+ bleu_n = np.zeros([N])
276
+ for n in range(N):
277
+ targ_ngrams = list(nltk.ngrams(sentence, n + 1))
278
+ for p in range(len(targ_ngrams)):
279
+ left = sentence[:p]
280
+ right = sentence[(p + n + 1) :]
281
+ rest_ngrams = list(nltk.ngrams(left, n + 1)) + list(
282
+ nltk.ngrams(right, n + 1)
283
+ )
284
+ # compute the nb of matching ngrams
285
+ bleu_n[n] += targ_ngrams[p] in rest_ngrams
286
+ bleu_n[n] /= len(targ_ngrams) # average them to get a proportion
287
+
288
+ weights = np.array(weights)
289
+ if mean_mode == "arithmetic":
290
+ return (bleu_n * weights).sum()
291
+ elif mean_mode == "geometric":
292
+ return (bleu_n**weights).prod()
293
+ else:
294
+ raise ValueError(f"Unknown agggregation mode {mean_mode}")
295
+
296
+
297
+ def run_f(task_params):
298
+ f, terms = task_params
299
+ return f(terms)
requirements.txt CHANGED
@@ -1 +1,18 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ typeguard==2.13.3
2
+ espnet @ git+https://github.com/siddhu001/espnet@sds_demo_recipe
3
+ espnet_model_zoo
4
+ huggingface_hub==0.23.2
5
+ transformers[sentencepiece]
6
+ sentencepiece
7
+ datasets
8
+ torch==2.5.1
9
+ torchaudio==2.5.1
10
+ librosa
11
+ sounddevice==0.5.0
12
+ webrtcvad-wheels
13
+ webrtcvad==2.0.10
14
+ ChatTTS
15
+ evaluate
16
+ snac==1.2.0
17
+ litgpt==0.4.3
18
+ gradio==4.43.0
temp_repo DELETED
@@ -1 +0,0 @@
1
- Subproject commit 8f631305c9bcacda1d4cfd24a9fcef74f518081e
 
 
versa.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ git clone https://github.com/shinjiwlab/versa.git
2
+ cd versa
3
+ git checkout 64bf6fe22fbc8d43068afdf4e715864a18577735
4
+ pip install .
5
+ cd ..