vqa-demo-cop / app.py
duyan2803's picture
Update app.py
b8d2cfc verified
import gradio as gr
import torch
from transformers import ViTModel, BertModel, BertTokenizer
from torchvision import transforms
from PIL import Image
import json
from torch import nn
from huggingface_hub import hf_hub_download
# Định nghĩa mô hình
class VQAModel(nn.Module):
def __init__(self, num_answers):
super(VQAModel, self).__init__()
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(768 * 3, num_answers)
)
def forward(self, image, input_ids, attention_mask):
image_features = self.vit(image).last_hidden_state[:, 0, :]
text_features = self.bert(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
combined = torch.cat([image_features, text_features, image_features * text_features], dim=1)
output = self.classifier(combined)
return output
# Load mô hình từ Hugging Face Hub
repo_id = "duyan2803/vqa-model-vilt-bert-color-optim"
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, "r") as f:
config = json.load(f)
num_answers = config["num_answers"]
weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
model = VQAModel(num_answers=num_answers)
state_dict = torch.load(weights_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
tokenizer = BertTokenizer.from_pretrained(repo_id)
answer_list_path = hf_hub_download(repo_id=repo_id, filename="answer_list.json")
with open(answer_list_path, "r") as f:
answer_list = json.load(f)
except Exception as e:
print(f"Lỗi khi load mô hình: {str(e)}")
raise e
# Hàm dự đoán
def predict(image, question):
try:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = transform(image).unsqueeze(0).to(device)
tokenized = tokenizer(question, padding='max_length', truncation=True, max_length=32, return_tensors='pt')
input_ids = tokenized['input_ids'].to(device)
attention_mask = tokenized['attention_mask'].to(device)
with torch.no_grad():
output = model(image_tensor, input_ids, attention_mask)
pred_idx = output.argmax(dim=1).item()
return answer_list[pred_idx]
except Exception as e:
return f"Lỗi khi dự đoán: {str(e)}"
# Giao diện Gradio
interface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload an image"),
gr.Textbox(label="Ask a question")
],
outputs=gr.Textbox(label="Answer"),
title="VQA Demo - Car Recognition",
description="Upload an image of a car and ask a question (e.g., 'What color is this car?' or 'What is this car?')."
)
interface.launch()