import io import sys import gradio as gr import srt import jiwer from dataclasses import dataclass from dataclasses_json import dataclass_json from datetime import timedelta @dataclass_json @dataclass class ZHTW_Sub: start: timedelta end: timedelta zh: str tw: str def read_srt(p): with open(p) as f: subs = list(srt.parse(f.read())) return subs def merge_sub(subs): i = 1 while i < len(subs): ps = subs[i-1] s = subs[i] if ps.end != s.start: i += 1 continue ps.end = s.end ps.zh += f" {s.zh}" ps.tw += f" {s.tw}" subs.pop(i) return subs def merge_sub2(subs, delta): i = 1 while i < len(subs): ps = subs[i-1] s = subs[i] if s.start - ps.end > delta: i += 1 continue ps.end = s.end ps.zh += f" {s.zh}" ps.tw += f" {s.tw}" subs.pop(i) return subs def filter_sub(subs): buffer = io.StringIO() stdout_bak = sys.stdout sys.stdout = buffer # Redirect print to buffer new_subs = [] carry_next = False for s in subs: content = s.content if '#' in s.content: print('註:標記', s.start, s.end, s.content) continue if '\n' in content: print('修:分行', '\\n', s.start, content) carry_next = True continue #? else: content = [content] if len(content) != 1: print('註:多行', '\\n', s.start, content) print(s.start, s.end) tw_all, zh_all = [], [] for cnt in content: if '|' in cnt: if len(cnt.split('|')) %2 != 0: print('修:多槓', cnt.split('|')) continue tw, zh = cnt.split('|') tw, zh = (t.strip() for t in [tw, zh]) else: sp = cnt.split() if len(sp) %2!=0: print('修:不均', s.start, s.end, sp) continue else: mid = len(sp)//2 tw, zh = sp[:mid], sp[mid:] tw, zh = (' '.join(t) for t in [tw, zh]) if jiwer.cer(tw, zh) > 1: print('註:差距', s.start, s.end, 'tw:', tw, 'zh:', zh) tw_all.append(tw) zh_all.append(zh) if carry_next: new_subs[-1].zh += f" {zh}" new_subs[-1].tw += f" {tw}" new_subs[-1].end = s.end carry_next = False else: new_sub = ZHTW_Sub(s.start, s.end, zh, tw) new_subs.append(new_sub) sys.stdout = stdout_bak return new_subs, buffer def update_yield(): buffer = [] def update_print(inp): buffer.append(str(inp)) return '\n'.join(buffer) return update_print def parse_srt(file): if file is None: return "No file uploaded." upd = update_yield() yield upd(file.name) subs = read_srt(file.name) yield upd(len(subs)) new_subs, logs = filter_sub(subs) yield upd(logs.getvalue()) yield upd(len(new_subs)) new_subs = merge_sub(new_subs) yield upd(len(new_subs)) # ep_name = file.name.replace('-dedup', '') # ep_name = ep_name.replace('.fix', '') total_dur = 0 for i, it in enumerate(new_subs): if (it.end-it.start).total_seconds() > 30: yield upd(i) yield upd(it.end.total_seconds(), (it.end-it.start).total_seconds(), it.tw) total_dur += (it.end-it.start).total_seconds() yield upd("可用時長 "+str(timedelta(seconds=int(total_dur)))) with gr.Blocks() as demo: gr.Markdown("## SRT File Validator") with gr.Column(): file_input = gr.File(label="Upload .srt File", file_types=[".srt"]) output_log = gr.Textbox(label="Parsing Log", lines=10, max_lines=120) file_input.change(fn=parse_srt, inputs=file_input, outputs=output_log) demo.launch()