Gradio-Transart / app.py
Raveheart1's picture
Update app.py
52ce338 verified
raw
history blame contribute delete
2.87 kB
import requests
import io
from PIL import Image, UnidentifiedImageError
import gradio as gr
from transformers import MarianMTModel, MarianTokenizer
import os
model_name = "Helsinki-NLP/opus-mt-mul-en"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
def translate_text(input_text, language):
language_map = {
"Tamil": "ta",
"French": "fr",
"Hindi": "hi",
"German": "de"
}
lang_prefix = f">>{language_map[language]}<< "
text_with_lang = lang_prefix + input_text
inputs = tokenizer(text_with_lang, return_tensors="pt", padding=True)
translated_tokens = model.generate(**inputs)
translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
return translation
def query_gemini_api(translated_text, gemini_api_key):
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent"
headers = {"Content-Type": "application/json"}
prompt = f"Based on the following sentence, continue the story: {translated_text}"
payload = {
"contents": [{"parts": [{"text": prompt}]}]
}
response = requests.post(f"{url}?key={gemini_api_key}", headers=headers, json=payload)
if response.status_code == 200:
result = response.json()
creative_text = result['candidates'][0]['content']['parts'][0]['text']
return creative_text
else:
return f"Error: {response.status_code} - {response.text}"
def query_image(payload):
huggingface_api_key = os.getenv('HUGGINGFACE_API_KEY')
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
headers = {"Authorization": f"Bearer {huggingface_api_key}"}
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
def process_input(tamil_input, language):
gemini_api_key = os.getenv('GEMINI_API_KEY')
translated_output = translate_text(tamil_input, language)
creative_output = query_gemini_api(translated_output, gemini_api_key)
image_bytes = query_image({"inputs": translated_output})
try:
image = Image.open(io.BytesIO(image_bytes))
except UnidentifiedImageError:
image = None
return translated_output, creative_output, image
# Gradio interface setup
iface = gr.Interface(
fn=process_input,
inputs=[
gr.Textbox(label="Input Text"),
gr.Dropdown(label="Select Language", choices=["Tamil", "French", "Hindi", "German"])
],
outputs=[
gr.Textbox(label="Translated Text"),
gr.Textbox(label="Creative Text"),
gr.Image(label="Generated Image")
],
title="TRANSART🎨 BY Sakthi",
description="Enter text to translate into English and generate an image based on the translated text."
)
iface.launch()