MiniCPM-v-2_6 / handler.py
fredaddy's picture
Update handler.py
a1c2e19 verified
raw
history blame contribute delete
1.93 kB
import torch
from PIL import Image
import base64
from io import BytesIO
from transformers import AutoModel, AutoTokenizer
class EndpointHandler:
def __init__(self, path="/repository"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
self.model = AutoModel.from_pretrained(
path,
trust_remote_code=True,
attn_implementation='sdpa',
torch_dtype=torch.bfloat16 if self.device.type == "cuda" else torch.float32,
).to(self.device)
self.model.eval()
# Load the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True,
)
def __call__(self, data):
# Extract image and text from the input data
image_data = data.get("inputs", {}).get("image", "")
text_prompt = data.get("inputs", {}).get("text", "")
if not image_data or not text_prompt:
return {"error": "Both 'image' and 'text' must be provided in the input data."}
# Process the image data
try:
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
except Exception as e:
return {"error": f"Failed to process image data: {e}"}
# Prepare the messages for the model
msgs = [{'role': 'user', 'content': [image, text_prompt]}]
# Generate output
with torch.no_grad():
res = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
sampling=True,
temperature=0.7,
top_p=0.95,
max_length=2000,
)
# The result is the generated text
output_text = res
return {"generated_text": output_text}