joshuaberkowitzus's picture
backup versioned
f5c511b verified
# app.py for Hugging Face Space
# Make sure to add 'gradio', 'transformers', and 'torch' (or 'tensorflow'/'flax')
# to your requirements.txt file in the Hugging Face Space repository.
# gated model
# Set Hugging Face token if needed (for gated models, though Llama 3.1 might not require it after initial access grant)
from huggingface_hub import login
# app.py for Hugging Face Space
# Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'),
# and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository.
import gradio as gr
import torch # Or tensorflow/flax depending on backend
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download # Import hub download function
import json # Import json library
import os # Import os library for path joining
# --- hf lpgin ---
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# --- Configuration ---
MODEL_NAME = "google/txgemma-2b-predict"
PROMPT_FILENAME = "tdc_prompts.json"
MODEL_CACHE = "model_cache" # Optional: define a cache directory
MAX_EXAMPLES = 100 # Limit the number of examples loaded from the JSON
EXAMPLE_SMILES = "C1=CC=CC=C1" # Default SMILES for examples (Benzene)
# --- Load Model, Tokenizer, and Prompts ---
print(f"Loading model: {MODEL_NAME}...")
tdc_prompts_data = None # Initialize as None
examples_list = [] # Initialize empty list for examples
try:
# Check if GPU is available and use it, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE)
print("Tokenizer loaded.")
# Load the model
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=MODEL_CACHE,
device_map="auto" # Automatically distribute model across available devices (GPU/CPU)
)
print("Model loaded.")
# Download and load the prompts JSON file
print(f"Downloading {PROMPT_FILENAME}...")
prompts_file_path = hf_hub_download(
repo_id=MODEL_NAME,
filename=PROMPT_FILENAME,
cache_dir=MODEL_CACHE,
# force_download=True, # Uncomment to force redownload if needed
)
print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}")
# Load the JSON data
with open(prompts_file_path, 'r') as f:
tdc_prompts_data = json.load(f)
print(f"Loaded prompts data from {PROMPT_FILENAME}.")
# --- Prepare examples for Gradio ---
# Updated logic: Parse the dictionary format from tdc_prompts.json
# The JSON is expected to be a dictionary where values are prompt templates.
if isinstance(tdc_prompts_data, dict):
print(f"Processing {len(tdc_prompts_data)} prompts from dictionary...")
count = 0
for prompt_template in tdc_prompts_data.values():
if count >= MAX_EXAMPLES:
break
if isinstance(prompt_template, str):
# Replace the placeholder with the example SMILES string
example_prompt = prompt_template.replace("{Drug SMILES}", EXAMPLE_SMILES)
# Add to examples list with default parameters
examples_list.append([example_prompt, 100, 0.7]) # Default max_tokens=100, temp=0.7
count += 1
else:
print(f"Warning: Skipping non-string value in prompts dictionary: {prompt_template}")
print(f"Prepared {len(examples_list)} examples for Gradio.")
else:
print(f"Warning: Expected {PROMPT_FILENAME} to contain a dictionary, but found {type(tdc_prompts_data)}. Cannot load examples.")
# examples_list remains empty
except Exception as e:
print(f"Error loading model, tokenizer, or prompts: {e}")
# Ensure examples_list is empty on error during setup
examples_list = []
raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}")
# --- Prediction Function ---
def predict(prompt, max_new_tokens=100, temperature=0.7):
"""
Generates text based on the input prompt using the loaded model.
Args:
prompt (str): The input text prompt.
max_new_tokens (int): The maximum number of new tokens to generate.
temperature (float): Controls the randomness of the generation. Lower is more deterministic.
Returns:
str: The generated text.
"""
print(f"Received prompt: {prompt}")
print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}")
try:
# Prepare the input for the model
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device
# Generate text
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens), # Ensure it's an integer
temperature=float(temperature), # Ensure it's a float
do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0
pad_token_id=tokenizer.eos_token_id # Set pad token id
)
# Decode the generated tokens
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated text (raw): {generated_text}")
# Remove the prompt from the beginning of the generated text
if generated_text.startswith(prompt):
prompt_length = len(prompt)
result_text = generated_text[prompt_length:].lstrip()
else:
# Handle cases where the model might slightly alter the prompt start
# This is a basic check; more robust checks might be needed
common_prefix = os.path.commonprefix([prompt, generated_text])
# Check if a significant portion of the prompt is at the start
# Use a threshold relative to prompt length, e.g., 80%
if len(prompt) > 0 and len(common_prefix) / len(prompt) > 0.8:
result_text = generated_text[len(common_prefix):].lstrip()
else:
result_text = generated_text # Assume prompt is not included or significantly altered
print(f"Generated text (processed): {result_text}")
return result_text
except Exception as e:
print(f"Error during prediction: {e}")
return f"An error occurred during generation: {e}"
# --- Gradio Interface ---
print("Creating Gradio interface...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# 🤖 TXGemma-2B-Predict Text Generation
Enter a prompt below or select an example, and the model ({MODEL_NAME}) will generate text based on it.
Adjust the parameters for different results. Examples loaded from `{PROMPT_FILENAME}`.
Example prompts use the SMILES string `{EXAMPLE_SMILES}` (Benzene) as a placeholder.
"""
)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Your Prompt",
placeholder="Enter your text prompt here, potentially including a specific Drug SMILES string...",
lines=5
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10,
maximum=500, # Adjust max limit if needed
value=100,
step=10,
label="Max New Tokens",
info="Maximum number of tokens to generate after the prompt."
)
temperature_slider = gr.Slider(
minimum=0.0, # Allow deterministic generation
maximum=1.5,
value=0.7,
step=0.05, # Finer control for temperature
label="Temperature",
info="Controls randomness (0=deterministic, >0=random)."
)
submit_button = gr.Button("Generate Text", variant="primary")
with gr.Column(scale=3):
output_text = gr.Textbox(
label="Generated Text",
lines=10,
interactive=False # Output is not editable by user
)
# --- Connect Components ---
submit_button.click(
fn=predict,
inputs=[prompt_input, max_tokens_slider, temperature_slider],
outputs=output_text,
api_name="predict" # Name for API endpoint if needed
)
# Use the loaded examples if available
if examples_list:
gr.Examples(
examples=examples_list,
# Ensure inputs match the order expected by the 'predict' function and the structure of examples_list
inputs=[prompt_input, max_tokens_slider, temperature_slider],
outputs=output_text,
fn=predict, # The function to run when an example is clicked
cache_examples=False # Caching might be slow/problematic for LLMs
)
else:
gr.Markdown("_(Could not load examples from JSON file or file format was incorrect.)_")
# --- Launch the App ---
print("Launching Gradio app...")
# queue() enables handling multiple users concurrently
# Set share=True if you need a public link, otherwise False or omit
demo.queue().launch(debug=True) # Set debug=False for production