thunnai commited on
Commit
470e483
·
1 Parent(s): f4176b0

All model to have device changed

Browse files
Files changed (1) hide show
  1. cli/SparkTTS.py +5 -0
cli/SparkTTS.py CHANGED
@@ -49,6 +49,11 @@ class SparkTTS:
49
  self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
50
  self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
51
  self.model.to(self.device)
 
 
 
 
 
52
 
53
  def process_prompt(
54
  self,
 
49
  self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
50
  self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
51
  self.model.to(self.device)
52
+
53
+ def to(self, device: torch.device):
54
+ self.device = device
55
+ self.model.to(self.device)
56
+ self.audio_tokenizer.to(self.device)
57
 
58
  def process_prompt(
59
  self,