Bird Captioning and Classification Model (CUB-200-2011)
This is a fine-tuned VisionEncoderDecoderModel based on nlpconnect/vit-gpt2-image-captioning
, trained on the CUB-200-2011 dataset for bird species classification and image captioning.
Model Description
- Base Model: ViT-GPT2 (
nlpconnect/vit-gpt2-image-captioning
) - Tasks:
- Generates descriptive captions for bird images, including species and attributes.
- Classifies images into one of 200 bird species.
- Dataset: CUB-200-2011 (11,788 images, 200 bird species)
- Training: 10 epochs, batch size 16, mixed precision, AdamW optimizer (lr=3e-5), combined loss (caption + 0.5 * classification).
- Best Validation Loss: 0.0690 (Epoch 3)
Files
model.safetensors
: Trained model weightsconfig.json
: Model configurationpreprocessor_config.json
: ViTImageProcessor settingstokenizer_config.json
,vocab.json
: GPT2 tokenizer filesspecies_mapping.txt
: Mapping of class indices to bird species namescub200_captions.csv
: Generated captions for the datasetmodel.py
: CustomBirdCaptioningModel
class definition
Usage
Prerequisites
pip install transformers torch huggingface_hub
Load Model and Dependencies
from transformers import ViTImageProcessor, AutoTokenizer
from huggingface_hub import PyTorchModelHubMixin
import torch
from model import BirdCaptioningModel # Save model.py locally
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = BirdCaptioningModel.from_pretrained("INVERTO/bird-captioning-cub200").to(device)
image_processor = ViTImageProcessor.from_pretrained("INVERTO/bird-captioning-cub200")
tokenizer = AutoTokenizer.from_pretrained("INVERTO/bird-captioning-cub200")
model.eval()
# Load species mapping
species_mapping = {}
with open("species_mapping.txt", "r") as f:
for line in f:
idx, name = line.strip().split(",", 1)
species_mapping[int(idx)] = name
Inference
from PIL import Image
def predict_bird_image(image_path):
image = Image.open(image_path).convert("RGB")
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
output_ids = model.base_model.generate(pixel_values, max_length=75, num_beams=4)
_, class_logits = model(pixel_values)
predicted_class_idx = torch.argmax(class_logits, dim=1).item()
confidence = torch.nn.functional.softmax(class_logits, dim=1)[0, predicted_class_idx].item() * 100
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
species = species_mapping.get(predicted_class_idx, "Unknown")
return caption, species, confidence
# Example
caption, species, confidence = predict_bird_image("/kaggle/input/cub2002011/CUB_200_2011/images/006.Least_Auklet/Least_Auklet_0007_795123.jpg")
print(f"Caption: {caption}")
print(f"Species: {species}")
print(f"Confidence: {confidence:.2f}%")
Dataset
- CUB-200-2011: 11,788 images of 200 bird species with attribute annotations.
- Captions were generated based on species names and attributes (e.g., bill shape, wing color).
Training Details
- Loss: Combined captioning (CrossEntropy) and classification (CrossEntropy) loss.
- Optimizer: AdamW (lr=3e-5)
- Scheduler: CosineAnnealingLR
- Hardware: GPU (CUDA)
- Training Time: ~5 min/epoch
Limitations
- May overfit after Epoch 3 (validation loss increases).
- Captions are limited to species and up to 5 attributes.
- Classification accuracy not explicitly reported.
License
MIT License
Contact
For issues, contact INVERTO on Hugging Face.
- Downloads last month
- 47
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support