import torch | |
import torch.nn as nn | |
from torch_audiomentations import Compose, Gain, PolarityInversion, AddColoredNoise, PitchShift, PeakNormalization, PitchShift | |
# TODO add where I copied the code from | |
class AUG(nn.Module): | |
def __init__(self, prob=0.3): | |
super().__init__() | |
self.aug = Compose( | |
transforms=[ | |
AddColoredNoise(p=prob), | |
PitchShift(sample_rate=16000, min_transpose_semitones=-1, max_transpose_semitones=1, p=prob), | |
PeakNormalization(p=0.1), | |
Gain(min_gain_in_db=-6, max_gain_in_db=6, p=prob), | |
]) | |
def forward(self, x): | |
return self.aug(x, sample_rate=16000) | |