PawMatchAI / app.py
DawnC's picture
Upload 4 files
406922d
raw
history blame
11 kB
import os
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
#from data_manager import get_dog_description
dog_breeds = ["Afghan_Hound(้˜ฟๅฏŒๆฑ—็ต็Šฌ)", "African_Hunting_Dog(้žๆดฒ้‡Ž็Šฌ)", "Airedale(่‰พ็ˆพ่ฐท็Šฌ)",
"American_Staffordshire_Terrier(็พŽๅœ‹ๆ–ฏๅก”็ฆ้ƒกๆข—)", "Appenzeller(ไบž่ณ“ๆพค็ˆพ็Šฌ)",
"Australian_Terrier(ๆพณๅคงๅˆฉไบžๆข—)", "Bedlington_Terrier(่ฒๅพท้ˆ้ “ๆข—)",
"Bernese_Mountain_Dog(ไผฏๆฉๅฑฑ็Šฌ)", "Blenheim_Spaniel(ๅธƒ่Šๅฐผๅง†็ต็Šฌ)",
"Border_Collie(้‚Šๅขƒ็‰ง็พŠ็Šฌ)", "Border_Terrier(้‚Šๅขƒๆข—)", "Boston_Bull(ๆณขๅฃซ้ “ๆข—)",
"Bouvier_Des_Flandres(ๆณ•่˜ญๅพทๆ–ฏ็‰ง็พŠ็Šฌ)", "Brabancon_Griffon(ๅธƒ้ญฏๅกž็ˆพๆ ผ้‡Œ่Šฌ็Šฌ)",
"Brittany_Spaniel(ๅธƒๅˆ—ๅก”ๅฐผ็ต็Šฌ)", "Cardigan(ๅก่ฟชๆ นๅจ็ˆพๅฃซๆŸฏๅŸบ็Šฌ)",
"Chesapeake_Bay_Retriever(ๅˆ‡่–ฉ็šฎๅ…‹็ฃ็ต็Šฌ)", "Chihuahua(ๅ‰ๅจƒๅจƒ)",
"Dandie_Dinmont(ไธน็ฌฌไธ่’™ๆข—)", "Doberman(ๆœ่ณ“็Šฌ)", "English_Foxhound(่‹ฑๅœ‹็ต็‹็Šฌ)",
"English_Setter(่‹ฑๅœ‹้›ช้”็Šฌ)", "English_Springer(่‹ฑๅœ‹่ทณ็ต็Šฌ)",
"EntleBucher(ๆฉ็‰น้›ทๅธƒ่ตซๅฑฑๅœฐ็Šฌ)", "Eskimo_Dog(ๆ„›ๆ–ฏๅŸบๆ‘ฉ็Šฌ)", "French_Bulldog(ๆณ•ๅœ‹้ฌฅ็‰›็Šฌ)",
"German_Shepherd(ๅพทๅœ‹็‰ง็พŠ็Šฌ)", "German_Short-Haired_Pointer(ๅพทๅœ‹็Ÿญๆฏ›ๆŒ‡็คบ็Šฌ)",
"Gordon_Setter(ๆˆˆ็™ป้›ช้”็Šฌ)", "Great_Dane(ๅคงไธน็Šฌ)", "Great_Pyrenees(ๅคง็™ฝ็†Š็Šฌ)",
"Greater_Swiss_Mountain_Dog(ๅคง็‘žๅฃซๅฑฑๅœฐ็Šฌ)", "Ibizan_Hound(ไพๆฏ”ๆฒ™็ต็Šฌ)",
"Irish_Setter(ๆ„›็ˆพ่˜ญ้›ช้”็Šฌ)", "Irish_Terrier(ๆ„›็ˆพ่˜ญๆข—)",
"Irish_Water_Spaniel(ๆ„›็ˆพ่˜ญๆฐด็ต็Šฌ)", "Irish_Wolfhound(ๆ„›็ˆพ่˜ญ็ต็‹ผ็Šฌ)",
"Italian_Greyhound(็พฉๅคงๅˆฉ็ฐ็‹—)", "Japanese_Spaniel(ๆ—ฅๆœฌ็‹†)",
"Kerry_Blue_Terrier(ๅ‡ฑๅˆฉ่—ๆข—)", "Labrador_Retriever(ๆ‹‰ๅธƒๆ‹‰ๅคšๅฐ‹ๅ›ž็Šฌ)",
"Lakeland_Terrier(ๆน–็•”ๆข—)", "Leonberg(็…ๆฏ›็‹—)", "Lhasa(ๆ‹‰่–ฉ็Šฌ)",
"Maltese_Dog(้ฆฌ็ˆพๆฟŸๆ–ฏ็Šฌ)", "Mexican_Hairless(ๅขจ่ฅฟๅ“ฅ็„กๆฏ›็Šฌ)", "Newfoundland(็ด่Šฌ่˜ญ็Šฌ)",
"Norfolk_Terrier(่ซพ็ฆๅ…‹ๆข—)", "Norwegian_Elkhound(ๆŒชๅจ็ต้บ‹็Šฌ)",
"Norwich_Terrier(่ซพๅˆฉๆฒปๆข—)", "Old_English_Sheepdog(ๅคไปฃ่‹ฑๅœ‹็‰ง็พŠ็Šฌ)",
"Pekinese(ๅŒ—ไบฌ็Šฌ)", "Pembroke(ๅจ็ˆพๅฃซๆŸฏๅŸบ็Šฌ)", "Pomeranian(ๅš็พŽ็Šฌ)",
"Rhodesian_Ridgeback(็พ…ๅพ—่ฅฟไบž่„Š่ƒŒ็Šฌ)", "Rottweiler(็พ…ๅจ็ด็Šฌ)",
"Saint_Bernard(่–ไผฏ็ด็Šฌ)", "Saluki(่–ฉ่ทฏๅŸบ็ต็Šฌ)", "Samoyed(่–ฉๆ‘ฉ่€ถ็Šฌ)",
"Scotch_Terrier(่˜‡ๆ ผ่˜ญๆข—)", "Scottish_Deerhound(่˜‡ๆ ผ่˜ญ็ต้นฟ็Šฌ)",
"Sealyham_Terrier(้Œซๅˆฉๅ“ˆๅง†ๆข—)", "Shetland_Sheepdog(่จญๅพ—่˜ญ็‰ง็พŠ็Šฌ)",
"Shih-Tzu(่ฅฟๆ–ฝ็Šฌ)", "Siberian_Husky(่ฅฟไผฏๅˆฉไบžๅ“ˆๅฃซๅฅ‡)",
"Staffordshire_Bullterrier(ๆ–ฏๅก”็ฆ้ƒก้ฌฅ็‰›ๆข—)", "Sussex_Spaniel(่˜‡ๅกžๅ…‹ๆ–ฏ็ต็Šฌ)",
"Tibetan_Mastiff(่—็’)", "Tibetan_Terrier(่ฅฟ่—ๆข—)", "Walker_Hound(ๆฒƒๅ…‹็ต็Šฌ)",
"Weimaraner(ๅจ็‘ช็Šฌ)", "Welsh_Springer_Spaniel(ๅจ็ˆพๅฃซ่ทณ็ต็Šฌ)",
"West_Highland_White_Terrier(่ฅฟ้ซ˜ๅœฐ็™ฝๆข—)", "Yorkshire_Terrier(็ด„ๅ…‹ๅคๆข—)",
"Affenpinscher(็Œด็Šฌ)", "Basenji(ๅทด่พ›ๅ‰็Šฌ)", "Basset(ๅทดๅ‰ๅบฆ็ต็Šฌ)", "Beagle(ๆฏ”ๆ ผ็Šฌ)",
"Black-and-Tan_Coonhound(้ป‘่ค็ตๆตฃ็†Š็Šฌ)", "Bloodhound(ๅฐ‹่ก€็ต็Šฌ)",
"Bluetick(ๅธƒ้ญฏๆๅ…‹็ต็Šฌ)", "Borzoi(ไฟ„็พ…ๆ–ฏ็ต็‹ผ็Šฌ)", "Boxer(ๆ‹ณๅธซ็Šฌ)", "Briard(ๅธƒ้‡Œไบž็Šฌ)",
"Bull_Mastiff(็’็Šฌ)", "Cairn(ๅ‡ฑๆฉๆข—)", "Chow(้ฌ†็…็Šฌ)", "Clumber(ๅ…‹ๅ€ซไผฏ็ต็Šฌ)",
"Cocker_Spaniel(ๅฏๅก็ต็Šฌ)", "Collie(ๆŸฏๅˆฉ็‰ง็พŠ็Šฌ)", "Curly-Coated_Retriever(ๆฒๆฏ›ๅฐ‹ๅ›ž็Šฌ)",
"Dhole(่ฑบ)", "Dingo(ๆพณๆดฒ้‡Ž็Šฌ)", "Flat-Coated_Retriever(ๅนณๆฏ›ๅฐ‹ๅ›ž็Šฌ)",
"Giant_Schnauzer(ๅคงๅž‹้›ช็ด็‘ž็Šฌ)", "Golden_Retriever(้ปƒ้‡‘็ต็Šฌ)",
"Groenendael(ๆฏ”ๅˆฉๆ™‚็‰ง็พŠ็Šฌ)", "Keeshond(่ท่˜ญๆฏ›็…็Šฌ)", "Kelpie(ๆพณๆดฒๅก็ˆพๆฏ”็Šฌ)",
"Komondor(ๅŒˆ็‰™ๅˆฉ็‰ง็พŠ็Šฌ)", "Kuvasz(ๅบซ็“ฆ่Œฒ็Šฌ)", "Malamute(้˜ฟๆ‹‰ๆ–ฏๅŠ ้›ชๆฉ‡็Šฌ)",
"Malinois(ๆฏ”ๅˆฉๆ™‚็‘ชๅˆฉ่ซพ็Šฌ)", "Miniature_Pinscher(่ฟทไฝ ๆœ่ณ“็Šฌ)",
"Miniature_Poodle(่ฟทไฝ ่ฒด่ณ“็Šฌ)", "Miniature_Schnauzer(่ฟทไฝ ้›ช็ด็‘ž็Šฌ)",
"Otterhound(ๆฐด็บ็ต็Šฌ)", "Papillon(่ด่ถ็Šฌ)", "Pug(ๅทดๅ“ฅ็Šฌ)", "Redbone(็ด…้ชจ็ตๆตฃ็†Š็Šฌ)",
"Schipperke(่ˆ’ๆŸๅฅ‡็Šฌ)", "Silky_Terrier(็ตฒๆฏ›ๆข—)",
"Soft-Coated_Wheaten_Terrier(ๆ„›็ˆพ่˜ญ่ปŸๆฏ›ๆข—)", "Standard_Poodle(ๆจ™ๆบ–่ฒด่ณ“็Šฌ)",
"Standard_Schnauzer(ๆจ™ๆบ–้›ช็ด็‘ž็Šฌ)", "Toy_Poodle(็Žฉๅ…ท่ฒด่ณ“็Šฌ)", "Toy_Terrier(็Žฉๅ…ทๆข—)",
"Vizsla(็ถญ่Œฒๆ‹‰็Šฌ)", "Whippet(ๆƒ ๆฏ”็‰น็Šฌ)", "Wire-Haired_Fox_Terrier(็กฌๆฏ›็ต็‹ๆข—)"]
class MultiHeadAttention(nn.Module):
def __init__(self, in_dim, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.head_dim = max(1, in_dim // num_heads)
self.scaled_dim = self.head_dim * num_heads
self.fc_in = nn.Linear(in_dim, self.scaled_dim)
self.query = nn.Linear(self.scaled_dim, self.scaled_dim)
self.key = nn.Linear(self.scaled_dim, self.scaled_dim)
self.value = nn.Linear(self.scaled_dim, self.scaled_dim)
self.fc_out = nn.Linear(self.scaled_dim, in_dim)
def forward(self, x):
N = x.shape[0]
x = self.fc_in(x)
q = self.query(x).view(N, self.num_heads, self.head_dim)
k = self.key(x).view(N, self.num_heads, self.head_dim)
v = self.value(x).view(N, self.num_heads, self.head_dim)
energy = torch.einsum("nqd,nkd->nqk", [q, k])
attention = F.softmax(energy / (self.head_dim ** 0.5), dim=2)
out = torch.einsum("nqk,nvd->nqd", [attention, v])
out = out.reshape(N, self.scaled_dim)
out = self.fc_out(out)
return out
class BaseModel(nn.Module):
def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
super().__init__()
self.device = device
self.backbone = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
self.feature_dim = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Identity()
# ๅ‹•ๆ…‹่จˆ็ฎ— num_heads
self.num_heads = max(1, min(8, self.feature_dim // 64))
self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
self.classifier = nn.Sequential(
nn.LayerNorm(self.feature_dim),
nn.Dropout(0.3),
nn.Linear(self.feature_dim, num_classes)
)
self.to(device)
def forward(self, x):
x = x.to(self.device)
features = self.backbone(x)
attended_features = self.attention(features)
logits = self.classifier(attended_features)
return logits, attended_features
num_classes = 120
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BaseModel(num_classes=num_classes, device=device)
checkpoint = torch.load('/content/drive/Othercomputers/ๆˆ‘็š„ MacBook Pro/Learning/Cats_Dogs_Detector/best_model/best_model_81_dog.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
# ๅฐ‡ๆจกๅž‹่จญ็ฝฎ็‚บ่ฉ•ไผฐๆจกๅผ
model.eval()
# Image preprocessing function
def preprocess_image(image):
# ๅฆ‚ๆžœๅœ–็‰‡ๆ˜ฏ numpy.ndarray ่ฝ‰ๆ›็‚บ PIL.Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# ไฝฟ็”จ torchvision.transforms ้€ฒ่กŒ้ ่™•็†
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)
def predict(image):
try:
image_tensor = preprocess_image(image)
with torch.no_grad():
logits, _ = model(image_tensor)
_, predicted = torch.max(logits, 1)
breed = dog_breeds[predicted.item()] # Map label to breed name
# Retrieve breed description
description = get_dog_description(breed)
# Formatting the description for better display
if isinstance(description, dict):
description_str = f"**Breed**: {description['Breed']}\n\n"
description_str += f"**Size**: {description['Size']}\n\n"
description_str += f"**Lifespan**: {description['Lifespan']}\n\n"
description_str += f"**Temperament**: {description['Temperament']}\n\n"
description_str += f"**Care Level**: {description['Care Level']}\n\n"
description_str += f"**Good with Children**: {description['Good with Children']}\n\n"
description_str += f"**Exercise Needs**: {description['Exercise Needs']}\n\n"
description_str += f"**Grooming Needs**: {description['Grooming Needs']}\n\n"
description_str += f"**Description**: {description['Description']}\n\n"
else:
description_str = description
return description_str
except Exception as e:
return f"An error occurred: {e}"
iface = gr.Interface(
fn=predict,
inputs=gr.Image(label="Upload an image", type="numpy"), # ๆ”ฏๆŒๆ‹–ๆ”พๅ’Œๅœ–็‰‡็ทจ่ผฏ
outputs="markdown",
title="<span style='font-family:Roboto; font-weight:bold; color:#2C3E50;'>Dog Breed Classifier</span>",
description="<span style='font-family:Open Sans; color:#34495E;'>Upload an image, and the system will predict the breed and provide detailed information from the database.</span>",
examples=['/content/drive/Othercomputers/ๆˆ‘็š„ MacBook Pro/Learning/Cats_Dogs_Detector/Border_Collie.jpg',
'/content/drive/Othercomputers/ๆˆ‘็š„ MacBook Pro/Learning/Cats_Dogs_Detector/Golden_Retriever.jpeg',
'/content/drive/Othercomputers/ๆˆ‘็š„ MacBook Pro/Learning/Cats_Dogs_Detector/Saint_Bernard.jpeg',
'/content/drive/Othercomputers/ๆˆ‘็š„ MacBook Pro/Learning/Cats_Dogs_Detector/French_Bulldog.jpeg',
'/content/drive/Othercomputers/ๆˆ‘็š„ MacBook Pro/Learning/Cats_Dogs_Detector/Samoyed.jpg'],
css="""
.output-markdown {
font-family: Noto Sans, sans-serif;
line-height: 1.6;
}
.gr-button {
background-color: #3498DB;
color: white;
border-radius: 8px;
box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.2);
padding: 10px 20px;
}
.gr-button:hover {
background-color: #2980B9;
}
.gr-box {
background: linear-gradient(to bottom, #f2f4f5, #ffffff);
border-radius: 10px;
padding: 20px;
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
}
""",
theme="default"
)
# Launch the app
if __name__ == "__main__":
iface.launch()