fevot commited on
Commit
50ea088
·
verified ·
1 Parent(s): 404eacf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -54
app.py CHANGED
@@ -6,8 +6,11 @@ import numpy as np
6
  import json
7
  from torchvision import models
8
  import librosa
 
 
 
9
 
10
- # Define the BirdCallRNN model
11
  class BirdCallRNN(nn.Module):
12
  def __init__(self, resnet, num_features, num_classes):
13
  super(BirdCallRNN, self).__init__()
@@ -21,77 +24,87 @@ class BirdCallRNN(nn.Module):
21
  features = self.resnet(x)
22
  features = features.view(batch, seq_len, -1)
23
  rnn_out, _ = self.rnn(features)
24
- output = self.fc(rnn_out[:, -1, :]) # Note: We’ll use this for single-segment sequences
25
  return output
26
 
27
- # Function to convert MP3 to mel spectrogram (unchanged)
28
- def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
29
- y, sr = librosa.load(mp3_file, sr=None)
30
- S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
31
- log_S = librosa.power_to_db(S, ref=np.max)
32
- current_time_steps = log_S.shape[1]
33
- target_time_steps = target_shape[1]
34
- if current_time_steps < target_time_steps:
35
- pad_width = target_time_steps - current_time_steps
36
- log_S_resized = np.pad(log_S, ((0, 0), (0, pad_width)), mode='constant')
37
- elif current_time_steps > target_time_steps:
38
- log_S_resized = log_S[:, :target_time_steps]
39
- else:
40
- log_S_resized = log_S
41
- log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
42
- return log_S_resized
43
 
44
- # Load class mapping globally
45
  with open('class_mapping.json', 'r') as f:
46
  class_names = json.load(f)
47
 
48
- # Revised inference function to predict per segment
49
- def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
50
- model.eval()
51
- # Load audio and compute mel spectrogram
52
- y, sr = librosa.load(mp3_file, sr=None)
53
- S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
54
- log_S = librosa.power_to_db(S, ref=np.max)
55
- # Segment the spectrogram
56
- num_segments = log_S.shape[1] // segment_length
57
- if num_segments == 0:
58
- segments = [log_S]
59
- else:
60
- segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
61
-
62
- predictions = []
63
- # Process each segment individually
64
- for seg in segments:
65
- seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
66
- seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
67
- # Create a tensor with batch size 1 and sequence length 1
68
- seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
69
- output = model(seg_tensor)
70
- pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
71
- predicted_bird = class_names[str(pred)] # Convert pred to string to match JSON keys
72
- predictions.append(predicted_bird)
73
- return predictions
74
-
75
  # Initialize the model
76
  resnet = models.resnet50(weights='IMAGENET1K_V2')
77
  num_features = resnet.fc.in_features
78
  resnet.fc = nn.Identity()
79
- num_classes = len(class_names) # Should be 114
80
  model = BirdCallRNN(resnet, num_features, num_classes)
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
  model.to(device)
83
  model.load_state_dict(torch.load('model_weights.pth', map_location=device))
84
  model.eval()
85
 
86
- # Prediction function for Gradio
87
- def predict_bird(file_path):
88
- predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device))
89
- return ", ".join(predictions) # Join predictions into a single string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- # Launch Gradio interface
92
  interface = gr.Interface(
93
  fn=predict_bird,
94
- inputs=gr.File(label="Upload MP3 file", file_types=['.mp3']),
95
- outputs=gr.Textbox(label="Predicted Bird Species")
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
  interface.launch()
 
6
  import json
7
  from torchvision import models
8
  import librosa
9
+ import matplotlib.pyplot as plt
10
+ from io import BytesIO
11
+ import PIL.Image
12
 
13
+ # Define the BirdCallRNN model class
14
  class BirdCallRNN(nn.Module):
15
  def __init__(self, resnet, num_features, num_classes):
16
  super(BirdCallRNN, self).__init__()
 
24
  features = self.resnet(x)
25
  features = features.view(batch, seq_len, -1)
26
  rnn_out, _ = self.rnn(features)
27
+ output = self.fc(rnn_out[:, -1, :])
28
  return output
29
 
30
+ # Function to plot mel spectrogram
31
+ def plot_spectrogram(log_S, sr):
32
+ fig, ax = plt.subplots(figsize=(10, 4))
33
+ img = librosa.display.specshow(log_S, sr=sr, x_axis='time', y_axis='mel', ax=ax)
34
+ fig.colorbar(img, ax=ax, format='%+2.0f dB')
35
+ ax.set_title('Mel Spectrogram')
36
+ buf = BytesIO()
37
+ plt.savefig(buf, format='png')
38
+ buf.seek(0)
39
+ img = PIL.Image.open(buf)
40
+ plt.close(fig)
41
+ return img
 
 
 
 
42
 
43
+ # Load class mapping
44
  with open('class_mapping.json', 'r') as f:
45
  class_names = json.load(f)
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Initialize the model
48
  resnet = models.resnet50(weights='IMAGENET1K_V2')
49
  num_features = resnet.fc.in_features
50
  resnet.fc = nn.Identity()
51
+ num_classes = len(class_names)
52
  model = BirdCallRNN(resnet, num_features, num_classes)
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  model.to(device)
55
  model.load_state_dict(torch.load('model_weights.pth', map_location=device))
56
  model.eval()
57
 
58
+ # Prediction function
59
+ def predict_bird(audio):
60
+ # Load audio file
61
+ y, sr = librosa.load(audio, sr=None)
62
+ S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
63
+ log_S = librosa.power_to_db(S, ref=np.max)
64
+
65
+ # Generate spectrogram image
66
+ spectrogram_img = plot_spectrogram(log_S, sr)
67
+
68
+ # Segment audio and predict
69
+ predictions = []
70
+ segment_length = 500
71
+ num_segments = log_S.shape[1] // segment_length
72
+ segments = [log_S] if num_segments == 0 else [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
73
+ for seg in segments:
74
+ seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
75
+ seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
76
+ seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device)
77
+ output = model(seg_tensor)
78
+ probs = torch.softmax(output, dim=1)
79
+ confidence, pred = torch.max(probs, dim=1)
80
+ pred = pred.cpu().numpy()[0]
81
+ confidence = confidence.cpu().numpy()[0]
82
+ predicted_bird = class_names[str(pred)]
83
+ predictions.append((predicted_bird, confidence))
84
+
85
+ # Format predictions as HTML
86
+ predictions_html = "<ol>"
87
+ for i, (bird, conf) in enumerate(predictions, 1):
88
+ predictions_html += f"<li>{bird} (Confidence: {conf*100:.1f}%)</li>"
89
+ predictions_html += "</ol>"
90
+
91
+ return spectrogram_img, predictions_html
92
 
93
+ # Gradio interface
94
  interface = gr.Interface(
95
  fn=predict_bird,
96
+ inputs=gr.Audio(label="Upload MP3 file", type="filepath"),
97
+ outputs=[
98
+ gr.Image(label="Mel Spectrogram"),
99
+ gr.HTML(label="Predicted Bird Species")
100
+ ],
101
+ description="""
102
+ <h3>Bird Species</h3>
103
+ <img src='1.jpeg' width='300'>
104
+ <h3>Bird Description</h3>
105
+ <img src='2.jpeg' width='300'>
106
+ <h3>Bird Origins</h3>
107
+ <img src='3.jpeg' width='300'>
108
+ """
109
  )
110
  interface.launch()