import gradio as gr import torch from torch import nn import cv2 import numpy as np import json from torchvision import models import librosa # Define the BirdCallRNN model class BirdCallRNN(nn.Module): def __init__(self, resnet, num_features, num_classes): super(BirdCallRNN, self).__init__() self.resnet = resnet self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True) self.fc = nn.Linear(512, num_classes) def forward(self, x): batch, seq_len, C, H, W = x.size() x = x.view(batch * seq_len, C, H, W) features = self.resnet(x) features = features.view(batch, seq_len, -1) rnn_out, _ = self.rnn(features) output = self.fc(rnn_out[:, -1, :]) # Note: We’ll use this for single-segment sequences return output # Function to convert MP3 to mel spectrogram (unchanged) def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)): y, sr = librosa.load(mp3_file, sr=None) S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000) log_S = librosa.power_to_db(S, ref=np.max) current_time_steps = log_S.shape[1] target_time_steps = target_shape[1] if current_time_steps < target_time_steps: pad_width = target_time_steps - current_time_steps log_S_resized = np.pad(log_S, ((0, 0), (0, pad_width)), mode='constant') elif current_time_steps > target_time_steps: log_S_resized = log_S[:, :target_time_steps] else: log_S_resized = log_S log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC) return log_S_resized # Load class mapping globally with open('class_mapping.json', 'r') as f: class_names = json.load(f) # Revised inference function to predict per segment def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"): model.eval() # Load audio and compute mel spectrogram y, sr = librosa.load(mp3_file, sr=None) S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000) log_S = librosa.power_to_db(S, ref=np.max) # Segment the spectrogram num_segments = log_S.shape[1] // segment_length if num_segments == 0: segments = [log_S] else: segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)] predictions = [] # Process each segment individually for seg in segments: seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC) seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1) # Create a tensor with batch size 1 and sequence length 1 seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224) output = model(seg_tensor) pred = torch.max(output, dim=1)[1].cpu().numpy()[0] predicted_bird = class_names[str(pred)] # Convert pred to string to match JSON keys predictions.append(predicted_bird) return predictions # Initialize the model resnet = models.resnet50(weights='IMAGENET1K_V2') num_features = resnet.fc.in_features resnet.fc = nn.Identity() num_classes = len(class_names) # Should be 114 model = BirdCallRNN(resnet, num_features, num_classes) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.load_state_dict(torch.load('model_weights.pth', map_location=device)) model.eval() # Prediction function for Gradio def predict_bird(file_path): predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device)) # Format predictions as a numbered list formatted_predictions = "\n".join([f"{i+1}. {pred}" for i, pred in enumerate(predictions)]) return formatted_predictions # Return formatted list of predictions # Custom Gradio interface with additional components def gradio_interface(file_path): # Predict bird species prediction = predict_bird(file_path) # Display the uploaded MP3 file with a play button audio_player = gr.Audio(file_path, label="Uploaded MP3 File", visible=True, autoplay=True) # Display images with titles bird_species_image = gr.Image("1.jpg", label="Bird Species") bird_description_image = gr.Image("2.jpg", label="Bird Description") bird_origins_image = gr.Image("3.jpg", label="Bird Origins") return prediction, audio_player, bird_species_image, bird_description_image, bird_origins_image # Launch Gradio interface interface = gr.Interface( fn=gradio_interface, inputs=gr.File(label="Upload MP3 file", file_types=['.mp3']), outputs=[ gr.Textbox(label="Predicted Bird Species"), gr.Audio(label="Uploaded MP3 File"), gr.Image(label="Bird Species"), gr.Image(label="Bird Description"), gr.Image(label="Bird Origins") ] ) interface.launch()