vqa-demo / app.py
pmhanh
Update app.py
a288221
import gradio as gr
import torch
from transformers import ViTModel, BertModel, BertTokenizer
from torchvision import transforms
from PIL import Image
import json
from torch import nn
from huggingface_hub import hf_hub_download
# Định nghĩa mô hình
class VQAModel(nn.Module):
def __init__(self, num_answers):
super(VQAModel, self).__init__()
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(768 * 3, num_answers)
)
def forward(self, image, input_ids, attention_mask):
image_features = self.vit(image).last_hidden_state[:, 0, :]
text_features = self.bert(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
combined = torch.cat([image_features, text_features, image_features * text_features], dim=1)
output = self.classifier(combined)
return output
# Load mô hình từ Hugging Face Hub
repo_id = "duyan2803/vqa-model-vit-bert"
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
# Load config
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, "r") as f:
config = json.load(f)
num_answers = config["num_answers"]
# Load weights
weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
model = VQAModel(num_answers=num_answers)
state_dict = torch.load(weights_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print("Đã load mô hình thành công!")
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained(repo_id)
print("Đã load tokenizer thành công!")
# Load answer list
answer_list_path = hf_hub_download(repo_id=repo_id, filename="answer_list.json")
with open(answer_list_path, "r") as f:
answer_list = json.load(f)
print("Đã load answer list thành công!")
except Exception as e:
print(f"Lỗi khi load mô hình hoặc file: {str(e)}")
raise e
# Hàm dự đoán
def predict(image, question):
try:
# Xử lý ảnh
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = transform(image).unsqueeze(0).to(device)
# Xử lý câu hỏi
tokenized = tokenizer(question, padding='max_length', truncation=True, max_length=32, return_tensors='pt')
input_ids = tokenized['input_ids'].to(device)
attention_mask = tokenized['attention_mask'].to(device)
# Dự đoán
with torch.no_grad():
output = model(image_tensor, input_ids, attention_mask)
pred_idx = output.argmax(dim=1).item()
return answer_list[pred_idx]
except Exception as e:
return f"Lỗi khi dự đoán: {str(e)}"
# Giao diện Gradio
interface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload an image"),
gr.Textbox(label="Ask a question")
],
outputs=gr.Textbox(label="Answer"),
title="VQA Demo - Car Recognition",
description="Upload an image of a car and ask a question (e.g., 'What color is this car?' or 'What is this car?')."
)
interface.launch()