xls-r-et-V-3 / augmentation.py
vasilis's picture
Training in progress, step 500
0eb8a8d
raw
history blame contribute delete
687 Bytes
import torch
import torch.nn as nn
from torch_audiomentations import Compose, Gain, PolarityInversion, AddColoredNoise, PitchShift, PeakNormalization, PitchShift
# TODO add where I copied the code from
class AUG(nn.Module):
def __init__(self, prob=0.3):
super().__init__()
self.aug = Compose(
transforms=[
AddColoredNoise(p=prob),
PitchShift(sample_rate=16000, min_transpose_semitones=-1, max_transpose_semitones=1, p=prob),
PeakNormalization(p=0.1),
Gain(min_gain_in_db=-6, max_gain_in_db=6, p=prob),
])
def forward(self, x):
return self.aug(x, sample_rate=16000)