All model to have device changed
Browse files- cli/SparkTTS.py +5 -0
cli/SparkTTS.py
CHANGED
@@ -49,6 +49,11 @@ class SparkTTS:
|
|
49 |
self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
|
50 |
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
|
51 |
self.model.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
def process_prompt(
|
54 |
self,
|
|
|
49 |
self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
|
50 |
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
|
51 |
self.model.to(self.device)
|
52 |
+
|
53 |
+
def to(self, device: torch.device):
|
54 |
+
self.device = device
|
55 |
+
self.model.to(self.device)
|
56 |
+
self.audio_tokenizer.to(self.device)
|
57 |
|
58 |
def process_prompt(
|
59 |
self,
|