Antoni Bigata commited on
Commit
9ebfa23
·
1 Parent(s): 1982e66

requirements

Browse files
Files changed (1) hide show
  1. WavLM_modules.py +1 -2
WavLM_modules.py CHANGED
@@ -450,8 +450,7 @@ class MultiheadAttention(nn.Module):
450
  relative_position_bucket = self._relative_positions_bucket(
451
  relative_position, bidirectional=True
452
  )
453
- relative_position_bucket.cuda()
454
- values = self.relative_attention_bias(relative_position_bucket)
455
  values = values.permute([2, 0, 1])
456
  return values
457
 
 
450
  relative_position_bucket = self._relative_positions_bucket(
451
  relative_position, bidirectional=True
452
  )
453
+ values = self.relative_attention_bias(relative_position_bucket.cuda())
 
454
  values = values.permute([2, 0, 1])
455
  return values
456