Respair commited on
Commit
8248296
·
verified ·
1 Parent(s): ecfb9a9

Create meldataset.py

Browse files
Hiformer_Checkpoint_Libri_24khz/meldataset.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ import numpy as np
7
+ from librosa.util import normalize
8
+ from scipy.io.wavfile import read
9
+ import torchaudio
10
+ import librosa
11
+ from librosa.filters import mel as librosa_mel_fn
12
+
13
+ MAX_WAV_VALUE = 32768.0
14
+ import soundfile as sf
15
+
16
+
17
+ def normalize_audio(wav):
18
+ return wav / torch.max(torch.abs(torch.from_numpy(wav))) # Correct peak normalization
19
+
20
+ def load_wav_librosa(full_path):
21
+ data, sampling_rate = librosa.load(full_path, sr=24_000)
22
+ return data, sampling_rate
23
+
24
+
25
+
26
+ def load_wav_scipy(full_path):
27
+ sampling_rate, data = read(full_path)
28
+ return data, sampling_rate
29
+
30
+ def load_wav(full_path):
31
+ try:
32
+ return load_wav_scipy(full_path)
33
+ except:
34
+ # print('using librosa...')
35
+ return load_wav_librosa(full_path)
36
+
37
+
38
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
39
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
40
+
41
+
42
+ def dynamic_range_decompression(x, C=1):
43
+ return np.exp(x) / C
44
+
45
+
46
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
47
+ return torch.log(torch.clamp(x, min=clip_val) * C)
48
+
49
+
50
+ def dynamic_range_decompression_torch(x, C=1):
51
+ return torch.exp(x) / C
52
+
53
+
54
+ def spectral_normalize_torch(magnitudes):
55
+ output = dynamic_range_compression_torch(magnitudes)
56
+ return output
57
+
58
+
59
+ def spectral_de_normalize_torch(magnitudes):
60
+ output = dynamic_range_decompression_torch(magnitudes)
61
+ return output
62
+
63
+
64
+ mel_basis = {}
65
+ hann_window = {}
66
+
67
+
68
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
69
+
70
+
71
+ # y = torch.clamp(y, -1, 1)
72
+
73
+
74
+ # if torch.min(y) < -1.:
75
+ # # y = torch.clamp(y, min = -1)
76
+ # # print('min value is ', torch.min(y))
77
+ # if torch.max(y) > 1.:
78
+ # y = torch.clamp(y, max = -1)
79
+ # print('max value is ', torch.max(y))
80
+
81
+ global mel_basis, hann_window
82
+ if fmax not in mel_basis:
83
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
84
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
85
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
86
+
87
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
88
+ y = y.squeeze(1)
89
+
90
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
91
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
92
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
93
+ spec = torch.view_as_real(spec)
94
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
95
+
96
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
97
+ spec = spectral_normalize_torch(spec)
98
+
99
+ return spec
100
+
101
+
102
+
103
+ to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
104
+
105
+
106
+ # to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
107
+
108
+ mean, std = -4, 4
109
+ 5
110
+ def preproces(wave,to_mel=to_mel, device='cpu'):
111
+
112
+ to_mel = to_mel.to(device)
113
+ # wave_tensor = torch.from_numpy(wave).float()
114
+ mel_tensor = to_mel(wave)
115
+ mel_tensor = (torch.log(1e-5 + mel_tensor) - mean) / std
116
+ return mel_tensor
117
+
118
+
119
+ def get_dataset_filelist(a):
120
+ with open(a.input_training_file, 'r', encoding='utf-8') as fi:
121
+ training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ('' if '' not in x else ''))
122
+ for x in fi.read().split('\n') if len(x) > 0]
123
+
124
+ with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
125
+ validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ('' if '' not in x else ''))
126
+ for x in fi.read().split('\n') if len(x) > 0]
127
+ return training_files, validation_files
128
+
129
+
130
+ class MelDataset(torch.utils.data.Dataset):
131
+ def __init__(self, training_files, segment_size, n_fft, num_mels,
132
+ hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
133
+ device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
134
+ self.audio_files = training_files
135
+ random.seed(1234)
136
+ if shuffle:
137
+ random.shuffle(self.audio_files)
138
+ self.segment_size = segment_size
139
+ self.sampling_rate = sampling_rate
140
+ self.split = split
141
+ self.n_fft = n_fft
142
+ self.num_mels = num_mels
143
+ self.hop_size = hop_size
144
+ self.win_size = win_size
145
+ self.fmin = fmin
146
+ self.fmax = fmax
147
+ self.fmax_loss = fmax_loss
148
+ self.cached_wav = None
149
+ self.n_cache_reuse = n_cache_reuse
150
+ self._cache_ref_count = 0
151
+ self.device = device
152
+ self.fine_tuning = fine_tuning
153
+ self.base_mels_path = base_mels_path
154
+
155
+ def __getitem__(self, index):
156
+ filename = self.audio_files[index]
157
+ if self._cache_ref_count == 0:
158
+ audio, sampling_rate = load_wav(filename)
159
+ audio = audio / MAX_WAV_VALUE
160
+ if not self.fine_tuning:
161
+ audio = normalize(audio) * 0.95
162
+ self.cached_wav = audio
163
+ if sampling_rate != self.sampling_rate:
164
+ audio = librosa.resample(audio, orig_sr= sampling_rate, target_sr= self.sampling_rate)
165
+ # raise ValueError("{} SR doesn't match target {} SR, {}".format(
166
+ # sampling_rate, self.sampling_rate, filename))
167
+ self._cache_ref_count = self.n_cache_reuse
168
+ else:
169
+ audio = self.cached_wav
170
+ self._cache_ref_count -= 1
171
+
172
+ audio = torch.FloatTensor(audio)
173
+ audio = audio.unsqueeze(0)
174
+
175
+ if not self.fine_tuning:
176
+ if self.split:
177
+ if audio.size(1) >= self.segment_size:
178
+ max_audio_start = audio.size(1) - self.segment_size
179
+ audio_start = random.randint(0, max_audio_start)
180
+ audio = audio[:, audio_start:audio_start+self.segment_size]
181
+ else:
182
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
183
+
184
+ # mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
185
+ # self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
186
+ # center=False)
187
+
188
+ mel = preproces(audio)
189
+ else:
190
+ mel = np.load(
191
+ os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
192
+ mel = torch.from_numpy(mel)
193
+
194
+ if len(mel.shape) < 3:
195
+ mel = mel.unsqueeze(0)
196
+
197
+ if self.split:
198
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
199
+
200
+ if audio.size(1) >= self.segment_size:
201
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
202
+ mel = mel[:, :, mel_start:mel_start + frames_per_seg]
203
+ audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
204
+ else:
205
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
206
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
207
+
208
+ mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
209
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
210
+ center=False)
211
+
212
+ # mel_loss = mel_spectrogram(audio)
213
+ if mel.shape[-1] != mel_loss.shape[-1]:
214
+ mel = mel[..., :mel_loss.shape[-1]]
215
+
216
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
217
+
218
+ def __len__(self):
219
+ return len(self.audio_files)