ocr_florence2 / app.py
K00B404's picture
Update app.py
316538a verified
import gradio as gr
import torch
import os
from dotenv import load_dotenv,find_dotenv
load_dotenv(find_dotenv())
import requests
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import io
import base64
import numpy as np
token=os.getenv('HF_TOKEN')
# Initialize models
# OCR model for text extraction
ocr_model = pipeline("document-question-answering", model="impira/layoutlm-document-qa")
# Florence-2 model for image understanding
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
florence_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base",
torch_dtype=torch_dtype,
trust_remote_code=True
).to(device)
# LLaMA model for game control interface reasoning
llm_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
llm_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").to(device)
def preprocess_image(image):
"""Convert image to the format required by the models"""
if isinstance(image, str): # If image is a base64 string
image = Image.open(io.BytesIO(base64.b64decode(image.split(",")[1])))
return image
def extract_text_from_image(image):
"""Extract text from the image using OCR and Florence-2's OCR capabilities"""
image = preprocess_image(image)
# Use LayoutLM for document text extraction
layout_result = ocr_model(image=image, question="What text is in this image?")
layout_text = layout_result['answer']
# Also use Florence-2 for text detection
# Florence-2 can be used with <OCR> task
prompt = "<OCR>"
# Process with Florence-2 for OCR
inputs = florence_processor(
text=prompt,
images=image,
return_tensors="pt"
).to(device, torch_dtype)
with torch.no_grad():
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3,
)
# Decode and process the generated text
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_text = florence_processor.post_process_generation(generated_text, task="<OCR>")
# Combine results from both models
combined_text = f"LayoutLM OCR: {layout_text}\n\nFlorence-2 OCR: {parsed_text}"
return combined_text
def analyze_image(image):
"""Analyze image content using Florence-2 for object detection"""
image = preprocess_image(image)
# Use Object Detection task with Florence-2
prompt = "<OD>" # Object Detection task token
# Process the image with Florence-2
inputs = florence_processor(
text=prompt,
images=image,
return_tensors="pt"
).to(device, torch_dtype)
with torch.no_grad():
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3,
)
# Decode and post-process the generated text
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_objects = florence_processor.post_process_generation(
generated_text,
task="<OD>",
image_size=(image.width, image.height)
)
# Process the detected objects for game analysis
game_related_categories = [
"person", "player", "character", "enemy", "button", "screen", "display",
"health", "bar", "score", "menu", "weapon", "item", "obstacle", "platform",
"text", "number", "icon", "power-up", "door", "key", "coin", "vehicle"
]
# Filter and organize detected objects
detected_elements = {}
confidence_sum = 0
count = 0
for obj in parsed_objects:
category = obj["category"]
confidence = obj["score"]
# Check if this is a game-related object or try to map it
for game_category in game_related_categories:
if game_category in category.lower():
if category not in detected_elements or confidence > detected_elements[category]["confidence"]:
detected_elements[category] = {
"confidence": confidence,
"box": obj["box"] # Keep the bounding box information
}
confidence_sum += confidence
count += 1
break
# Calculate average confidence
avg_confidence = confidence_sum / max(count, 1)
return {
"detected_elements": list(detected_elements.keys()),
"element_details": detected_elements,
"confidence": avg_confidence
}
def generate_game_control(text_content, image_analysis, user_input):
"""Generate game control interface suggestions using LLaMA and Florence-2's visual understanding"""
# Extract more detailed information from the image analysis
detected_elements = image_analysis['detected_elements']
element_details = image_analysis['element_details']
# Create a more detailed prompt for LLaMA with positional information
detailed_elements = []
for element in detected_elements:
if element in element_details:
box = element_details[element]['box']
confidence = element_details[element]['confidence']
position = f"at position x:{box[0]:.1f}-{box[2]:.1f}, y:{box[1]:.1f}-{box[3]:.1f}"
detailed_elements.append(f"{element} ({position}, confidence: {confidence:.2f})")
# Format detailed elements text
detailed_elements_text = "\n - ".join([""] + detailed_elements) if detailed_elements else "None detected with high confidence"
# Prepare comprehensive prompt for LLaMA
prompt = f"""
You are an AI game assistant that helps players understand game screenshots and provides control suggestions.
Game screenshot analysis:
- Text content detected:
{text_content}
- Visual elements detected: {detailed_elements_text}
- Overall detection confidence: {image_analysis['confidence']:.2f}
User query: {user_input}
Based on the game screenshot analysis above, provide specific game control suggestions.
Focus on:
1. What UI elements the player should interact with
2. Which buttons or controls they should use
3. Gameplay strategy based on what's visible
4. Clear next steps or actions
Your response:
"""
# Process with LLaMA
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = llm_model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
num_beams=3,
)
response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the game control suggestions from the response
try:
suggestions = response.split("Your response:")[1].strip()
except:
suggestions = response
return suggestions
def process_game_screenshot(image, user_input):
"""Main function to process game screenshot and generate control interface"""
if image is None:
return "Please upload a game screenshot."
# Extract text from image
text_content = extract_text_from_image(image)
# Analyze image content
image_analysis = analyze_image(image)
# Use Florence-2 for a general image description as well
# This gives additional context about the game scene
image_desc = get_florence_image_description(image)
# Generate game control interface suggestions
control_suggestions = generate_game_control(text_content, image_analysis, user_input)
# Create comprehensive response
detected_elements_formatted = []
for elem in image_analysis['detected_elements']:
if elem in image_analysis['element_details']:
conf = image_analysis['element_details'][elem]['confidence']
detected_elements_formatted.append(f"{elem} (confidence: {conf:.2f})")
elements_text = "\n- ".join([""] + detected_elements_formatted) if detected_elements_formatted else "None detected with high confidence"
response = f"""
## Game Screenshot Analysis
### Scene Description:
{image_desc}
### Text Content Detected:
{text_content}
### Visual Elements Detected:
{elements_text}
## Game Control Suggestions:
{control_suggestions}
"""
return response
def get_florence_image_description(image):
"""Get a general description of the image using Florence-2's image captioning capability"""
image = preprocess_image(image)
# Use Image Captioning task with Florence-2
prompt = "<IC>" # Image Captioning task token
# Process the image with Florence-2
inputs = florence_processor(
text=prompt,
images=image,
return_tensors="pt"
).to(device, torch_dtype)
with torch.no_grad():
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=512,
do_sample=True,
temperature=0.8,
top_p=0.9,
num_beams=3,
)
# Decode the generated text
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
caption = florence_processor.post_process_generation(generated_text, task="<IC>")
return caption
def create_api():
"""Create and expose the API endpoint"""
with gr.Blocks(title="Game Control Interface AI") as app:
gr.Markdown("# Game Control Interface AI")
gr.Markdown("Upload a game screenshot and provide your query to get game control suggestions")
with gr.Row():
with gr.Column(scale=2):
image_input = gr.Image(type="pil", label="Game Screenshot")
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Your Query",
placeholder="e.g., 'How do I defeat this enemy?', 'What should I do next?'",
lines=3
)
submit_button = gr.Button("Analyze Screenshot", variant="primary")
# Add example queries to help users
example_queries = [
["What should I do next in this game?"],
["How do I defeat this enemy?"],
["What items should I collect in this scene?"],
["How do I solve this puzzle?"],
["What controls should I use in this situation?"]
]
gr.Examples(
examples=example_queries,
inputs=text_input
)
with gr.Row():
with gr.Column():
# Add tabs for different views
with gr.Tabs():
with gr.TabItem("Game Control Suggestions"):
output = gr.Markdown(label="Game Control Interface Suggestions")
with gr.TabItem("Raw Analysis Data"):
with gr.Accordion("OCR Results", open=False):
ocr_output = gr.Textbox(label="Text Detection Results", lines=5)
with gr.Accordion("Object Detection", open=False):
object_output = gr.JSON(label="Detected Objects")
# Define processing function with multiple outputs
def process_with_details(image, user_input):
if image is None:
return "Please upload a game screenshot.", "No text detected", {}
# Extract text from image
text_content = extract_text_from_image(image)
# Analyze image content
image_analysis = analyze_image(image)
# Use Florence-2 for a general image description
image_desc = get_florence_image_description(image)
# Generate game control interface suggestions
control_suggestions = generate_game_control(text_content, image_analysis, user_input)
# Format main response
detected_elements_formatted = []
for elem in image_analysis['detected_elements']:
if elem in image_analysis['element_details']:
conf = image_analysis['element_details'][elem]['confidence']
detected_elements_formatted.append(f"{elem} (confidence: {conf:.2f})")
elements_text = "\n- ".join([""] + detected_elements_formatted) if detected_elements_formatted else "None detected with high confidence"
response = f"""
## Game Screenshot Analysis
### Scene Description:
{image_desc}
### Text Content Detected:
{text_content}
### Visual Elements Detected:
{elements_text}
## Game Control Suggestions:
{control_suggestions}
"""
return response, text_content, image_analysis['element_details']
# Connect the interface
submit_button.click(
fn=process_with_details,
inputs=[image_input, text_input],
outputs=[output, ocr_output, object_output]
)
# Add API endpoint
gr.Interface(
fn=process_game_screenshot,
inputs=[
gr.Image(type="pil", label="Game Screenshot"),
gr.Textbox(label="User Query")
],
outputs=gr.Markdown(label="Game Control Interface Suggestions"),
title="Game Control Interface AI API",
description="API for game screenshot analysis and control suggestions",
examples=[
["path/to/example_screenshot.jpg", "What should I do next?"],
["path/to/example_boss_battle.jpg", "How do I defeat this boss?"]
]
).launch(share=True)
return app
# Entry point
if __name__ == "__main__":
app = create_api()
app.launch(share=True, show_api=True)