JJeDIkUYeCqVU7B / main.py
wBfvtNqNhb's picture
Update main.py
be7c8ee verified
raw
history blame
3.53 kB
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import base64
from io import BytesIO
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# Initialize FastAPI
app = FastAPI()
# Load the model and processor
#checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
checkpoint = "Qwen/Qwen2.5-VL-7B-Instruct"
#checkpoint = "Qwen/Qwen2.5-VL-72B-Instruct"
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor = AutoProcessor.from_pretrained(
checkpoint,
min_pixels=min_pixels,
max_pixels=max_pixels,
use_fase=True
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Define the request schema
class ImageRequest(BaseModel):
image_base64: str # Base64 encoded image
prompt: str # Text prompt
@app.get("/")
def read_root():
return {"message": "API is live. Use the /predict endpoint."}
@app.post("/predict") # Changed from GET to POST
async def predict(request: ImageRequest):
# Decode the base64 image
try:
image_data = base64.b64decode(request.image_base64)
image = Image.open(BytesIO(image_data)).convert("RGB")
except Exception as e:
return {"error": f"Invalid base64 image data: {str(e)}"}
# Create message structure
messages = [
{"role": "system", "content": "You are a helpful assistant with vision abilities."},
{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": request.prompt}]},
]
# Process inputs
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(model.device)
# Run inference
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=4096) # 128
# Process output
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_texts = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return {"response": output_texts[0]}
class SummaryRequest(BaseModel):
prompt: str # Input text to summarize
@app.post("/summary")
async def summary(request: SummaryRequest):
# Create message structure
messages = [
{"role": "system", "content": "You are a helpful assistant that summarizes text."},
{"role": "user", "content": [{"type": "text", "text": request.prompt}]},
]
# Process inputs (text-only)
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
padding=True,
return_tensors="pt",
).to(model.device)
# Run inference
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=4096) # Adjust max_new_tokens for summary length
# Process output
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_texts = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return {"response": output_texts[0]}