Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import AutoProcessor, LlavaNextForConditionalGeneration | |
import spaces | |
# Load the processor and model | |
model_id = "llava-hf/llava-v1.6-mistral-7B-hf" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = LlavaNextForConditionalGeneration.from_pretrained( | |
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
def llava_inference(image: Image.Image, prompt: str): | |
# Format the input as a conversation | |
conversation = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": prompt}, | |
], | |
}, | |
] | |
formatted_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = processor(image, formatted_prompt, return_tensors="pt").to(device) | |
# Generate response with a max token limit | |
output_ids = model.generate(**inputs, max_new_tokens=100) | |
output_text = processor.decode(output_ids[0], skip_special_tokens=True) | |
return output_text | |
# Updated Gradio interface using new component syntax | |
demo = gr.Interface( | |
fn=llava_inference, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image"), | |
gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt") | |
], | |
outputs=gr.Text(label="Output Response"), | |
title="LLaVA-1.6 Gradio Demo", | |
description="Upload an image and enter a prompt. The model will generate a response using LLaVA-1.6.", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |