Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from transformers import pipeline | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
import re | |
import os | |
import json | |
import requests | |
import whisper | |
from yt_dlp import YoutubeDL | |
import matplotlib as plt | |
#whisper_model = whisper.load_model('small') | |
path = "Hyeonsieun/NTtoGT_7epoch" | |
tokenizer = T5Tokenizer.from_pretrained(path) | |
model = T5ForConditionalGeneration.from_pretrained(path) | |
MODEL_NAME = "openai/whisper-large-v2" | |
BATCH_SIZE = 8 | |
#FILE_LIMIT_MB = 1000 | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=MODEL_NAME, | |
chunk_length_s=30, | |
) | |
def transcribe(inputs): | |
if inputs is None: | |
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"] | |
return text | |
def remove_spaces_within_dollar(text): | |
# ๋ฌ๋ฌ ๊ธฐํธ๋ก ๋๋ฌ์ธ์ธ ๋ถ๋ถ์์ ์คํ์ด์ค ์ ๊ฑฐ | |
# ์ ๊ท ํํ์: \$.*?\$ ๋ '$'๋ก ์์ํด์ '$'๋ก ๋๋๋ ์ต์ํ์ ๋ฌธ์์ด์ ์ฐพ์ (non-greedy) | |
# re.sub์ repl ํ๋ผ๋ฏธํฐ์ ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ๋งค์น๋ ๋ถ๋ถ์์๋ง ๋ณ๊ฒฝ์ ์ ์ฉ | |
result = re.sub(r'\$(.*?)\$', lambda match: match.group(0).replace(' ', ''), text) | |
return result | |
def audio_correction(file): | |
ASR_result = transcribe(file) | |
text_list = split_text_complex_rules_with_warning(ASR_result) | |
whole_text = '' | |
for text in text_list: | |
input_text = f"translate the text pronouncing the formula to a LaTeX equation: {text}" | |
inputs = tokenizer.encode( | |
input_text, | |
return_tensors='pt', | |
max_length=325, | |
padding='max_length', | |
truncation=True | |
) | |
# Get correct sentence ids. | |
corrected_ids = model.generate( | |
inputs, | |
max_length=325, | |
num_beams=5, # `num_beams=1` indicated temperature sampling. | |
early_stopping=True | |
) | |
# Decode. | |
corrected_sentence = tokenizer.decode( | |
corrected_ids[0], | |
skip_special_tokens=False | |
) | |
whole_text += corrected_sentence | |
return remove_spaces_within_dollar(whole_text)[5:-4] | |
def youtubeASR(link): | |
# ์ ํ๋ธ์ ์์ฑ๋ง ๋ค์ด๋ก๋ํ ์์ ํ์ผ๋ช | |
out_fn = 'temp1.mp3' | |
ydl_opts = { | |
'format': 'bestaudio/best', # Audio๋ง ๋ค์ด๋ก๋ | |
'outtmpl': out_fn, # ์ง์ ํ ํ์ผ๋ช ์ผ๋ก ์ ์ฅ | |
} | |
with YoutubeDL(ydl_opts) as ydl: | |
ydl.download([link]) | |
result = pipe(out_fn, batch_size=BATCH_SIZE, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"] # Youtube์์ ๋ฐ์ ์์ฑ ํ์ผ(out_fn)์ ๋ฐ์์ฐ๊ธฐ | |
script = result['text'] # ๋ฐ์์ฐ๊ธฐ ํ ๋ด์ฉ ์ ์ฅ | |
return script | |
def split_text_complex_rules_with_warning(text): | |
# ์ฝค๋ง๋ฅผ ์ ์ธํ ๊ตฌ๋์ ์ผ๋ก ๋ฌธ์ฅ ๋ถ๋ฆฌ | |
parts = re.split(r'(?<=[.?!])\s+', text) | |
result = [] | |
warnings = [] # ๊ฒฝ๊ณ ๋ฉ์์ง๋ฅผ ์ ์ฅํ ๋ฆฌ์คํธ | |
for part in parts: | |
# ๊ฐ ๋ถ๋ถ์ ๊ธธ์ด๊ฐ 256์๋ฅผ ์ด๊ณผํ๋ ๊ฒฝ์ฐ ์ฝค๋ง๋ก ์ถ๊ฐ ๋ถ๋ฆฌ | |
if len(part) > 256: | |
subparts = re.split(r',\s*', part) | |
for subpart in subparts: | |
# ๋น ๋ฌธ์์ด ์ ๊ฑฐ ๋ฐ ๊ธธ์ด๊ฐ 256์ ์ดํ์ธ ๊ฒฝ์ฐ๋ง ๊ฒฐ๊ณผ ๋ฆฌ์คํธ์ ์ถ๊ฐ | |
trimmed_subpart = subpart.strip() | |
if trimmed_subpart and len(trimmed_subpart) <= 256: | |
result.append(trimmed_subpart) | |
else: | |
# ๊ธธ์ด๊ฐ 256์๋ฅผ ์ด๊ณผํ๋ ๊ฒฝ์ฐ ๊ฒฝ๊ณ ๋ฉ์์ง ์ถ๊ฐ | |
warnings.append(f"๋ฌธ์ฅ ๊ธธ์ด๊ฐ 256์๋ฅผ ์ด๊ณผํฉ๋๋ค: {trimmed_subpart[:50]}... (๊ธธ์ด: {len(trimmed_subpart)})") | |
else: | |
# ๊ธธ์ด๊ฐ 256์ ์ดํ์ธ ๊ฒฝ์ฐ ๋ฐ๋ก ๊ฒฐ๊ณผ ๋ฆฌ์คํธ์ ์ถ๊ฐ | |
result.append(part.strip()) | |
warnings = 0 | |
return result | |
def youtube_correction(link): | |
ASR_result = youtubeASR(link) | |
text_list = split_text_complex_rules_with_warning(ASR_result) | |
whole_text = '' | |
for text in text_list: | |
input_text = f"translate the text pronouncing the formula to a LaTeX equation: {text}" | |
inputs = tokenizer.encode( | |
input_text, | |
return_tensors='pt', | |
max_length=325, | |
padding='max_length', | |
truncation=True | |
) | |
# Get correct sentence ids. | |
corrected_ids = model.generate( | |
inputs, | |
max_length=325, | |
num_beams=5, # `num_beams=1` indicated temperature sampling. | |
early_stopping=True | |
) | |
# Decode. | |
corrected_sentence = tokenizer.decode( | |
corrected_ids[0], | |
skip_special_tokens=False | |
) | |
whole_text += corrected_sentence | |
return remove_spaces_within_dollar(whole_text)[5:-4] | |
demo = gr.Blocks() | |
file_transcribe = gr.Interface( | |
fn=audio_correction, | |
inputs=gr.components.Audio(sources="upload", type="filepath"), | |
outputs="text" | |
) | |
yt_transcribe = gr.Interface( | |
fn=youtube_correction, | |
inputs="text", | |
outputs="text" | |
) | |
with demo: | |
gr.TabbedInterface([file_transcribe, yt_transcribe], ["Audio file", "YouTube"]) | |
demo.launch() |