joshuaberkowitzus's picture
back to 2B
4360f38 verified
# app.py for Hugging Face Space
# Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'),
# and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository.
# Using gr.DataFrame does not require adding pandas if using list-of-lists format.
from huggingface_hub import login
import gradio as gr
import torch # Or tensorflow/flax depending on backend
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download # Import hub download function
import json # Import json library
import os # Import os library for path joining
# --- Configuration ---
MODEL_NAME = "google/txgemma-2b-predict"
#MODEL_NAME = "google/txgemma-9b-predict"
PROMPT_FILENAME = "tdc_prompts.json"
MODEL_CACHE = "model_cache" # Optional: define a cache directory
# MAX_EXAMPLES is no longer strictly limiting the display, but can be used if needed later
MAX_EXAMPLES = 600 # Keep variable definition, but DataFrame handles scrolling
EXAMPLE_SMILES = "C1=CC=CC=C1" # Default SMILES for examples (Benzene)
DATAFRAME_HEADERS = ["Task Name", "Prompt Template"]
DATAFRAME_ROW_COUNT = 8 # Number of rows to display initially in the DataFrame
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# --- Load Model, Tokenizer, and Prompts ---
print(f"Loading model: {MODEL_NAME}...")
tdc_prompts_data = None # Initialize as None
dataframe_data = [] # Initialize empty list for DataFrame content
try:
# Check if GPU is available and use it, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE)
print("Tokenizer loaded.")
# Load the model
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=MODEL_CACHE,
device_map="auto" # Automatically distribute model across available devices (GPU/CPU)
)
print("Model loaded.")
# Download and load the prompts JSON file
print(f"Downloading {PROMPT_FILENAME}...")
prompts_file_path = hf_hub_download(
repo_id=MODEL_NAME,
filename=PROMPT_FILENAME,
cache_dir=MODEL_CACHE,
)
print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}")
# Load the JSON data
with open(prompts_file_path, 'r') as f:
tdc_prompts_data = json.load(f)
print(f"Loaded prompts data from {PROMPT_FILENAME}.")
# --- Prepare data for Gradio DataFrame ---
# Updated logic: Parse the dictionary format from tdc_prompts.json
# Create a list of lists for the DataFrame: [[task_name, prompt_template], ...]
if isinstance(tdc_prompts_data, dict):
print(f"Processing {len(tdc_prompts_data)} prompts from dictionary for DataFrame...")
for task_name, prompt_template in tdc_prompts_data.items():
if isinstance(prompt_template, str) and isinstance(task_name, str):
# Add task name and the raw template to the list
dataframe_data.append([task_name, prompt_template])
else:
print(f"Warning: Skipping invalid item in prompts dictionary: key={task_name}, value_type={type(prompt_template)}")
print(f"Prepared {len(dataframe_data)} rows for DataFrame.")
else:
print(f"Warning: Expected {PROMPT_FILENAME} to contain a dictionary, but found {type(tdc_prompts_data)}. Cannot load examples.")
# dataframe_data remains empty
except Exception as e:
print(f"Error loading model, tokenizer, or prompts: {e}")
# Ensure dataframe_data is empty on error during setup
dataframe_data = []
raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}")
# --- Prediction Function ---
def predict(prompt, max_new_tokens=100, temperature=0.7):
"""
Generates text based on the input prompt using the loaded model.
(Function remains the same as before)
"""
print(f"Received prompt: {prompt}")
print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}")
try:
# Prepare the input for the model
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device
# Generate text
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens), # Ensure it's an integer
temperature=float(temperature), # Ensure it's a float
do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0
pad_token_id=tokenizer.eos_token_id # Set pad token id
)
# Decode the generated tokens
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated text (raw): {generated_text}")
# Remove the prompt from the beginning of the generated text
if generated_text.startswith(prompt):
prompt_length = len(prompt)
result_text = generated_text[prompt_length:].lstrip()
else:
common_prefix = os.path.commonprefix([prompt, generated_text])
if len(prompt) > 0 and len(common_prefix) / len(prompt) > 0.8:
result_text = generated_text[len(common_prefix):].lstrip()
else:
result_text = generated_text
print(f"Generated text (processed): {result_text}")
return result_text
except Exception as e:
print(f"Error during prediction: {e}")
return f"An error occurred during generation: {e}"
# --- Function to handle DataFrame selection ---
def select_prompt_from_df(evt: gr.SelectData):
"""
Triggered when a row is selected in the DataFrame.
Updates the main prompt input with the selected template, replacing the placeholder.
"""
if evt.index is None or evt.index[0] >= len(dataframe_data):
print("Invalid selection event or index out of bounds.")
return gr.update() # No change
selected_row_index = evt.index[0]
# Get the prompt template from the second column (index 1) of the selected row
prompt_template = dataframe_data[selected_row_index][1]
# Replace the placeholder with the example SMILES string
selected_prompt = prompt_template.replace("{Drug SMILES}", EXAMPLE_SMILES)
print(f"Selected prompt template from row {selected_row_index}, updated input.")
# Return the processed prompt to update the prompt_input textbox
return selected_prompt
# --- Gradio Interface ---
print("Creating Gradio interface...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# 🤖 TXGemma-2B-Predict Property Prediction
Enter a prompt below, or select a task from the table to load its template, and the model ({MODEL_NAME}) will generate text.
Adjust the parameters for different results. Prompt templates loaded from `{PROMPT_FILENAME}`.
Selected templates will use the SMILES string `{EXAMPLE_SMILES}` (Benzene) as a placeholder.
"""
)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Your Prompt",
placeholder="Enter your text prompt here, or select a template from the table below...",
lines=5,
elem_id="prompt_input_box" # Add elem_id for clarity if needed
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10,
maximum=500,
value=100,
step=10,
label="Max New Tokens",
info="Maximum number of tokens to generate after the prompt."
)
temperature_slider = gr.Slider(
minimum=0.0,
maximum=1.5,
value=0.7,
step=0.05,
label="Temperature",
info="Controls randomness (0=deterministic, >0=random)."
)
submit_button = gr.Button("Generate Text", variant="primary")
with gr.Column(scale=3):
output_text = gr.Textbox(
label="Generated Text",
lines=10, # Adjust height if needed
interactive=False
)
# --- Add DataFrame for Prompt Templates ---
gr.Markdown("### Select a Prompt Template")
prompt_df = gr.DataFrame(
value=dataframe_data,
headers=DATAFRAME_HEADERS,
row_count=(DATAFRAME_ROW_COUNT, "dynamic"), # Show fixed rows initially, allow scrolling
col_count=(len(DATAFRAME_HEADERS), "fixed"), # Fixed number of columns
wrap=True, # Wrap text in cells
label="Prompt Templates"
)
# --- Connect Components ---
# Connect submit button to prediction function
submit_button.click(
fn=predict,
inputs=[prompt_input, max_tokens_slider, temperature_slider],
outputs=output_text,
api_name="predict"
)
# Connect DataFrame selection to update prompt input
# The `select` event triggers the `select_prompt_from_df` function.
# The event data (evt: gr.SelectData) is implicitly passed to the function.
# The function returns the value to update the `prompt_input` component.
prompt_df.select(
fn=select_prompt_from_df,
inputs=None, # No explicit inputs needed, event data is passed automatically
outputs=prompt_input,
show_progress="hidden" # Hide progress bar for this quick update
)
# --- Launch the App ---
print("Launching Gradio app...")
demo.queue().launch(debug=True) # Set debug=False for production