fevot commited on
Commit
60eeb55
·
verified ·
1 Parent(s): 1526231

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -18
app.py CHANGED
@@ -7,7 +7,7 @@ import json
7
  from torchvision import models
8
  import librosa
9
 
10
- # Define the BirdCallRNN model (unchanged)
11
  class BirdCallRNN(nn.Module):
12
  def __init__(self, resnet, num_features, num_classes):
13
  super(BirdCallRNN, self).__init__()
@@ -21,7 +21,7 @@ class BirdCallRNN(nn.Module):
21
  features = self.resnet(x)
22
  features = features.view(batch, seq_len, -1)
23
  rnn_out, _ = self.rnn(features)
24
- output = self.fc(rnn_out[:, -1, :])
25
  return output
26
 
27
  # Function to convert MP3 to mel spectrogram (unchanged)
@@ -45,7 +45,7 @@ def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224,
45
  with open('class_mapping.json', 'r') as f:
46
  class_names = json.load(f)
47
 
48
- # Revised inference function to include confidence scores
49
  def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
50
  model.eval()
51
  # Load audio and compute mel spectrogram
@@ -67,13 +67,9 @@ def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
67
  # Create a tensor with batch size 1 and sequence length 1
68
  seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
69
  output = model(seg_tensor)
70
- # Apply softmax to get probabilities
71
- probs = torch.softmax(output, dim=1)
72
- confidence, pred_idx = torch.max(probs, dim=1)
73
- pred_idx = pred_idx.cpu().numpy()[0]
74
- confidence = confidence.cpu().numpy()[0]
75
- predicted_bird = class_names[str(pred_idx)]
76
- predictions.append((predicted_bird, confidence))
77
  return predictions
78
 
79
  # Initialize the model
@@ -87,20 +83,20 @@ model.to(device)
87
  model.load_state_dict(torch.load('model_weights.pth', map_location=device))
88
  model.eval()
89
 
90
- # Prediction function with confidence scores
91
  def predict_bird(file_path):
92
  predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device))
93
- # Format predictions as a numbered list with confidence scores
94
- formatted_predictions = "\n".join([f"{i+1}. {pred} (Confidence: {conf*100:.2f}%)" for i, (pred, conf) in enumerate(predictions)])
95
- return formatted_predictions
96
 
97
- # Custom Gradio interface
98
  def gradio_interface(file_path):
99
- # Predict bird species with confidence
100
  prediction = predict_bird(file_path)
101
 
102
  # Display the uploaded MP3 file with a play button
103
- audio_player = gr.Audio(file_path, label="Uploaded MP3 File", visible=True, autoplay=False)
104
 
105
  # Display images with titles
106
  bird_species_image = gr.Image("1.jpg", label="Bird Species")
@@ -121,4 +117,4 @@ interface = gr.Interface(
121
  gr.Image(label="Bird Origins")
122
  ]
123
  )
124
- interface.launch(share=True)
 
7
  from torchvision import models
8
  import librosa
9
 
10
+ # Define the BirdCallRNN model
11
  class BirdCallRNN(nn.Module):
12
  def __init__(self, resnet, num_features, num_classes):
13
  super(BirdCallRNN, self).__init__()
 
21
  features = self.resnet(x)
22
  features = features.view(batch, seq_len, -1)
23
  rnn_out, _ = self.rnn(features)
24
+ output = self.fc(rnn_out[:, -1, :]) # Note: We’ll use this for single-segment sequences
25
  return output
26
 
27
  # Function to convert MP3 to mel spectrogram (unchanged)
 
45
  with open('class_mapping.json', 'r') as f:
46
  class_names = json.load(f)
47
 
48
+ # Revised inference function to predict per segment
49
  def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
50
  model.eval()
51
  # Load audio and compute mel spectrogram
 
67
  # Create a tensor with batch size 1 and sequence length 1
68
  seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
69
  output = model(seg_tensor)
70
+ pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
71
+ predicted_bird = class_names[str(pred)] # Convert pred to string to match JSON keys
72
+ predictions.append(predicted_bird)
 
 
 
 
73
  return predictions
74
 
75
  # Initialize the model
 
83
  model.load_state_dict(torch.load('model_weights.pth', map_location=device))
84
  model.eval()
85
 
86
+ # Prediction function for Gradio
87
  def predict_bird(file_path):
88
  predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device))
89
+ # Format predictions as a numbered list
90
+ formatted_predictions = "\n".join([f"{i+1}. {pred}" for i, pred in enumerate(predictions)])
91
+ return formatted_predictions # Return formatted list of predictions
92
 
93
+ # Custom Gradio interface with additional components
94
  def gradio_interface(file_path):
95
+ # Predict bird species
96
  prediction = predict_bird(file_path)
97
 
98
  # Display the uploaded MP3 file with a play button
99
+ audio_player = gr.Audio(file_path, label="Uploaded MP3 File", visible=True, autoplay=True)
100
 
101
  # Display images with titles
102
  bird_species_image = gr.Image("1.jpg", label="Bird Species")
 
117
  gr.Image(label="Bird Origins")
118
  ]
119
  )
120
+ interface.launch()