File size: 3,578 Bytes
0463bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
f8afe62
 
be7c8ee
0463bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import base64
from io import BytesIO
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

# Initialize FastAPI
app = FastAPI()

# Load the model and processor
#checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
#checkpoint = "Qwen/Qwen2.5-VL-7B-Instruct"
checkpoint = "Qwen/Qwen2.5-VL-32B-Instruct"
#checkpoint = "Qwen/Qwen2.5-VL-72B-Instruct"
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28

processor = AutoProcessor.from_pretrained(
    checkpoint,
    min_pixels=min_pixels,
    max_pixels=max_pixels,
    use_fase=True
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    checkpoint,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Define the request schema
class ImageRequest(BaseModel):
    image_base64: str  # Base64 encoded image
    prompt: str  # Text prompt

@app.get("/")
def read_root():
    return {"message": "API is live. Use the /predict endpoint."}

@app.post("/predict")  # Changed from GET to POST
async def predict(request: ImageRequest):
    # Decode the base64 image
    try:
        image_data = base64.b64decode(request.image_base64)
        image = Image.open(BytesIO(image_data)).convert("RGB")
    except Exception as e:
        return {"error": f"Invalid base64 image data: {str(e)}"}

    # Create message structure
    messages = [
        {"role": "system", "content": "You are a helpful assistant with vision abilities."},
        {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": request.prompt}]},
    ]

    # Process inputs
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(model.device)

    # Run inference
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=4096)  # 128

    # Process output
    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_texts = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return {"response": output_texts[0]}


class SummaryRequest(BaseModel):
    prompt: str  # Input text to summarize

@app.post("/summary")
async def summary(request: SummaryRequest):
    # Create message structure
    messages = [
        {"role": "system", "content": "You are a helpful assistant that summarizes text."},
        {"role": "user", "content": [{"type": "text", "text": request.prompt}]},
    ]
    
    # Process inputs (text-only)
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(
        text=[text],
        padding=True,
        return_tensors="pt",
    ).to(model.device)
    
    # Run inference
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=4096)  # Adjust max_new_tokens for summary length
    
    # Process output
    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_texts = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    return {"response": output_texts[0]}