import gradio as gr import torch from transformers import ViTModel, BertModel, BertTokenizer from torchvision import transforms from PIL import Image import json from torch import nn from huggingface_hub import hf_hub_download # Định nghĩa mô hình class VQAModel(nn.Module): def __init__(self, num_answers): super(VQAModel, self).__init__() self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224') self.bert = BertModel.from_pretrained('bert-base-uncased') self.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(768 * 3, num_answers) ) def forward(self, image, input_ids, attention_mask): image_features = self.vit(image).last_hidden_state[:, 0, :] text_features = self.bert(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] combined = torch.cat([image_features, text_features, image_features * text_features], dim=1) output = self.classifier(combined) return output # Load mô hình từ Hugging Face Hub repo_id = "duyan2803/vqa-model-vit-bert" device = "cuda" if torch.cuda.is_available() else "cpu" try: # Load config config_path = hf_hub_download(repo_id=repo_id, filename="config.json") with open(config_path, "r") as f: config = json.load(f) num_answers = config["num_answers"] # Load weights weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") model = VQAModel(num_answers=num_answers) state_dict = torch.load(weights_path, map_location=device, weights_only=True) model.load_state_dict(state_dict) model.to(device) model.eval() print("Đã load mô hình thành công!") # Load tokenizer tokenizer = BertTokenizer.from_pretrained(repo_id) print("Đã load tokenizer thành công!") # Load answer list answer_list_path = hf_hub_download(repo_id=repo_id, filename="answer_list.json") with open(answer_list_path, "r") as f: answer_list = json.load(f) print("Đã load answer list thành công!") except Exception as e: print(f"Lỗi khi load mô hình hoặc file: {str(e)}") raise e # Hàm dự đoán def predict(image, question): try: # Xử lý ảnh transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image_tensor = transform(image).unsqueeze(0).to(device) # Xử lý câu hỏi tokenized = tokenizer(question, padding='max_length', truncation=True, max_length=32, return_tensors='pt') input_ids = tokenized['input_ids'].to(device) attention_mask = tokenized['attention_mask'].to(device) # Dự đoán with torch.no_grad(): output = model(image_tensor, input_ids, attention_mask) pred_idx = output.argmax(dim=1).item() return answer_list[pred_idx] except Exception as e: return f"Lỗi khi dự đoán: {str(e)}" # Giao diện Gradio interface = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Upload an image"), gr.Textbox(label="Ask a question") ], outputs=gr.Textbox(label="Answer"), title="VQA Demo - Car Recognition", description="Upload an image of a car and ask a question (e.g., 'What color is this car?' or 'What is this car?')." ) interface.launch()