Bird Captioning and Classification Model (CUB-200-2011)

This is a fine-tuned VisionEncoderDecoderModel based on nlpconnect/vit-gpt2-image-captioning, trained on the CUB-200-2011 dataset for bird species classification and image captioning.

Model Description

  • Base Model: ViT-GPT2 (nlpconnect/vit-gpt2-image-captioning)
  • Tasks:
    • Generates descriptive captions for bird images, including species and attributes.
    • Classifies images into one of 200 bird species.
  • Dataset: CUB-200-2011 (11,788 images, 200 bird species)
  • Training: 10 epochs, batch size 16, mixed precision, AdamW optimizer (lr=3e-5), combined loss (caption + 0.5 * classification).
  • Best Validation Loss: 0.0690 (Epoch 3)

Files

  • model.safetensors: Trained model weights
  • config.json: Model configuration
  • preprocessor_config.json: ViTImageProcessor settings
  • tokenizer_config.json, vocab.json: GPT2 tokenizer files
  • species_mapping.txt: Mapping of class indices to bird species names
  • cub200_captions.csv: Generated captions for the dataset
  • model.py: Custom BirdCaptioningModel class definition

Usage

Prerequisites

pip install transformers torch huggingface_hub

Load Model and Dependencies

from transformers import ViTImageProcessor, AutoTokenizer
from huggingface_hub import PyTorchModelHubMixin
import torch
from model import BirdCaptioningModel  # Save model.py locally

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = BirdCaptioningModel.from_pretrained("INVERTO/bird-captioning-cub200").to(device)
image_processor = ViTImageProcessor.from_pretrained("INVERTO/bird-captioning-cub200")
tokenizer = AutoTokenizer.from_pretrained("INVERTO/bird-captioning-cub200")
model.eval()

# Load species mapping
species_mapping = {}
with open("species_mapping.txt", "r") as f:
    for line in f:
        idx, name = line.strip().split(",", 1)
        species_mapping[int(idx)] = name

Inference

from PIL import Image

def predict_bird_image(image_path):
    image = Image.open(image_path).convert("RGB")
    pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
    with torch.no_grad():
        output_ids = model.base_model.generate(pixel_values, max_length=75, num_beams=4)
        _, class_logits = model(pixel_values)
        predicted_class_idx = torch.argmax(class_logits, dim=1).item()
        confidence = torch.nn.functional.softmax(class_logits, dim=1)[0, predicted_class_idx].item() * 100
        caption = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
        species = species_mapping.get(predicted_class_idx, "Unknown")
    return caption, species, confidence

# Example
caption, species, confidence = predict_bird_image("/kaggle/input/cub2002011/CUB_200_2011/images/006.Least_Auklet/Least_Auklet_0007_795123.jpg")
print(f"Caption: {caption}")
print(f"Species: {species}")
print(f"Confidence: {confidence:.2f}%")

Dataset

  • CUB-200-2011: 11,788 images of 200 bird species with attribute annotations.
  • Captions were generated based on species names and attributes (e.g., bill shape, wing color).

Training Details

  • Loss: Combined captioning (CrossEntropy) and classification (CrossEntropy) loss.
  • Optimizer: AdamW (lr=3e-5)
  • Scheduler: CosineAnnealingLR
  • Hardware: GPU (CUDA)
  • Training Time: ~5 min/epoch

Limitations

  • May overfit after Epoch 3 (validation loss increases).
  • Captions are limited to species and up to 5 attributes.
  • Classification accuracy not explicitly reported.

License

MIT License

Contact

For issues, contact INVERTO on Hugging Face.

Downloads last month
47
Safetensors
Model size
239M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support