5m4ck3r's picture
Update app.py
bd6df67 verified
raw
history blame contribute delete
1.66 kB
import gradio as gr
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaNextForConditionalGeneration
import spaces
# Load the processor and model
model_id = "llava-hf/llava-v1.6-mistral-7B-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = LlavaNextForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
@spaces.GPU()
def llava_inference(image: Image.Image, prompt: str):
# Format the input as a conversation
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
},
]
formatted_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(image, formatted_prompt, return_tensors="pt").to(device)
# Generate response with a max token limit
output_ids = model.generate(**inputs, max_new_tokens=100)
output_text = processor.decode(output_ids[0], skip_special_tokens=True)
return output_text
# Updated Gradio interface using new component syntax
demo = gr.Interface(
fn=llava_inference,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt")
],
outputs=gr.Text(label="Output Response"),
title="LLaVA-1.6 Gradio Demo",
description="Upload an image and enter a prompt. The model will generate a response using LLaVA-1.6.",
)
if __name__ == "__main__":
demo.launch()