Spaces:
Running
Running
import gradio as gr | |
import torch | |
import os | |
from dotenv import load_dotenv,find_dotenv | |
load_dotenv(find_dotenv()) | |
import requests | |
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer | |
from PIL import Image | |
import io | |
import base64 | |
import numpy as np | |
token=os.getenv('HF_TOKEN') | |
# Initialize models | |
# OCR model for text extraction | |
ocr_model = pipeline("document-question-answering", model="impira/layoutlm-document-qa") | |
# Florence-2 model for image understanding | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
florence_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) | |
florence_model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Florence-2-base", | |
torch_dtype=torch_dtype, | |
trust_remote_code=True | |
).to(device) | |
# LLaMA model for game control interface reasoning | |
llm_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
llm_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").to(device) | |
def preprocess_image(image): | |
"""Convert image to the format required by the models""" | |
if isinstance(image, str): # If image is a base64 string | |
image = Image.open(io.BytesIO(base64.b64decode(image.split(",")[1]))) | |
return image | |
def extract_text_from_image(image): | |
"""Extract text from the image using OCR and Florence-2's OCR capabilities""" | |
image = preprocess_image(image) | |
# Use LayoutLM for document text extraction | |
layout_result = ocr_model(image=image, question="What text is in this image?") | |
layout_text = layout_result['answer'] | |
# Also use Florence-2 for text detection | |
# Florence-2 can be used with <OCR> task | |
prompt = "<OCR>" | |
# Process with Florence-2 for OCR | |
inputs = florence_processor( | |
text=prompt, | |
images=image, | |
return_tensors="pt" | |
).to(device, torch_dtype) | |
with torch.no_grad(): | |
generated_ids = florence_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
do_sample=False, | |
num_beams=3, | |
) | |
# Decode and process the generated text | |
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_text = florence_processor.post_process_generation(generated_text, task="<OCR>") | |
# Combine results from both models | |
combined_text = f"LayoutLM OCR: {layout_text}\n\nFlorence-2 OCR: {parsed_text}" | |
return combined_text | |
def analyze_image(image): | |
"""Analyze image content using Florence-2 for object detection""" | |
image = preprocess_image(image) | |
# Use Object Detection task with Florence-2 | |
prompt = "<OD>" # Object Detection task token | |
# Process the image with Florence-2 | |
inputs = florence_processor( | |
text=prompt, | |
images=image, | |
return_tensors="pt" | |
).to(device, torch_dtype) | |
with torch.no_grad(): | |
generated_ids = florence_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
do_sample=False, | |
num_beams=3, | |
) | |
# Decode and post-process the generated text | |
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_objects = florence_processor.post_process_generation( | |
generated_text, | |
task="<OD>", | |
image_size=(image.width, image.height) | |
) | |
# Process the detected objects for game analysis | |
game_related_categories = [ | |
"person", "player", "character", "enemy", "button", "screen", "display", | |
"health", "bar", "score", "menu", "weapon", "item", "obstacle", "platform", | |
"text", "number", "icon", "power-up", "door", "key", "coin", "vehicle" | |
] | |
# Filter and organize detected objects | |
detected_elements = {} | |
confidence_sum = 0 | |
count = 0 | |
for obj in parsed_objects: | |
category = obj["category"] | |
confidence = obj["score"] | |
# Check if this is a game-related object or try to map it | |
for game_category in game_related_categories: | |
if game_category in category.lower(): | |
if category not in detected_elements or confidence > detected_elements[category]["confidence"]: | |
detected_elements[category] = { | |
"confidence": confidence, | |
"box": obj["box"] # Keep the bounding box information | |
} | |
confidence_sum += confidence | |
count += 1 | |
break | |
# Calculate average confidence | |
avg_confidence = confidence_sum / max(count, 1) | |
return { | |
"detected_elements": list(detected_elements.keys()), | |
"element_details": detected_elements, | |
"confidence": avg_confidence | |
} | |
def generate_game_control(text_content, image_analysis, user_input): | |
"""Generate game control interface suggestions using LLaMA and Florence-2's visual understanding""" | |
# Extract more detailed information from the image analysis | |
detected_elements = image_analysis['detected_elements'] | |
element_details = image_analysis['element_details'] | |
# Create a more detailed prompt for LLaMA with positional information | |
detailed_elements = [] | |
for element in detected_elements: | |
if element in element_details: | |
box = element_details[element]['box'] | |
confidence = element_details[element]['confidence'] | |
position = f"at position x:{box[0]:.1f}-{box[2]:.1f}, y:{box[1]:.1f}-{box[3]:.1f}" | |
detailed_elements.append(f"{element} ({position}, confidence: {confidence:.2f})") | |
# Format detailed elements text | |
detailed_elements_text = "\n - ".join([""] + detailed_elements) if detailed_elements else "None detected with high confidence" | |
# Prepare comprehensive prompt for LLaMA | |
prompt = f""" | |
You are an AI game assistant that helps players understand game screenshots and provides control suggestions. | |
Game screenshot analysis: | |
- Text content detected: | |
{text_content} | |
- Visual elements detected: {detailed_elements_text} | |
- Overall detection confidence: {image_analysis['confidence']:.2f} | |
User query: {user_input} | |
Based on the game screenshot analysis above, provide specific game control suggestions. | |
Focus on: | |
1. What UI elements the player should interact with | |
2. Which buttons or controls they should use | |
3. Gameplay strategy based on what's visible | |
4. Clear next steps or actions | |
Your response: | |
""" | |
# Process with LLaMA | |
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = llm_model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
num_beams=3, | |
) | |
response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract the game control suggestions from the response | |
try: | |
suggestions = response.split("Your response:")[1].strip() | |
except: | |
suggestions = response | |
return suggestions | |
def process_game_screenshot(image, user_input): | |
"""Main function to process game screenshot and generate control interface""" | |
if image is None: | |
return "Please upload a game screenshot." | |
# Extract text from image | |
text_content = extract_text_from_image(image) | |
# Analyze image content | |
image_analysis = analyze_image(image) | |
# Use Florence-2 for a general image description as well | |
# This gives additional context about the game scene | |
image_desc = get_florence_image_description(image) | |
# Generate game control interface suggestions | |
control_suggestions = generate_game_control(text_content, image_analysis, user_input) | |
# Create comprehensive response | |
detected_elements_formatted = [] | |
for elem in image_analysis['detected_elements']: | |
if elem in image_analysis['element_details']: | |
conf = image_analysis['element_details'][elem]['confidence'] | |
detected_elements_formatted.append(f"{elem} (confidence: {conf:.2f})") | |
elements_text = "\n- ".join([""] + detected_elements_formatted) if detected_elements_formatted else "None detected with high confidence" | |
response = f""" | |
## Game Screenshot Analysis | |
### Scene Description: | |
{image_desc} | |
### Text Content Detected: | |
{text_content} | |
### Visual Elements Detected: | |
{elements_text} | |
## Game Control Suggestions: | |
{control_suggestions} | |
""" | |
return response | |
def get_florence_image_description(image): | |
"""Get a general description of the image using Florence-2's image captioning capability""" | |
image = preprocess_image(image) | |
# Use Image Captioning task with Florence-2 | |
prompt = "<IC>" # Image Captioning task token | |
# Process the image with Florence-2 | |
inputs = florence_processor( | |
text=prompt, | |
images=image, | |
return_tensors="pt" | |
).to(device, torch_dtype) | |
with torch.no_grad(): | |
generated_ids = florence_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.8, | |
top_p=0.9, | |
num_beams=3, | |
) | |
# Decode the generated text | |
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
caption = florence_processor.post_process_generation(generated_text, task="<IC>") | |
return caption | |
def create_api(): | |
"""Create and expose the API endpoint""" | |
with gr.Blocks(title="Game Control Interface AI") as app: | |
gr.Markdown("# Game Control Interface AI") | |
gr.Markdown("Upload a game screenshot and provide your query to get game control suggestions") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
image_input = gr.Image(type="pil", label="Game Screenshot") | |
with gr.Column(scale=1): | |
text_input = gr.Textbox( | |
label="Your Query", | |
placeholder="e.g., 'How do I defeat this enemy?', 'What should I do next?'", | |
lines=3 | |
) | |
submit_button = gr.Button("Analyze Screenshot", variant="primary") | |
# Add example queries to help users | |
example_queries = [ | |
["What should I do next in this game?"], | |
["How do I defeat this enemy?"], | |
["What items should I collect in this scene?"], | |
["How do I solve this puzzle?"], | |
["What controls should I use in this situation?"] | |
] | |
gr.Examples( | |
examples=example_queries, | |
inputs=text_input | |
) | |
with gr.Row(): | |
with gr.Column(): | |
# Add tabs for different views | |
with gr.Tabs(): | |
with gr.TabItem("Game Control Suggestions"): | |
output = gr.Markdown(label="Game Control Interface Suggestions") | |
with gr.TabItem("Raw Analysis Data"): | |
with gr.Accordion("OCR Results", open=False): | |
ocr_output = gr.Textbox(label="Text Detection Results", lines=5) | |
with gr.Accordion("Object Detection", open=False): | |
object_output = gr.JSON(label="Detected Objects") | |
# Define processing function with multiple outputs | |
def process_with_details(image, user_input): | |
if image is None: | |
return "Please upload a game screenshot.", "No text detected", {} | |
# Extract text from image | |
text_content = extract_text_from_image(image) | |
# Analyze image content | |
image_analysis = analyze_image(image) | |
# Use Florence-2 for a general image description | |
image_desc = get_florence_image_description(image) | |
# Generate game control interface suggestions | |
control_suggestions = generate_game_control(text_content, image_analysis, user_input) | |
# Format main response | |
detected_elements_formatted = [] | |
for elem in image_analysis['detected_elements']: | |
if elem in image_analysis['element_details']: | |
conf = image_analysis['element_details'][elem]['confidence'] | |
detected_elements_formatted.append(f"{elem} (confidence: {conf:.2f})") | |
elements_text = "\n- ".join([""] + detected_elements_formatted) if detected_elements_formatted else "None detected with high confidence" | |
response = f""" | |
## Game Screenshot Analysis | |
### Scene Description: | |
{image_desc} | |
### Text Content Detected: | |
{text_content} | |
### Visual Elements Detected: | |
{elements_text} | |
## Game Control Suggestions: | |
{control_suggestions} | |
""" | |
return response, text_content, image_analysis['element_details'] | |
# Connect the interface | |
submit_button.click( | |
fn=process_with_details, | |
inputs=[image_input, text_input], | |
outputs=[output, ocr_output, object_output] | |
) | |
# Add API endpoint | |
gr.Interface( | |
fn=process_game_screenshot, | |
inputs=[ | |
gr.Image(type="pil", label="Game Screenshot"), | |
gr.Textbox(label="User Query") | |
], | |
outputs=gr.Markdown(label="Game Control Interface Suggestions"), | |
title="Game Control Interface AI API", | |
description="API for game screenshot analysis and control suggestions", | |
examples=[ | |
["path/to/example_screenshot.jpg", "What should I do next?"], | |
["path/to/example_boss_battle.jpg", "How do I defeat this boss?"] | |
] | |
).launch(share=True) | |
return app | |
# Entry point | |
if __name__ == "__main__": | |
app = create_api() | |
app.launch(share=True, show_api=True) |