import os import gc import gradio as gr import torch import random import logging from huggingface_hub import login, HfApi, snapshot_download import spacy import subprocess import pkg_resources import sys login(token=os.environ.get("LA_NAME")) os.environ["LASER"] = "laser" def check_and_install(package, required_version): try: dist = pkg_resources.get_distribution(package) installed_version = dist.version if installed_version != required_version: print(f"[{package}] already installed {installed_version}. Required version {required_version},re-install...") subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}", "--force-reinstall"]) else: print(f"[{package}] required version {required_version} finished") except pkg_resources.DistributionNotFound: print(f"[{package}] not found, install: {required_version}...") subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}"]) packages = { "pip": "24.0", "fairseq": "0.12.2", "torch": "2.6.0", "transformers": "4.51.3" } for package, version in packages.items(): check_and_install(package, version) from transformers import AutoTokenizer, AutoModelForCausalLM from vecalign.plan2align import translate_text, external_find_best_translation from trl import AutoModelForCausalLMWithValueHead models = ["en_core_web_sm", "ru_core_news_sm", "de_core_news_sm", "ja_core_news_sm", "ko_core_news_sm", "es_core_news_sm"] for model in models: try: spacy.load(model) except OSError: from spacy.cli import download download(model) try: spacy.load("zh_core_web_sm") except OSError: from spacy.cli import download download("zh_core_web_sm") subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"]) # ---------- translation function ---------- # Initialize device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load models once print("Loading models...") model_id = "google/gemma-2-9b-it" # "meta-llama/Meta-Llama-3.1-8B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16 ) import spacy lang_map = { "English": ("en", "en_core_web_sm"), "Russian": ("ru", "ru_core_news_sm"), "German": ("de", "de_core_news_sm"), "Japanese": ("ja", "ja_core_news_sm"), "Korean": ("ko", "ko_core_news_sm"), "Spanish": ("es", "es_core_news_sm"), "Simplified Chinese": ("zh", "zh_core_web_sm"), "Traditional Chinese": ("zh", "zh_core_web_sm") } def get_lang_and_nlp(language): if language not in lang_map: raise ValueError(f"Unsupported language: {language}") lang_code, model_name = lang_map[language] return lang_code, spacy.load(model_name) def segment_sentences_by_punctuation(text, src_nlp): segmented_sentences = [] paragraphs = text.split('\n') for paragraph in paragraphs: if paragraph.strip(): doc = src_nlp(paragraph) for sent in doc.sents: segmented_sentences.append(sent.text.strip()) return segmented_sentences def generate_translation(system_prompt, prompt): full_prompt = f"System: {system_prompt}\nUser: {prompt}\nAssistant:" inputs = tokenizer(full_prompt, return_tensors="pt").to(device) outputs = model.generate( **inputs, max_new_tokens=2048, temperature=0.7, top_p=0.9, do_sample=True ) translation = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) return translation def check_token_length(text, max_tokens=1024): return len(text) <= max_tokens import uuid def get_user_session(state=None): if state is None: state = {} if not isinstance(state, dict): state = {} if not state.get("session_id"): state["session_id"] = uuid.uuid4().hex return state["session_id"] # ---------- Translation Function ---------- def mpc_initial_translate(source_sentence, src_language, tgt_language): system_prompts = [ "You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.", "You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.", "You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story." ] translations = [] for prompt_style in system_prompts: prompt = f"### Translate this from {src_language} to {tgt_language} and only output the result." prompt += f"\n### {src_language}:\n {source_sentence}" prompt += f"\n### {tgt_language}:\n" translation = generate_translation(prompt_style, prompt) translations.append(translation) print("mpc_initial_translate") print(translations) return translations def mpc_improved_translate(source_sentence, current_translation, src_language, tgt_language): system_prompts = [ "You are a meticulous translator. Please improve the following translation by ensuring it is a literal and structurally precise version.", "You are a professional translator. Please refine the provided translation to be clear, formal, and accurate.", "You are a creative translator. Please enhance the translation so that it is vivid, natural, and engaging." ] translations = [] for prompt_style in system_prompts: prompt = (f"Source ({src_language}): {source_sentence}\n" f"Current Translation ({tgt_language}): {current_translation}\n" f"Please provide an improved translation into {tgt_language} and only output the result:") translation = generate_translation(prompt_style, prompt) translations.append(translation) print("mpc_improved_translate") print(translations) return translations def basic_translate(source_sentence, src_language, tgt_language): system_prompts = ["You are a helpful translator and only output the result."] translations = [] for prompt_style in system_prompts: prompt = f"### Translate this from {src_language} to {tgt_language}." prompt += f"\n### {src_language}:\n {source_sentence}" prompt += f"\n### {tgt_language}:\n" translation = generate_translation(prompt_style, prompt) translations.append(translation) return translations def summary_translate(src_text, temp_tgt_text, tgt_language, session_id): if len(temp_tgt_text.strip()) == 0: return "", 0 system_prompts = ["You are a helpful rephraser. You only output the rephrased result."] translations = [] for prompt_style in system_prompts: prompt = f"### Rephrase the following in {tgt_language}." prompt += f"\n### Input:\n {temp_tgt_text}" prompt += f"\n### Rephrased:\n" translation = generate_translation(prompt_style, prompt) translations.append(translation) try: _, score = evaluate_candidates(src_text, translations, tgt_language, session_id) except: score = 0 return translations[0], score def plan2align_translate_text(text, session_id, model, tokenizer, device, src_language, task_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, reward_model_type): result = translate_text( text = text, model = model, tokenizer = tokenizer, device = device, src_language=src_language, task_language=task_language, max_iterations_value=max_iterations_value, threshold_value=threshold_value, good_ref_contexts_num_value=good_ref_contexts_num_value, reward_model_type=reward_model_type, session_id=session_id ) try: _, score = evaluate_candidates(text, [result], task_language, session_id) except: score = 0 return result, score def evaluate_candidates(source, candidates, language, session_id): evals = [(source, candidates)] best_translations = external_find_best_translation(evals, language, session_id) best_candidate, best_score = best_translations[0] return best_candidate, best_score def original_translation(text, src_language, target_language, session_id): cand_list = basic_translate(text, src_language, target_language) best, score = evaluate_candidates(text, cand_list, target_language, session_id) if cand_list: return best, score return "", 0 def best_of_n_translation(text, src_language, target_language, n, session_id): if not check_token_length(text, 4096): return "Warning: Input text too long.", 0 candidates = [] for i in range(n): cand_list = basic_translate(text, src_language, target_language) if cand_list: candidates.append(cand_list[0]) try: best, score = evaluate_candidates(text, candidates, target_language, session_id) print("best_of_n evaluate_candidates results:") print(best, score) except: print("evaluate_candidates fail") return "Warning: Input text too long.", 0 return best, score def mpc_translation(text, src_language, target_language, iterations, session_id): if not check_token_length(text, 4096): return "Warning: Input text too long.", 0 current_trans = "" best_score = None for i in range(iterations): if i == 0: cand_list = mpc_initial_translate(text, src_language, target_language) else: cand_list = mpc_improved_translate(text, current_trans, src_language, target_language) try: best, score = evaluate_candidates(text, cand_list, target_language, session_id) print("mpc evaluate_candidates results:") print(best, score) current_trans = best best_score = score except: print("evaluate_candidates fail") current_trans = cand_list[0] best_score = 0 return current_trans, best_score # ---------- Gradio function ---------- def process_text(text, src_language, target_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, translation_methods=None, chunk_size=-1, state=None): translation_methods = translation_methods or ["Original", "Plan2Align"] session_id = get_user_session(state) """ 傳入中文文本與目標語言,依序產生四種翻譯結果: 1. 原始翻譯 2. Plan2Align 翻譯 3. Best-of-N 翻譯 4. MPC 翻譯 """ orig_output = "" plan2align_output = "" best_of_n_output = "" mpc_output = "" src_lang, src_nlp = get_lang_and_nlp(src_language) source_sentence = text.replace("\n", " ") source_segments = segment_sentences_by_punctuation(source_sentence, src_nlp) if chunk_size == -1: if "Original" in translation_methods: orig, best_score = original_translation(text, src_language, target_language, session_id) orig_output = f"{orig}\n\nScore: {best_score:.2f}" if "Plan2Align" in translation_methods: plan2align_trans, best_score = plan2align_translate_text( text, session_id, model, tokenizer, device, src_language, target_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx" ) plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}" if "Best-of-N" in translation_methods: best_candidate, best_score = best_of_n_translation(text, src_language, target_language, max_iterations_value, session_id) best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}" if "MPC" in translation_methods: mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language, max_iterations_value, session_id) mpc_output = f"{mpc_candidate}\n\nScore: {mpc_score:.2f}" else: chunks = [' '.join(source_segments[i:i+chunk_size]) for i in range(0, len(source_segments), chunk_size)] org_translated_chunks = [] p2a_translated_chunks = [] bfn_translated_chunks = [] mpc_translated_chunks = [] for chunk in chunks: if "Original" in translation_methods: translation, _ = original_translation(chunk, src_language, target_language, session_id) org_translated_chunks.append(translation) if "Plan2Align" in translation_methods: translation, _ = plan2align_translate_text( chunk, session_id, model, tokenizer, device, src_language, target_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx" ) p2a_translated_chunks.append(translation) if "Best-of-N" in translation_methods: translation, _ = best_of_n_translation(chunk, src_language, target_language, max_iterations_value, session_id) bfn_translated_chunks.append(translation) if "MPC" in translation_methods: translation, _ = mpc_translation(chunk, src_language, target_language, max_iterations_value, session_id) mpc_translated_chunks.append(translation) org_combined_translation = ' '.join(org_translated_chunks) p2a_combined_translation = ' '.join(p2a_translated_chunks) bfn_combined_translation = ' '.join(bfn_translated_chunks) mpc_combined_translation = ' '.join(mpc_translated_chunks) orig, best_score = summary_translate(text, org_combined_translation, target_language, session_id) orig_output = f"{orig}\n\nScore: {best_score:.2f}" plan2align_trans, best_score = summary_translate(text, p2a_combined_translation, target_language, session_id) plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}" best_candidate, best_score = summary_translate(text, bfn_combined_translation, target_language, session_id) best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}" mpc_candidate, best_score = summary_translate(text, mpc_combined_translation, target_language, session_id) mpc_output = f"{mpc_candidate}\n\nScore: {best_score:.2f}" return orig_output, plan2align_output, best_of_n_output, mpc_output # ---------- Gradio ---------- target_languages = ["Traditional Chinese", "Simplified Chinese", "English", "Russian", "German", "Japanese", "Korean"] src_languages = ["Traditional Chinese", "Simplified Chinese", "English", "Russian", "German", "Japanese", "Korean"] with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo: state = gr.State({}) gr.Markdown("# Translation Demo: Multiple Translation Methods") gr.Markdown("請選擇要執行的翻譯方法(可多選或全選):") with gr.Row(): with gr.Column(scale=1): source_text = gr.Textbox( label="Source Text", placeholder="請輸入文本...", lines=5 ) src_language_input = gr.Dropdown( choices=src_languages, value="Traditional Chinese", label="Source Language" ) task_language_input = gr.Dropdown( choices=target_languages, value="English", label="Target Language" ) max_iterations_input = gr.Number(label="Max Iterations", value=6) threshold_input = gr.Number(label="Threshold", value=0.7) good_ref_contexts_num_input = gr.Number(label="Good Ref Contexts Num", value=5) translation_methods_input = gr.CheckboxGroup( choices=["Original", "Plan2Align", "Best-of-N", "MPC"], value=["Original", "Plan2Align"], label="Translation Methods" ) chunk_size_input = gr.Number( # ✅ add chunk function label="Chunk Size (-1 for all)", value=-1 ) translate_button = gr.Button("Translate") with gr.Column(scale=2): original_output = gr.Textbox( label="Original Translation", lines=5, interactive=False ) plan2align_output = gr.Textbox( label="Plan2Align Translation", lines=5, interactive=False ) best_of_n_output = gr.Textbox( label="Best-of-N Translation", lines=5, interactive=False ) mpc_output = gr.Textbox( label="MPC Translation", lines=5, interactive=False ) translate_button.click( fn=process_text, inputs=[ source_text, src_language_input, task_language_input, max_iterations_input, threshold_input, good_ref_contexts_num_input, translation_methods_input, chunk_size_input, # ✅ add chunk function state ], outputs=[original_output, plan2align_output, best_of_n_output, mpc_output] ) gr.Examples( examples=[ ["台灣夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "Traditional Chinese", "English", 2, 0.7, 1, ["Original", "Plan2Align"], -1], ["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Traditional Chinese", "Japanese", 2, 0.7, 1, ["Original", "Plan2Align"], -1], ["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "Traditional Chinese", "Korean", 2, 0.7, 1, ["Original", "Plan2Align"], -1], # ["珍珠奶茶,這款源自台灣的獨特飲品,不僅在台灣本地深受喜愛,更以其獨特的風味和口感,在全球掀起了一股熱潮,成為了一種跨越文化、風靡全球的時尚飲品。", "Traditional Chinese", "Japanese", 3, 0.7, 3, ["Original", "Plan2Align"], -1], # ["原住民文化如同一片深邃的星空,閃爍著無數璀璨的傳統與藝術光芒。他們的歌舞,是與祖靈對話的旋律,是與自然共鳴的節奏,每一個舞步、每一聲吟唱,都承載著古老的傳說與智慧。編織,是他們巧手下的詩篇,一絲一線,交織出生命的紋理,也編織出對土地的熱愛與敬畏。木雕,則是他們與自然對話的雕塑,每一刀、每一鑿,都刻畫著對萬物的觀察與敬意,也雕琢出對祖先的追憶與傳承。", "Traditional Chinese", "Korean", 5, 0.7, 5, ["Original", "Plan2Align"], -1] ], inputs=[ source_text, src_language_input, task_language_input, max_iterations_input, threshold_input, good_ref_contexts_num_input, translation_methods_input, chunk_size_input # ✅ add chunk function ], outputs=[original_output, plan2align_output, best_of_n_output, mpc_output], fn=process_text ) gr.Markdown("## How It Works") gr.Markdown(""" 1. **Original Translation:** 利用固定提示生成候選,直接取首個候選作為原始翻譯。 2. **Plan2Align Translation:** 採用 context alignment 和 self-rewriting 策略進行翻譯,適合長文翻譯。 3. **Best-of-N Translation:** 重複生成多次候選,評分選出最佳翻譯,適合短文翻譯。 4. **MPC Translation:** 以迭代改善策略,每輪生成候選後評分,並將最佳翻譯作為下一輪輸入,適合短文翻譯。 若輸入文本超過 1024 tokens,Best-of-N 與 MPC 方法會回傳警告訊息。 """) if __name__ == "__main__": demo.launch(share=True)