Spaces:
Bradarr
/
Running on Zero

Bradarr commited on
Commit
6c77205
·
verified ·
1 Parent(s): 107a152

Upload app2.py

Browse files
Files changed (1) hide show
  1. app2.py +194 -0
app2.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import spaces
5
+ import torch
6
+ import torchaudio
7
+ from generator import Segment, load_csm_1b
8
+ from huggingface_hub import hf_hub_download, login
9
+ from watermarking import watermark
10
+ import whisper
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+ import logging
13
+ from transformers import GenerationConfig
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
17
+
18
+ # --- Authentication and Configuration --- (Moved BEFORE model loading)
19
+ try:
20
+ api_key = os.getenv("HF_TOKEN")
21
+ if not api_key:
22
+ raise ValueError("HF_TOKEN not found in environment variables.")
23
+ login(token=api_key)
24
+
25
+ CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
26
+ if not CSM_1B_HF_WATERMARK:
27
+ raise ValueError("WATERMARK_KEY not found or invalid in environment variables.")
28
+
29
+ gpu_timeout = int(os.getenv("GPU_TIMEOUT", 120))
30
+ except (ValueError, TypeError) as e:
31
+ logging.error(f"Configuration error: {e}")
32
+ raise
33
+
34
+ SPACE_INTRO_TEXT = """
35
+ # Sesame CSM 1B - Conversational Demo
36
+
37
+ This demo allows you to have a conversation with Sesame CSM 1B, leveraging Whisper for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources.
38
+
39
+ *Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
40
+ """
41
+
42
+ # --- Model Loading --- (Moved INSIDE infer function)
43
+
44
+ # --- Constants --- (Constants can stay outside)
45
+ SPEAKER_ID = 0
46
+ MAX_CONTEXT_SEGMENTS = 1
47
+ MAX_GEMMA_LENGTH = 150
48
+
49
+ # --- Global Conversation History ---
50
+ conversation_history = []
51
+
52
+ # --- Helper Functions ---
53
+
54
+ def transcribe_audio(audio_path: str, whisper_model) -> str: # Pass whisper_model
55
+ try:
56
+ audio = whisper.load_audio(audio_path)
57
+ audio = whisper.pad_or_trim(audio)
58
+ result = whisper_model.transcribe(audio)
59
+ return result["text"]
60
+ except Exception as e:
61
+ logging.error(f"Whisper transcription error: {e}")
62
+ return "Error: Could not transcribe audio."
63
+
64
+ def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: # Pass model and tokenizer
65
+ try:
66
+ # Gemma 3 chat template format
67
+ messages = [{"role": "user", "content": text}]
68
+ input = tokenizer_gemma.apply_chat_template(messages, return_tensors="pt").to(device)
69
+ generation_config = GenerationConfig(
70
+ max_new_tokens=MAX_GEMMA_LENGTH,
71
+ early_stopping=True,
72
+ )
73
+
74
+ generated_output = model_gemma.generate(input, generation_config=generation_config)
75
+ return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True)
76
+
77
+ #input_text = "Reapond to the users prompt: " + text
78
+ #input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
79
+ #generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True)
80
+ #return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True)
81
+ except Exception as e:
82
+ logging.error(f"Gemma response generation error: {e}")
83
+ return "I'm sorry, I encountered an error generating a response."
84
+
85
+ def load_audio(audio_path: str, generator) -> torch.Tensor: #Pass generator
86
+ try:
87
+ audio_tensor, sample_rate = torchaudio.load(audio_path)
88
+ audio_tensor = audio_tensor.mean(dim=0)
89
+ if sample_rate != generator.sample_rate:
90
+ audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=generator.sample_rate)
91
+ return audio_tensor
92
+ except Exception as e:
93
+ logging.error(f"Audio loading error: {e}")
94
+ raise gr.Error("Could not load or process the audio file.") from e
95
+
96
+ def clear_history():
97
+ global conversation_history
98
+ conversation_history = []
99
+ logging.info("Conversation history cleared.")
100
+ return "Conversation history cleared."
101
+
102
+ # --- Main Inference Function ---
103
+
104
+ @spaces.GPU(duration=gpu_timeout) # Decorator FIRST
105
+ def infer(user_audio) -> tuple[int, np.ndarray]:
106
+ # --- CUDA Availability Check (INSIDE infer) ---
107
+ if torch.cuda.is_available():
108
+ print(f"CUDA is available! Device count: {torch.cuda.device_count()}")
109
+ print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
110
+ print(f"CUDA version: {torch.version.cuda}")
111
+ device = "cuda"
112
+ else:
113
+ print("CUDA is NOT available. Using CPU.") # Use CPU, don't raise
114
+ device = "cpu"
115
+
116
+ try:
117
+ # --- Model Loading (INSIDE infer, after device is set) ---
118
+ model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
119
+ generator = load_csm_1b(model_path, device)
120
+ logging.info("Sesame CSM 1B loaded successfully.")
121
+
122
+ whisper_model = whisper.load_model("small.en", device=device)
123
+ logging.info("Whisper model loaded successfully.")
124
+
125
+ tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
126
+ model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it").to(device)
127
+ logging.info("Gemma 3 1B pt model loaded successfully.")
128
+
129
+ if not user_audio:
130
+ raise ValueError("No audio input received.")
131
+ return _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device) #Pass all models
132
+ except Exception as e:
133
+ logging.exception(f"Inference error: {e}")
134
+ raise gr.Error(f"An error occurred during processing: {e}")
135
+
136
+ def _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device) -> tuple[int, np.ndarray]:
137
+ global conversation_history
138
+
139
+ try:
140
+ user_text = transcribe_audio(user_audio, whisper_model) # Pass whisper_model
141
+ logging.info(f"User: {user_text}")
142
+
143
+ ai_text = generate_response(user_text, model_gemma, tokenizer_gemma, device) # Pass model and tokenizer
144
+ logging.info(f"AI: {ai_text}")
145
+
146
+ try:
147
+ ai_audio = generator.generate(
148
+ text=ai_text,
149
+ speaker=SPEAKER_ID,
150
+ context=conversation_history,
151
+ max_audio_length_ms=10_000,
152
+ )
153
+ logging.info("Audio generated successfully.")
154
+ except Exception as e:
155
+ logging.error(f"Sesame response generation error: {e}")
156
+ raise gr.Error(f"Sesame response generation error: {e}")
157
+
158
+
159
+ user_segment = Segment(speaker = 1, text = user_text, audio = load_audio(user_audio, generator)) #Pass Generator
160
+ ai_segment = Segment(speaker = SPEAKER_ID, text = ai_text, audio = ai_audio)
161
+ conversation_history.append(user_segment)
162
+ conversation_history.append(ai_segment)
163
+
164
+ if len(conversation_history) > MAX_CONTEXT_SEGMENTS:
165
+ conversation_history.pop(0)
166
+
167
+ audio_tensor, wm_sample_rate = watermark(
168
+ generator._watermarker, ai_audio, generator.sample_rate, CSM_1B_HF_WATERMARK
169
+ )
170
+ audio_tensor = torchaudio.functional.resample(
171
+ audio_tensor, orig_freq=wm_sample_rate, new_freq=generator.sample_rate
172
+ )
173
+
174
+ ai_audio_array = (audio_tensor * 32768).to(torch.int16).cpu().numpy()
175
+ return generator.sample_rate, ai_audio_array
176
+
177
+ except Exception as e:
178
+ logging.exception(f"Error in _infer: {e}")
179
+ raise gr.Error(f"An error occurred during processing: {e}")
180
+
181
+ # --- Gradio Interface ---
182
+
183
+ with gr.Blocks() as app:
184
+ gr.Markdown(SPACE_INTRO_TEXT)
185
+ audio_input = gr.Audio(label="Your Input", type="filepath")
186
+ audio_output = gr.Audio(label="AI Response")
187
+ clear_button = gr.Button("Clear Conversation History")
188
+ status_display = gr.Textbox(label="Status", visible=False)
189
+
190
+ btn = gr.Button("Generate Response")
191
+ btn.click(infer, inputs=[audio_input], outputs=[audio_output])
192
+ clear_button.click(clear_history, outputs=[status_display])
193
+
194
+ app.launch(ssr_mode=False, share=True)