|
import gradio as gr |
|
import torch |
|
import librosa |
|
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor |
|
|
|
model_name = "greenarcade/wav2vec2-vd-bird-sound-classification" |
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name) |
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) |
|
|
|
def predict(audio_file): |
|
|
|
audio, sr = librosa.load(audio_file, sr=16000) |
|
|
|
|
|
inputs = feature_extractor( |
|
audio, |
|
sampling_rate=16000, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=16000 * 5, |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
probs = torch.softmax(logits, dim=-1).squeeze().tolist() |
|
|
|
|
|
predictions = {model.config.id2label[i]: prob for i, prob in enumerate(probs)} |
|
sorted_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:3] |
|
return {k: v for k, v in sorted_preds} |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Audio(sources=["upload"], type="filepath"), |
|
outputs=gr.Label(num_top_classes=3), |
|
title="🦜 Bird Sound Classifier (Indian birds)", |
|
description="Upload a 5-second audio clip to identify bird species", |
|
examples=[["greyheron-sample.wav"], ["blue-tail-sample.mp3"]] |
|
) |
|
|
|
demo.launch() |