fevot commited on
Commit
1526231
·
verified ·
1 Parent(s): ad952f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -13
app.py CHANGED
@@ -7,7 +7,7 @@ import json
7
  from torchvision import models
8
  import librosa
9
 
10
- # Define the BirdCallRNN model
11
  class BirdCallRNN(nn.Module):
12
  def __init__(self, resnet, num_features, num_classes):
13
  super(BirdCallRNN, self).__init__()
@@ -21,7 +21,7 @@ class BirdCallRNN(nn.Module):
21
  features = self.resnet(x)
22
  features = features.view(batch, seq_len, -1)
23
  rnn_out, _ = self.rnn(features)
24
- output = self.fc(rnn_out[:, -1, :]) # Note: We’ll use this for single-segment sequences
25
  return output
26
 
27
  # Function to convert MP3 to mel spectrogram (unchanged)
@@ -45,12 +45,14 @@ def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224,
45
  with open('class_mapping.json', 'r') as f:
46
  class_names = json.load(f)
47
 
48
- # Revised inference function to predict per segment with confidence scores
49
  def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
50
  model.eval()
 
51
  y, sr = librosa.load(mp3_file, sr=None)
52
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
53
  log_S = librosa.power_to_db(S, ref=np.max)
 
54
  num_segments = log_S.shape[1] // segment_length
55
  if num_segments == 0:
56
  segments = [log_S]
@@ -58,43 +60,53 @@ def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
58
  segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
59
 
60
  predictions = []
 
61
  for seg in segments:
62
  seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
63
  seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
64
- seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device)
 
65
  output = model(seg_tensor)
 
66
  probs = torch.softmax(output, dim=1)
67
- confidence, pred = torch.max(probs, dim=1)
68
- pred = pred.cpu().numpy()[0]
69
  confidence = confidence.cpu().numpy()[0]
70
- predicted_bird = class_names.get(str(pred), "Unknown")
71
- predictions.append(f"{predicted_bird} ({confidence:.2%} confidence)")
72
  return predictions
73
 
74
  # Initialize the model
75
  resnet = models.resnet50(weights='IMAGENET1K_V2')
76
  num_features = resnet.fc.in_features
77
  resnet.fc = nn.Identity()
78
- num_classes = len(class_names)
79
  model = BirdCallRNN(resnet, num_features, num_classes)
80
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
  model.to(device)
82
  model.load_state_dict(torch.load('model_weights.pth', map_location=device))
83
  model.eval()
84
 
85
- # Prediction function for Gradio
86
  def predict_bird(file_path):
87
  predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device))
88
- formatted_predictions = "\n".join([f"{i+1}. {pred}" for i, pred in enumerate(predictions)])
 
89
  return formatted_predictions
90
 
91
  # Custom Gradio interface
92
  def gradio_interface(file_path):
 
93
  prediction = predict_bird(file_path)
94
- audio_player = gr.Audio(file_path, label="Uploaded MP3 File", visible=True, autoplay=True)
 
 
 
 
95
  bird_species_image = gr.Image("1.jpg", label="Bird Species")
96
  bird_description_image = gr.Image("2.jpg", label="Bird Description")
97
  bird_origins_image = gr.Image("3.jpg", label="Bird Origins")
 
98
  return prediction, audio_player, bird_species_image, bird_description_image, bird_origins_image
99
 
100
  # Launch Gradio interface
@@ -109,4 +121,4 @@ interface = gr.Interface(
109
  gr.Image(label="Bird Origins")
110
  ]
111
  )
112
- interface.launch(share=True)
 
7
  from torchvision import models
8
  import librosa
9
 
10
+ # Define the BirdCallRNN model (unchanged)
11
  class BirdCallRNN(nn.Module):
12
  def __init__(self, resnet, num_features, num_classes):
13
  super(BirdCallRNN, self).__init__()
 
21
  features = self.resnet(x)
22
  features = features.view(batch, seq_len, -1)
23
  rnn_out, _ = self.rnn(features)
24
+ output = self.fc(rnn_out[:, -1, :])
25
  return output
26
 
27
  # Function to convert MP3 to mel spectrogram (unchanged)
 
45
  with open('class_mapping.json', 'r') as f:
46
  class_names = json.load(f)
47
 
48
+ # Revised inference function to include confidence scores
49
  def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
50
  model.eval()
51
+ # Load audio and compute mel spectrogram
52
  y, sr = librosa.load(mp3_file, sr=None)
53
  S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
54
  log_S = librosa.power_to_db(S, ref=np.max)
55
+ # Segment the spectrogram
56
  num_segments = log_S.shape[1] // segment_length
57
  if num_segments == 0:
58
  segments = [log_S]
 
60
  segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
61
 
62
  predictions = []
63
+ # Process each segment individually
64
  for seg in segments:
65
  seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
66
  seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
67
+ # Create a tensor with batch size 1 and sequence length 1
68
+ seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
69
  output = model(seg_tensor)
70
+ # Apply softmax to get probabilities
71
  probs = torch.softmax(output, dim=1)
72
+ confidence, pred_idx = torch.max(probs, dim=1)
73
+ pred_idx = pred_idx.cpu().numpy()[0]
74
  confidence = confidence.cpu().numpy()[0]
75
+ predicted_bird = class_names[str(pred_idx)]
76
+ predictions.append((predicted_bird, confidence))
77
  return predictions
78
 
79
  # Initialize the model
80
  resnet = models.resnet50(weights='IMAGENET1K_V2')
81
  num_features = resnet.fc.in_features
82
  resnet.fc = nn.Identity()
83
+ num_classes = len(class_names) # Should be 114
84
  model = BirdCallRNN(resnet, num_features, num_classes)
85
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
  model.to(device)
87
  model.load_state_dict(torch.load('model_weights.pth', map_location=device))
88
  model.eval()
89
 
90
+ # Prediction function with confidence scores
91
  def predict_bird(file_path):
92
  predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device))
93
+ # Format predictions as a numbered list with confidence scores
94
+ formatted_predictions = "\n".join([f"{i+1}. {pred} (Confidence: {conf*100:.2f}%)" for i, (pred, conf) in enumerate(predictions)])
95
  return formatted_predictions
96
 
97
  # Custom Gradio interface
98
  def gradio_interface(file_path):
99
+ # Predict bird species with confidence
100
  prediction = predict_bird(file_path)
101
+
102
+ # Display the uploaded MP3 file with a play button
103
+ audio_player = gr.Audio(file_path, label="Uploaded MP3 File", visible=True, autoplay=False)
104
+
105
+ # Display images with titles
106
  bird_species_image = gr.Image("1.jpg", label="Bird Species")
107
  bird_description_image = gr.Image("2.jpg", label="Bird Description")
108
  bird_origins_image = gr.Image("3.jpg", label="Bird Origins")
109
+
110
  return prediction, audio_player, bird_species_image, bird_description_image, bird_origins_image
111
 
112
  # Launch Gradio interface
 
121
  gr.Image(label="Bird Origins")
122
  ]
123
  )
124
+ interface.launch(share=True)