Antoni Bigata commited on
Commit
1982e66
·
1 Parent(s): d9970e1

requirements

Browse files
Files changed (1) hide show
  1. WavLM_modules.py +1 -3
WavLM_modules.py CHANGED
@@ -450,9 +450,7 @@ class MultiheadAttention(nn.Module):
450
  relative_position_bucket = self._relative_positions_bucket(
451
  relative_position, bidirectional=True
452
  )
453
- # relative_position_bucket = relative_position_bucket.to(
454
- # self.relative_attention_bias.weight.device
455
- # )
456
  values = self.relative_attention_bias(relative_position_bucket)
457
  values = values.permute([2, 0, 1])
458
  return values
 
450
  relative_position_bucket = self._relative_positions_bucket(
451
  relative_position, bidirectional=True
452
  )
453
+ relative_position_bucket.cuda()
 
 
454
  values = self.relative_attention_bias(relative_position_bucket)
455
  values = values.permute([2, 0, 1])
456
  return values