Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -18,14 +18,14 @@ from duckduckgo_search import DDGS
|
|
18 |
import re
|
19 |
import time
|
20 |
from huggingface_hub import HfApi
|
21 |
-
from spaces import GPU
|
22 |
|
23 |
# --- Constants and Configuration ---
|
24 |
MODEL_ID = "nvidia/Llama-3.1-Nemotron-8B-UltraLong-4M-Instruct"
|
25 |
MAX_GPU_MEMORY = "40GiB" # A100 memory allocation
|
26 |
|
27 |
# --- Model Loading ---
|
28 |
-
@GPU(memory=40)
|
29 |
def load_model():
|
30 |
"""Load the LLM model optimized for A100 GPU using 4-bit quantization."""
|
31 |
print(f"Attempting to load model: {MODEL_ID} with 4-bit quantization")
|
@@ -39,39 +39,23 @@ def load_model():
|
|
39 |
)
|
40 |
|
41 |
# Device map will handle placing layers, relying on accelerate
|
42 |
-
# No need to explicitly set max_memory when using device_map="auto" typically
|
43 |
pipe = pipeline(
|
44 |
"text-generation",
|
45 |
model=MODEL_ID,
|
46 |
-
# Note: torch_dtype is sometimes ignored when quantization_config is used,
|
47 |
-
# but specifying compute_dtype in BitsAndBytesConfig is key.
|
48 |
-
# Keep torch_dtype=torch.bfloat16 here for consistency if needed by other parts.
|
49 |
torch_dtype=torch.bfloat16,
|
50 |
-
device_map="auto",
|
51 |
model_kwargs={
|
52 |
"quantization_config": quantization_config,
|
53 |
"use_cache": True,
|
54 |
-
# "trust_remote_code=True" # Add if model requires it (check model card)
|
55 |
}
|
56 |
)
|
57 |
print(f"Model {MODEL_ID} loaded successfully on device: {pipe.device} (using 4-bit quantization)")
|
58 |
return pipe
|
59 |
except Exception as e:
|
60 |
print(f"FATAL Error loading model '{MODEL_ID}' (check memory/config): {e}")
|
61 |
-
# Raise the error to ensure it's visible in Spaces logs
|
62 |
raise e
|
63 |
|
64 |
-
# --- REST OF THE CODE REMAINS THE SAME ---
|
65 |
-
# (search_person, create_synthetic_profile, extract_text_from_search_results,
|
66 |
-
# parse_llm_output, generate_enhanced_persona, generate_system_prompt_with_llm,
|
67 |
-
# generate_response, PersonaChat class, create_interface function, __main__ block)
|
68 |
-
# ... include the rest of the Python code from the previous correct version here ...
|
69 |
-
|
70 |
-
# Make sure the rest of your app.py file follows this modified load_model function.
|
71 |
-
# Keep all other functions and the Gradio interface definition as they were.
|
72 |
-
|
73 |
# --- Web Search ---
|
74 |
-
# (Keep search_person, create_synthetic_profile, extract_text_from_search_results as before)
|
75 |
def search_person(name, context=""):
|
76 |
"""Search for information about a person using DuckDuckGo."""
|
77 |
print(f"Searching for: {name} with context: {context}")
|
@@ -128,7 +112,7 @@ def create_synthetic_profile(name, context):
|
|
128 |
profile["body"] += f"Based on being in {grade}th grade, {name} is likely around {age} years old. "
|
129 |
profile["body"] += f"Typical interests for this age might include friends, hobbies, school subjects, and developing independence. "
|
130 |
except ValueError:
|
131 |
-
|
132 |
profile["body"] += "Since no public information was found, this profile is based solely on the provided context."
|
133 |
return [profile]
|
134 |
|
@@ -166,14 +150,15 @@ def parse_llm_output(full_output, input_prompt_list):
|
|
166 |
if isinstance(full_output, list) and len(full_output) > 0:
|
167 |
if isinstance(full_output[0], dict) and "generated_text" in full_output[0]:
|
168 |
generated_text = full_output[0]["generated_text"]
|
169 |
-
else:
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
172 |
|
173 |
last_input_content = ""
|
174 |
if isinstance(input_prompt_list, list) and input_prompt_list:
|
175 |
-
# Find the last message with 'user' or 'system' role potentially?
|
176 |
-
# Let's stick to finding the last message content for simplicity
|
177 |
last_input_content = input_prompt_list[-1].get("content", "")
|
178 |
|
179 |
if last_input_content:
|
@@ -181,52 +166,41 @@ def parse_llm_output(full_output, input_prompt_list):
|
|
181 |
if last_occurrence_index != -1:
|
182 |
potential_response = generated_text[last_occurrence_index + len(last_input_content):].strip()
|
183 |
if potential_response:
|
184 |
-
# Basic cleanup
|
185 |
potential_response = re.sub(r'^<\/?s?>', '', potential_response).strip()
|
186 |
potential_response = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', potential_response).strip()
|
187 |
-
# Check if the response is just whitespace or seems empty after cleanup
|
188 |
if potential_response:
|
189 |
return potential_response
|
190 |
|
191 |
-
# Fallback or if model correctly outputted only the response
|
192 |
cleaned_text = generated_text
|
193 |
if isinstance(input_prompt_list, list) and input_prompt_list:
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
pass # Let's rely more on the end-stripping heuristic above
|
198 |
|
199 |
-
# General cleanup
|
200 |
cleaned_text = re.sub(r'^<\/?s?>', '', cleaned_text).strip()
|
201 |
cleaned_text = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', cleaned_text).strip()
|
202 |
|
203 |
-
# If after all this, it's empty, maybe return original generated_text?
|
204 |
-
# Or log a warning and return the cleaned version.
|
205 |
if not cleaned_text and generated_text:
|
206 |
-
|
207 |
-
|
208 |
|
209 |
-
# If input prompt wasn't found, assume the model outputted only the response (ideal case)
|
210 |
-
# or the whole thing (fallback case). The cleaning helps for the latter.
|
211 |
if last_input_content and last_occurrence_index == -1:
|
212 |
print("Warning: Could not find last input prompt in LLM output. Returning cleaned full output.")
|
213 |
|
214 |
return cleaned_text
|
215 |
|
216 |
-
@GPU(memory=40)
|
217 |
-
def generate_enhanced_persona(
|
218 |
"""Use the LLM to enhance the persona profile."""
|
|
|
219 |
print(f"Generating enhanced persona for {name}...")
|
220 |
-
if model is None: raise ValueError("Model is not loaded.")
|
221 |
-
|
222 |
enhancement_prompt = [
|
223 |
{"role": "system", "content": """You are an expert AI character developer. Your task is to synthesize information into a detailed and coherent character profile. Focus on personality, potential interests, speaking style, and mannerisms based ONLY on the provided text. If the text indicates the character is a child, ensure the profile reflects age-appropriate traits. Output ONLY the enhanced character profile description. Do not include conversational introductions, explanations, apologies for limited info, or markdown formatting like headers (e.g., ### Personality). Start directly with the profile text."""},
|
224 |
{"role": "user", "content": f"""Synthesize the following information about '{name}' into a character profile. Context: {context} Information Found:\n{bio_text}\n\nCreate the profile based *only* on the text above."""}
|
225 |
]
|
226 |
-
|
227 |
try:
|
228 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
229 |
-
outputs =
|
230 |
parsed_output = parse_llm_output(outputs, enhancement_prompt)
|
231 |
print("Enhanced persona generated.")
|
232 |
return parsed_output if parsed_output else f"Could not generate profile based on:\n{bio_text}"
|
@@ -235,22 +209,19 @@ def generate_enhanced_persona(model, name, bio_text, context=""):
|
|
235 |
print(error_msg)
|
236 |
return f"Error enhancing profile: {str(e)}\n\nUsing basic info:\n{bio_text}"
|
237 |
|
238 |
-
@GPU(memory=40)
|
239 |
-
def generate_system_prompt_with_llm(
|
240 |
"""Generate an optimized system prompt for the persona."""
|
|
|
241 |
print(f"Generating system prompt for {name}...")
|
242 |
-
if model is None: raise ValueError("Model is not loaded.")
|
243 |
-
|
244 |
fallback_prompt = f"""You are simulating the character '{name}'. Act and respond according to this profile:\n{enhanced_profile}\nAdditional context for the simulation: {context}\n---\nMaintain this persona consistently. Respond naturally based on the profile. Do not mention that you are an AI or a simulation. If asked about details not in the profile, you can be evasive or state you don't know/remember, consistent with the persona."""
|
245 |
-
|
246 |
prompt = [
|
247 |
{"role": "system", "content": """You are an expert AI prompt engineer specializing in character simulation. Your task is to create a concise and effective system prompt for an LLM that will simulate a character based on a provided profile. The system prompt should instruct the LLM to embody the character, covering: 1. Core personality, attitude, and speaking style (based on the profile). 2. Key interests or knowledge areas (if mentioned in the profile). 3. How to handle questions outside its knowledge (e.g., be evasive, admit ignorance naturally). 4. Explicitly state it should *not* break character or mention being an AI. 5. Incorporate age-appropriateness if the profile suggests a specific age group. Output ONLY the system prompt itself. Do not add any explanation or introductory text."""},
|
248 |
{"role": "user", "content": f"""Create a system prompt for an AI to simulate the character '{name}'. Context for simulation: {context} Character Profile:\n{enhanced_profile}\n\nGenerate the system prompt based *only* on the profile and context provided."""}
|
249 |
]
|
250 |
-
|
251 |
try:
|
252 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
253 |
-
outputs =
|
254 |
parsed_output = parse_llm_output(outputs, prompt)
|
255 |
print("System prompt generated.")
|
256 |
return parsed_output if parsed_output else fallback_prompt
|
@@ -259,25 +230,22 @@ def generate_system_prompt_with_llm(model, name, enhanced_profile, context=""):
|
|
259 |
print(error_msg)
|
260 |
return fallback_prompt
|
261 |
|
262 |
-
@GPU(memory=40)
|
263 |
-
def generate_response(
|
264 |
"""Generate a response using the LLM."""
|
|
|
265 |
print("Generating response...")
|
266 |
-
if
|
267 |
-
|
268 |
-
|
269 |
try:
|
270 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
271 |
-
|
272 |
-
outputs = model(
|
273 |
messages,
|
274 |
max_new_tokens=512,
|
275 |
do_sample=True,
|
276 |
top_p=0.9,
|
277 |
temperature=0.7,
|
278 |
-
|
279 |
-
# Check if EOS token is needed for this model/pipeline setup
|
280 |
-
pad_token_id=model.tokenizer.eos_token_id if model.tokenizer.eos_token_id else None
|
281 |
)
|
282 |
parsed_output = parse_llm_output(outputs, messages)
|
283 |
print("Response generated.")
|
@@ -285,50 +253,27 @@ def generate_response(model, messages):
|
|
285 |
except Exception as e:
|
286 |
error_msg = f"Error during response generation: {str(e)}"
|
287 |
print(error_msg)
|
288 |
-
# Consider if the specific error should be shown to the user
|
289 |
return f"Sorry, I encountered an error trying to respond."
|
290 |
|
291 |
-
|
292 |
# --- Persona Chat Class ---
|
293 |
class PersonaChat:
|
294 |
def __init__(self):
|
295 |
-
self.model = None
|
296 |
self.system_prompt = "You are a helpful assistant."
|
297 |
self.persona_name = "Assistant"
|
298 |
self.persona_context = ""
|
299 |
self.messages = []
|
300 |
self.enhanced_profile = ""
|
301 |
-
self.model_loaded = False
|
302 |
-
|
303 |
-
# No @GPU decorator needed here typically, as it calls functions that ARE decorated
|
304 |
-
def load_model_if_needed(self):
|
305 |
-
"""Loads the model if it hasn't been loaded successfully."""
|
306 |
-
if not self.model_loaded or self.model is None: # Check self.model too
|
307 |
-
print("Model not loaded or instance lost. Attempting to load...")
|
308 |
-
# Call the @GPU decorated load_model function
|
309 |
-
self.model = load_model() # This function IS decorated
|
310 |
-
if self.model is None:
|
311 |
-
# load_model now raises error, but double-check here
|
312 |
-
raise RuntimeError("Failed to load the language model. Cannot proceed.")
|
313 |
-
else:
|
314 |
-
self.model_loaded = True
|
315 |
-
print("Model loaded successfully within PersonaChat instance.")
|
316 |
-
# else: print("Model already loaded.") # Reduce log noise
|
317 |
|
318 |
-
# No @GPU decorator needed here typically
|
319 |
def set_persona(self, name, context=""):
|
320 |
"""Orchestrates persona creation: search, enhance, generate prompt."""
|
321 |
-
# This method calls other functions that have @GPU decorators
|
322 |
try:
|
323 |
-
self.load_model_if_needed() # Ensures model is ready
|
324 |
-
|
325 |
self.persona_name = name
|
326 |
self.persona_context = context
|
327 |
self.messages = []
|
328 |
self.enhanced_profile = ""
|
329 |
|
330 |
status = f"Searching for information about {name}..."
|
331 |
-
yield status, "", "", [{"role": "system", "content": "Initializing persona creation..."}]
|
332 |
|
333 |
search_results = search_person(name, context)
|
334 |
if isinstance(search_results, str) and search_results.startswith("Error"):
|
@@ -338,82 +283,58 @@ class PersonaChat:
|
|
338 |
|
339 |
bio_text = extract_text_from_search_results(search_results)
|
340 |
if bio_text.startswith("Could not extract text"):
|
341 |
-
|
342 |
|
343 |
status = f"Creating enhanced profile for {name}..."
|
344 |
-
yield status, "", bio_text, [{"role": "system", "content": status}]
|
345 |
|
346 |
-
|
347 |
-
self.enhanced_profile = generate_enhanced_persona(self.model, name, bio_text, context)
|
348 |
profile_for_prompt = self.enhanced_profile
|
349 |
if self.enhanced_profile.startswith("Error enhancing profile"):
|
350 |
-
|
351 |
-
|
352 |
|
353 |
status = f"Generating optimal system prompt for {name}..."
|
354 |
-
# Yield the enhanced profile while generating prompt
|
355 |
yield status, self.enhanced_profile, self.enhanced_profile, [{"role": "system", "content": status}]
|
356 |
|
357 |
-
|
358 |
-
self.system_prompt = generate_system_prompt_with_llm(self.model, name, profile_for_prompt, context)
|
359 |
self.messages = [{"role": "system", "content": self.system_prompt}]
|
360 |
|
361 |
yield f"Persona set to '{name}'. Ready to chat!", self.system_prompt, self.enhanced_profile, self.messages
|
362 |
|
363 |
-
except RuntimeError as e:
|
364 |
-
error_msg = f"Critical Error: {str(e)}"
|
365 |
-
print(error_msg)
|
366 |
-
yield error_msg, "", "", [{"role": "system", "content": error_msg}]
|
367 |
except Exception as e:
|
368 |
error_msg = f"An unexpected error occurred during persona setup: {str(e)}"
|
369 |
print(error_msg)
|
370 |
-
# Attempt to yield current state even on error
|
371 |
yield error_msg, self.system_prompt, self.enhanced_profile, [{"role": "system", "content": error_msg}]
|
372 |
|
373 |
-
# No @GPU decorator needed here typically
|
374 |
def chat(self, user_message):
|
375 |
"""Processes a user message and returns the AI's response."""
|
376 |
-
# This method calls generate_response which has the @GPU decorator
|
377 |
try:
|
378 |
-
self.load_model_if_needed()
|
379 |
-
|
380 |
if not self.messages:
|
381 |
-
|
382 |
-
|
383 |
|
384 |
print(f"User message: {user_message}")
|
385 |
-
|
386 |
-
# Keep internal history, pass copy to model if needed, but pipeline usually handles state
|
387 |
-
self.messages.append(formatted_message)
|
388 |
|
389 |
-
|
390 |
-
response = generate_response(self.model, self.messages)
|
391 |
|
392 |
-
# Append assistant response IF generation succeeded
|
393 |
if not response.startswith("Sorry, I encountered an error"):
|
394 |
-
|
395 |
-
|
396 |
-
print(f"Assistant response: {response}")
|
397 |
else:
|
398 |
-
|
399 |
-
# Do not add the error message itself to the persistent history
|
400 |
-
# Let the UI show the error, but don't make the bot repeat it next turn.
|
401 |
|
402 |
return response
|
403 |
|
404 |
-
except RuntimeError as e:
|
405 |
-
error_msg = f"Critical Error: {str(e)}. Cannot generate response."
|
406 |
-
print(error_msg)
|
407 |
-
return error_msg
|
408 |
except Exception as e:
|
409 |
error_msg = f"Error generating response: {str(e)}"
|
410 |
print(error_msg)
|
411 |
return f"Sorry, I encountered an error: {str(e)}"
|
412 |
|
413 |
-
|
414 |
# --- Gradio Interface ---
|
415 |
def create_interface():
|
416 |
-
persona_chat = PersonaChat()
|
417 |
|
418 |
css = """
|
419 |
.gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
|
@@ -428,7 +349,6 @@ def create_interface():
|
|
428 |
.persona-button { background-color: #4ca1af !important; color: white !important; }
|
429 |
.system-prompt-display { background-color: #f5f5f5; border-radius: 8px; padding: 15px; margin-top: 15px; border: 1px solid #e0e0e0; font-family: monospace; white-space: pre-wrap; word-wrap: break-word; }
|
430 |
.footer { text-align: center; margin-top: 20px; font-size: 0.9rem; color: #666; }
|
431 |
-
/* Use default chatbot message styling provided by type='messages' */
|
432 |
.typing-indicator { color: #aaa; font-style: italic; }
|
433 |
"""
|
434 |
|
@@ -451,18 +371,15 @@ def create_interface():
|
|
451 |
enhanced_profile_display = gr.TextArea(label="Enhanced Profile (Generated by AI)", interactive=False, lines=10, elem_classes="system-prompt-display")
|
452 |
system_prompt_display = gr.TextArea(label="System Prompt (Instructions for the AI)", interactive=False, lines=10, elem_classes="system-prompt-display")
|
453 |
|
454 |
-
|
455 |
with gr.Column(elem_classes="chat-section"):
|
456 |
gr.Markdown("### 2. Chat with Your Character")
|
457 |
character_name_display = gr.Markdown(value="*No persona created yet*", elem_id="character-name-display")
|
458 |
-
# ***** FIX GRADIO WARNINGS *****
|
459 |
chatbot = gr.Chatbot(
|
460 |
label="Conversation",
|
461 |
height=450,
|
462 |
elem_classes="chat-container",
|
463 |
-
|
464 |
-
|
465 |
-
type="messages" # ***** USE RECOMMENDED TYPE *****
|
466 |
)
|
467 |
with gr.Row():
|
468 |
msg_input = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter...", elem_classes="message-input", scale=4)
|
@@ -481,7 +398,6 @@ def create_interface():
|
|
481 |
initial_character_display = f"### Preparing to chat with {name}..."
|
482 |
initial_prompt = "System prompt will appear here..."
|
483 |
initial_profile = "Enhanced profile will appear here..."
|
484 |
-
# Start with empty history for messages type
|
485 |
initial_history = []
|
486 |
|
487 |
yield initial_status, initial_prompt, initial_profile, initial_character_display, initial_history
|
@@ -489,55 +405,46 @@ def create_interface():
|
|
489 |
final_status, final_prompt, final_profile = "Error", "", ""
|
490 |
final_history = initial_history
|
491 |
try:
|
492 |
-
# Use the PersonaChat instance's method generator
|
493 |
-
# Expected yield order: status, system_prompt, enhanced_profile, messages_list
|
494 |
for status_update, prompt_update, profile_update, history_update in persona_chat.set_persona(name, context):
|
495 |
final_status, final_prompt, final_profile = status_update, prompt_update, profile_update
|
496 |
-
if isinstance(history_update, list):
|
|
|
497 |
|
498 |
character_display = f"### Preparing chat with {name}..."
|
499 |
if "Ready to chat" in status_update:
|
500 |
character_display = f"### Chatting with {name}"
|
501 |
elif "Error" in status_update:
|
502 |
-
|
503 |
|
504 |
yield status_update, final_prompt, final_profile, character_display, final_history
|
505 |
-
time.sleep(0.1)
|
506 |
|
507 |
except Exception as e:
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
yield error_msg, final_prompt, final_profile, f"### Error creating {name}", final_history
|
512 |
-
|
513 |
|
514 |
def send_message_flow(message, history):
|
515 |
-
|
516 |
-
|
517 |
if not message.strip():
|
518 |
return "", history
|
519 |
|
520 |
-
# Check if persona is ready (looks for system message in internal state)
|
521 |
if not persona_chat.messages or persona_chat.messages[0]['role'] != 'system':
|
522 |
history.append({"role": "user", "content": message})
|
523 |
history.append({"role": "assistant", "content": "Error: Please create a valid persona first."})
|
524 |
return "", history
|
525 |
|
526 |
-
# Append user message to UI history
|
527 |
history.append({"role": "user", "content": message})
|
528 |
-
|
529 |
-
history.append({"role": "assistant", "content": None}) # Use None for typing indicator with type='messages'
|
530 |
|
531 |
-
yield "", history
|
532 |
|
533 |
-
# Call chat method (uses internal state, returns string response)
|
534 |
response_text = persona_chat.chat(message)
|
535 |
|
536 |
-
# Update the placeholder in UI history with the actual response
|
537 |
history[-1]["content"] = response_text
|
538 |
|
539 |
-
yield "", history
|
540 |
-
|
541 |
|
542 |
set_persona_button.click(
|
543 |
set_persona_flow,
|
@@ -561,10 +468,9 @@ def create_interface():
|
|
561 |
if __name__ == "__main__":
|
562 |
print("Starting Gradio application for Hugging Face Spaces...")
|
563 |
demo = create_interface()
|
564 |
-
demo.queue().launch(
|
565 |
server_name="0.0.0.0",
|
566 |
server_port=7860,
|
567 |
-
|
568 |
-
|
569 |
-
debug=True # More verbose logging
|
570 |
)
|
|
|
18 |
import re
|
19 |
import time
|
20 |
from huggingface_hub import HfApi
|
21 |
+
from spaces import GPU # Directly import GPU from spaces - Crucial for HF Spaces
|
22 |
|
23 |
# --- Constants and Configuration ---
|
24 |
MODEL_ID = "nvidia/Llama-3.1-Nemotron-8B-UltraLong-4M-Instruct"
|
25 |
MAX_GPU_MEMORY = "40GiB" # A100 memory allocation
|
26 |
|
27 |
# --- Model Loading ---
|
28 |
+
@GPU(memory=40) # ****** THIS DECORATOR IS ESSENTIAL FOR SPACES STARTUP ******
|
29 |
def load_model():
|
30 |
"""Load the LLM model optimized for A100 GPU using 4-bit quantization."""
|
31 |
print(f"Attempting to load model: {MODEL_ID} with 4-bit quantization")
|
|
|
39 |
)
|
40 |
|
41 |
# Device map will handle placing layers, relying on accelerate
|
|
|
42 |
pipe = pipeline(
|
43 |
"text-generation",
|
44 |
model=MODEL_ID,
|
|
|
|
|
|
|
45 |
torch_dtype=torch.bfloat16,
|
46 |
+
device_map="auto", # Let accelerate handle layer placement
|
47 |
model_kwargs={
|
48 |
"quantization_config": quantization_config,
|
49 |
"use_cache": True,
|
|
|
50 |
}
|
51 |
)
|
52 |
print(f"Model {MODEL_ID} loaded successfully on device: {pipe.device} (using 4-bit quantization)")
|
53 |
return pipe
|
54 |
except Exception as e:
|
55 |
print(f"FATAL Error loading model '{MODEL_ID}' (check memory/config): {e}")
|
|
|
56 |
raise e
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
# --- Web Search ---
|
|
|
59 |
def search_person(name, context=""):
|
60 |
"""Search for information about a person using DuckDuckGo."""
|
61 |
print(f"Searching for: {name} with context: {context}")
|
|
|
112 |
profile["body"] += f"Based on being in {grade}th grade, {name} is likely around {age} years old. "
|
113 |
profile["body"] += f"Typical interests for this age might include friends, hobbies, school subjects, and developing independence. "
|
114 |
except ValueError:
|
115 |
+
profile["body"] += f"The grade mentioned ('{grade_match.group(1)}') could not be parsed to estimate age. "
|
116 |
profile["body"] += "Since no public information was found, this profile is based solely on the provided context."
|
117 |
return [profile]
|
118 |
|
|
|
150 |
if isinstance(full_output, list) and len(full_output) > 0:
|
151 |
if isinstance(full_output[0], dict) and "generated_text" in full_output[0]:
|
152 |
generated_text = full_output[0]["generated_text"]
|
153 |
+
else:
|
154 |
+
return str(full_output)
|
155 |
+
elif isinstance(full_output, str):
|
156 |
+
generated_text = full_output
|
157 |
+
else:
|
158 |
+
return str(full_output)
|
159 |
|
160 |
last_input_content = ""
|
161 |
if isinstance(input_prompt_list, list) and input_prompt_list:
|
|
|
|
|
162 |
last_input_content = input_prompt_list[-1].get("content", "")
|
163 |
|
164 |
if last_input_content:
|
|
|
166 |
if last_occurrence_index != -1:
|
167 |
potential_response = generated_text[last_occurrence_index + len(last_input_content):].strip()
|
168 |
if potential_response:
|
|
|
169 |
potential_response = re.sub(r'^<\/?s?>', '', potential_response).strip()
|
170 |
potential_response = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', potential_response).strip()
|
|
|
171 |
if potential_response:
|
172 |
return potential_response
|
173 |
|
|
|
174 |
cleaned_text = generated_text
|
175 |
if isinstance(input_prompt_list, list) and input_prompt_list:
|
176 |
+
first_prompt_content = input_prompt_list[0].get("content", "")
|
177 |
+
if first_prompt_content and cleaned_text.startswith(first_prompt_content):
|
178 |
+
pass # Rely on end-stripping heuristic
|
|
|
179 |
|
|
|
180 |
cleaned_text = re.sub(r'^<\/?s?>', '', cleaned_text).strip()
|
181 |
cleaned_text = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', cleaned_text).strip()
|
182 |
|
|
|
|
|
183 |
if not cleaned_text and generated_text:
|
184 |
+
print("Warning: Parsing resulted in empty string, returning original generation.")
|
185 |
+
return generated_text
|
186 |
|
|
|
|
|
187 |
if last_input_content and last_occurrence_index == -1:
|
188 |
print("Warning: Could not find last input prompt in LLM output. Returning cleaned full output.")
|
189 |
|
190 |
return cleaned_text
|
191 |
|
192 |
+
@GPU(memory=40) # Decorator needed for Spaces resource allocation during calls
|
193 |
+
def generate_enhanced_persona(name, bio_text, context=""):
|
194 |
"""Use the LLM to enhance the persona profile."""
|
195 |
+
pipe = load_model() # Load model within GPU context
|
196 |
print(f"Generating enhanced persona for {name}...")
|
|
|
|
|
197 |
enhancement_prompt = [
|
198 |
{"role": "system", "content": """You are an expert AI character developer. Your task is to synthesize information into a detailed and coherent character profile. Focus on personality, potential interests, speaking style, and mannerisms based ONLY on the provided text. If the text indicates the character is a child, ensure the profile reflects age-appropriate traits. Output ONLY the enhanced character profile description. Do not include conversational introductions, explanations, apologies for limited info, or markdown formatting like headers (e.g., ### Personality). Start directly with the profile text."""},
|
199 |
{"role": "user", "content": f"""Synthesize the following information about '{name}' into a character profile. Context: {context} Information Found:\n{bio_text}\n\nCreate the profile based *only* on the text above."""}
|
200 |
]
|
|
|
201 |
try:
|
202 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
203 |
+
outputs = pipe(enhancement_prompt, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9)
|
204 |
parsed_output = parse_llm_output(outputs, enhancement_prompt)
|
205 |
print("Enhanced persona generated.")
|
206 |
return parsed_output if parsed_output else f"Could not generate profile based on:\n{bio_text}"
|
|
|
209 |
print(error_msg)
|
210 |
return f"Error enhancing profile: {str(e)}\n\nUsing basic info:\n{bio_text}"
|
211 |
|
212 |
+
@GPU(memory=40) # Decorator needed for Spaces resource allocation during calls
|
213 |
+
def generate_system_prompt_with_llm(name, enhanced_profile, context=""):
|
214 |
"""Generate an optimized system prompt for the persona."""
|
215 |
+
pipe = load_model()
|
216 |
print(f"Generating system prompt for {name}...")
|
|
|
|
|
217 |
fallback_prompt = f"""You are simulating the character '{name}'. Act and respond according to this profile:\n{enhanced_profile}\nAdditional context for the simulation: {context}\n---\nMaintain this persona consistently. Respond naturally based on the profile. Do not mention that you are an AI or a simulation. If asked about details not in the profile, you can be evasive or state you don't know/remember, consistent with the persona."""
|
|
|
218 |
prompt = [
|
219 |
{"role": "system", "content": """You are an expert AI prompt engineer specializing in character simulation. Your task is to create a concise and effective system prompt for an LLM that will simulate a character based on a provided profile. The system prompt should instruct the LLM to embody the character, covering: 1. Core personality, attitude, and speaking style (based on the profile). 2. Key interests or knowledge areas (if mentioned in the profile). 3. How to handle questions outside its knowledge (e.g., be evasive, admit ignorance naturally). 4. Explicitly state it should *not* break character or mention being an AI. 5. Incorporate age-appropriateness if the profile suggests a specific age group. Output ONLY the system prompt itself. Do not add any explanation or introductory text."""},
|
220 |
{"role": "user", "content": f"""Create a system prompt for an AI to simulate the character '{name}'. Context for simulation: {context} Character Profile:\n{enhanced_profile}\n\nGenerate the system prompt based *only* on the profile and context provided."""}
|
221 |
]
|
|
|
222 |
try:
|
223 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
224 |
+
outputs = pipe(prompt, max_new_tokens=300, do_sample=True, temperature=0.6)
|
225 |
parsed_output = parse_llm_output(outputs, prompt)
|
226 |
print("System prompt generated.")
|
227 |
return parsed_output if parsed_output else fallback_prompt
|
|
|
230 |
print(error_msg)
|
231 |
return fallback_prompt
|
232 |
|
233 |
+
@GPU(memory=40) # Decorator needed for Spaces resource allocation during calls
|
234 |
+
def generate_response(messages):
|
235 |
"""Generate a response using the LLM."""
|
236 |
+
pipe = load_model()
|
237 |
print("Generating response...")
|
238 |
+
if not messages:
|
239 |
+
return "Error: No message history provided."
|
|
|
240 |
try:
|
241 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
242 |
+
outputs = pipe(
|
|
|
243 |
messages,
|
244 |
max_new_tokens=512,
|
245 |
do_sample=True,
|
246 |
top_p=0.9,
|
247 |
temperature=0.7,
|
248 |
+
pad_token_id=pipe.tokenizer.eos_token_id if pipe.tokenizer.eos_token_id else None
|
|
|
|
|
249 |
)
|
250 |
parsed_output = parse_llm_output(outputs, messages)
|
251 |
print("Response generated.")
|
|
|
253 |
except Exception as e:
|
254 |
error_msg = f"Error during response generation: {str(e)}"
|
255 |
print(error_msg)
|
|
|
256 |
return f"Sorry, I encountered an error trying to respond."
|
257 |
|
|
|
258 |
# --- Persona Chat Class ---
|
259 |
class PersonaChat:
|
260 |
def __init__(self):
|
|
|
261 |
self.system_prompt = "You are a helpful assistant."
|
262 |
self.persona_name = "Assistant"
|
263 |
self.persona_context = ""
|
264 |
self.messages = []
|
265 |
self.enhanced_profile = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
|
|
267 |
def set_persona(self, name, context=""):
|
268 |
"""Orchestrates persona creation: search, enhance, generate prompt."""
|
|
|
269 |
try:
|
|
|
|
|
270 |
self.persona_name = name
|
271 |
self.persona_context = context
|
272 |
self.messages = []
|
273 |
self.enhanced_profile = ""
|
274 |
|
275 |
status = f"Searching for information about {name}..."
|
276 |
+
yield status, "", "", [{"role": "system", "content": "Initializing persona creation..."}]
|
277 |
|
278 |
search_results = search_person(name, context)
|
279 |
if isinstance(search_results, str) and search_results.startswith("Error"):
|
|
|
283 |
|
284 |
bio_text = extract_text_from_search_results(search_results)
|
285 |
if bio_text.startswith("Could not extract text"):
|
286 |
+
yield f"Warning: {bio_text}", "", "", [{"role": "system", "content": bio_text}]
|
287 |
|
288 |
status = f"Creating enhanced profile for {name}..."
|
289 |
+
yield status, "", bio_text, [{"role": "system", "content": status}]
|
290 |
|
291 |
+
self.enhanced_profile = generate_enhanced_persona(name, bio_text, context)
|
|
|
292 |
profile_for_prompt = self.enhanced_profile
|
293 |
if self.enhanced_profile.startswith("Error enhancing profile"):
|
294 |
+
yield f"Warning: Could not enhance profile. Using basic info.", "", self.enhanced_profile, [{"role": "system", "content": self.enhanced_profile}]
|
295 |
+
profile_for_prompt = bio_text
|
296 |
|
297 |
status = f"Generating optimal system prompt for {name}..."
|
|
|
298 |
yield status, self.enhanced_profile, self.enhanced_profile, [{"role": "system", "content": status}]
|
299 |
|
300 |
+
self.system_prompt = generate_system_prompt_with_llm(name, profile_for_prompt, context)
|
|
|
301 |
self.messages = [{"role": "system", "content": self.system_prompt}]
|
302 |
|
303 |
yield f"Persona set to '{name}'. Ready to chat!", self.system_prompt, self.enhanced_profile, self.messages
|
304 |
|
|
|
|
|
|
|
|
|
305 |
except Exception as e:
|
306 |
error_msg = f"An unexpected error occurred during persona setup: {str(e)}"
|
307 |
print(error_msg)
|
|
|
308 |
yield error_msg, self.system_prompt, self.enhanced_profile, [{"role": "system", "content": error_msg}]
|
309 |
|
|
|
310 |
def chat(self, user_message):
|
311 |
"""Processes a user message and returns the AI's response."""
|
|
|
312 |
try:
|
|
|
|
|
313 |
if not self.messages:
|
314 |
+
print("Error: Chat called before persona was set.")
|
315 |
+
return "Please set a persona first using the controls above."
|
316 |
|
317 |
print(f"User message: {user_message}")
|
318 |
+
self.messages.append({"role": "user", "content": user_message})
|
|
|
|
|
319 |
|
320 |
+
response = generate_response(self.messages)
|
|
|
321 |
|
|
|
322 |
if not response.startswith("Sorry, I encountered an error"):
|
323 |
+
self.messages.append({"role": "assistant", "content": response})
|
324 |
+
print(f"Assistant response: {response}")
|
|
|
325 |
else:
|
326 |
+
print(f"Assistant error response: {response}")
|
|
|
|
|
327 |
|
328 |
return response
|
329 |
|
|
|
|
|
|
|
|
|
330 |
except Exception as e:
|
331 |
error_msg = f"Error generating response: {str(e)}"
|
332 |
print(error_msg)
|
333 |
return f"Sorry, I encountered an error: {str(e)}"
|
334 |
|
|
|
335 |
# --- Gradio Interface ---
|
336 |
def create_interface():
|
337 |
+
persona_chat = PersonaChat() # Instantiate the handler class
|
338 |
|
339 |
css = """
|
340 |
.gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
|
|
|
349 |
.persona-button { background-color: #4ca1af !important; color: white !important; }
|
350 |
.system-prompt-display { background-color: #f5f5f5; border-radius: 8px; padding: 15px; margin-top: 15px; border: 1px solid #e0e0e0; font-family: monospace; white-space: pre-wrap; word-wrap: break-word; }
|
351 |
.footer { text-align: center; margin-top: 20px; font-size: 0.9rem; color: #666; }
|
|
|
352 |
.typing-indicator { color: #aaa; font-style: italic; }
|
353 |
"""
|
354 |
|
|
|
371 |
enhanced_profile_display = gr.TextArea(label="Enhanced Profile (Generated by AI)", interactive=False, lines=10, elem_classes="system-prompt-display")
|
372 |
system_prompt_display = gr.TextArea(label="System Prompt (Instructions for the AI)", interactive=False, lines=10, elem_classes="system-prompt-display")
|
373 |
|
|
|
374 |
with gr.Column(elem_classes="chat-section"):
|
375 |
gr.Markdown("### 2. Chat with Your Character")
|
376 |
character_name_display = gr.Markdown(value="*No persona created yet*", elem_id="character-name-display")
|
|
|
377 |
chatbot = gr.Chatbot(
|
378 |
label="Conversation",
|
379 |
height=450,
|
380 |
elem_classes="chat-container",
|
381 |
+
avatar_images=(None, "🤖"), # User default, Bot emoji
|
382 |
+
type="messages" # Use recommended type
|
|
|
383 |
)
|
384 |
with gr.Row():
|
385 |
msg_input = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter...", elem_classes="message-input", scale=4)
|
|
|
398 |
initial_character_display = f"### Preparing to chat with {name}..."
|
399 |
initial_prompt = "System prompt will appear here..."
|
400 |
initial_profile = "Enhanced profile will appear here..."
|
|
|
401 |
initial_history = []
|
402 |
|
403 |
yield initial_status, initial_prompt, initial_profile, initial_character_display, initial_history
|
|
|
405 |
final_status, final_prompt, final_profile = "Error", "", ""
|
406 |
final_history = initial_history
|
407 |
try:
|
|
|
|
|
408 |
for status_update, prompt_update, profile_update, history_update in persona_chat.set_persona(name, context):
|
409 |
final_status, final_prompt, final_profile = status_update, prompt_update, profile_update
|
410 |
+
if isinstance(history_update, list):
|
411 |
+
final_history = history_update
|
412 |
|
413 |
character_display = f"### Preparing chat with {name}..."
|
414 |
if "Ready to chat" in status_update:
|
415 |
character_display = f"### Chatting with {name}"
|
416 |
elif "Error" in status_update:
|
417 |
+
character_display = f"### Error creating {name}"
|
418 |
|
419 |
yield status_update, final_prompt, final_profile, character_display, final_history
|
420 |
+
time.sleep(0.1) # Small delay for UI update visibility
|
421 |
|
422 |
except Exception as e:
|
423 |
+
error_msg = f"Failed to set persona (interface error): {str(e)}"
|
424 |
+
print(error_msg)
|
425 |
+
yield error_msg, final_prompt, final_profile, f"### Error creating {name}", final_history
|
|
|
|
|
426 |
|
427 |
def send_message_flow(message, history):
|
428 |
+
if history is None:
|
429 |
+
history = []
|
430 |
if not message.strip():
|
431 |
return "", history
|
432 |
|
|
|
433 |
if not persona_chat.messages or persona_chat.messages[0]['role'] != 'system':
|
434 |
history.append({"role": "user", "content": message})
|
435 |
history.append({"role": "assistant", "content": "Error: Please create a valid persona first."})
|
436 |
return "", history
|
437 |
|
|
|
438 |
history.append({"role": "user", "content": message})
|
439 |
+
history.append({"role": "assistant", "content": None}) # Typing indicator
|
|
|
440 |
|
441 |
+
yield "", history # Show user msg + typing
|
442 |
|
|
|
443 |
response_text = persona_chat.chat(message)
|
444 |
|
|
|
445 |
history[-1]["content"] = response_text
|
446 |
|
447 |
+
yield "", history # Show final response
|
|
|
448 |
|
449 |
set_persona_button.click(
|
450 |
set_persona_flow,
|
|
|
468 |
if __name__ == "__main__":
|
469 |
print("Starting Gradio application for Hugging Face Spaces...")
|
470 |
demo = create_interface()
|
471 |
+
demo.queue().launch(
|
472 |
server_name="0.0.0.0",
|
473 |
server_port=7860,
|
474 |
+
show_error=True,
|
475 |
+
debug=True
|
|
|
476 |
)
|