Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,37 +1,31 @@
|
|
1 |
# --- Required Installs ---
|
2 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
import gradio as gr
|
5 |
import transformers
|
6 |
import torch
|
7 |
-
from transformers import pipeline, BitsAndBytesConfig
|
8 |
from duckduckgo_search import DDGS
|
9 |
import re
|
10 |
import time
|
11 |
from huggingface_hub import HfApi
|
12 |
-
|
13 |
|
14 |
# --- Constants and Configuration ---
|
15 |
MODEL_ID = "nvidia/Llama-3.1-Nemotron-8B-UltraLong-4M-Instruct"
|
16 |
MAX_GPU_MEMORY = "40GiB" # A100 memory allocation
|
17 |
|
18 |
-
# --- GPU Decorator (Placeholder if not using HF Spaces GPU class) ---
|
19 |
-
# If not running on Hugging Face Spaces with their specific @GPU decorator,
|
20 |
-
# remove or comment out the @GPU decorators below.
|
21 |
-
# The resource allocation might need to be handled differently depending on your environment.
|
22 |
-
# For simplicity, assuming the decorator exists or is not strictly needed for function.
|
23 |
-
try:
|
24 |
-
from spaces import GPU
|
25 |
-
except ImportError:
|
26 |
-
print("Warning: 'spaces.GPU' not found. Assuming standard environment.")
|
27 |
-
# Define a dummy decorator if 'spaces' is not available
|
28 |
-
def GPU(memory=None):
|
29 |
-
def decorator(func):
|
30 |
-
return func
|
31 |
-
return decorator
|
32 |
-
|
33 |
# --- Model Loading ---
|
34 |
-
|
35 |
def load_model():
|
36 |
"""Load the LLM model optimized for A100 GPU."""
|
37 |
print(f"Attempting to load model: {MODEL_ID}")
|
@@ -43,47 +37,41 @@ def load_model():
|
|
43 |
"text-generation",
|
44 |
model=MODEL_ID,
|
45 |
torch_dtype=torch.bfloat16,
|
46 |
-
device_map="auto",
|
47 |
model_kwargs={
|
48 |
-
# Use quantization_config instead of load_in_8bit directly
|
49 |
"quantization_config": quantization_config,
|
50 |
"use_cache": True,
|
51 |
-
# max_memory
|
52 |
-
# but explicitly setting can be safer. Adjust if needed.
|
53 |
-
# "max_memory": {0: MAX_GPU_MEMORY} # Keep if necessary for your setup
|
54 |
}
|
55 |
)
|
56 |
-
print(f"Model {MODEL_ID} loaded successfully.")
|
57 |
return pipe
|
58 |
except Exception as e:
|
59 |
-
print(f"Error loading model '{MODEL_ID}': {e}")
|
60 |
-
#
|
61 |
-
|
62 |
-
return None #
|
63 |
|
64 |
# --- Web Search ---
|
|
|
65 |
def search_person(name, context=""):
|
66 |
"""Search for information about a person using DuckDuckGo."""
|
67 |
print(f"Searching for: {name} with context: {context}")
|
68 |
results = []
|
69 |
search_terms = []
|
70 |
|
71 |
-
# Prioritize context-specific search
|
72 |
if context:
|
73 |
search_terms.append(f"{name} {context}")
|
74 |
-
# Add grade-specific search if applicable
|
75 |
grade_match = re.search(r'(\d+)(?:st|nd|rd|th)?\s+grade', context.lower())
|
76 |
if grade_match:
|
77 |
grade = grade_match.group(1)
|
78 |
search_terms.append(f"{name} student {grade} grade")
|
79 |
|
80 |
-
|
81 |
-
search_terms.append(f"{name}") # Just the name
|
82 |
search_terms.append(f"{name} biography")
|
83 |
search_terms.append(f"{name} interests")
|
84 |
search_terms.append(f"{name} personality")
|
85 |
|
86 |
-
# Remove duplicates
|
87 |
search_terms = list(dict.fromkeys(search_terms))
|
88 |
print(f"Using search terms: {search_terms}")
|
89 |
|
@@ -91,14 +79,13 @@ def search_person(name, context=""):
|
|
91 |
with DDGS() as ddgs:
|
92 |
for term in search_terms:
|
93 |
print(f"Searching DDG for: '{term}'")
|
94 |
-
# Fetch fewer results per term to keep context concise
|
95 |
search_results = list(ddgs.text(term, max_results=2))
|
96 |
results.extend(search_results)
|
97 |
-
time.sleep(0.2)
|
98 |
except Exception as e:
|
99 |
error_msg = f"Error during DuckDuckGo search: {str(e)}"
|
100 |
print(error_msg)
|
101 |
-
return error_msg
|
102 |
|
103 |
if not results:
|
104 |
print(f"No search results found for {name}. Creating synthetic profile.")
|
@@ -114,39 +101,34 @@ def create_synthetic_profile(name, context):
|
|
114 |
"href": "",
|
115 |
"body": f"{name} is a person described with the context: '{context}'. "
|
116 |
}
|
117 |
-
|
118 |
-
# Try to infer age from grade
|
119 |
if "grade" in context.lower():
|
120 |
grade_match = re.search(r'(\d+)(?:st|nd|rd|th)?\s+grade', context.lower())
|
121 |
if grade_match:
|
122 |
try:
|
123 |
grade = int(grade_match.group(1))
|
124 |
-
age = 5 + grade
|
125 |
profile["body"] += f"Based on being in {grade}th grade, {name} is likely around {age} years old. "
|
126 |
profile["body"] += f"Typical interests for this age might include friends, hobbies, school subjects, and developing independence. "
|
127 |
except ValueError:
|
128 |
profile["body"] += f"The grade mentioned ('{grade_match.group(1)}') could not be parsed to estimate age. "
|
129 |
-
|
130 |
profile["body"] += "Since no public information was found, this profile is based solely on the provided context."
|
131 |
-
# Return as a list containing the dictionary, matching search_person's format
|
132 |
return [profile]
|
133 |
|
134 |
def extract_text_from_search_results(search_results):
|
135 |
"""Extract relevant text from search results."""
|
136 |
-
if isinstance(search_results, str):
|
137 |
return f"Could not extract text due to search error: {search_results}"
|
138 |
|
139 |
combined_text = ""
|
140 |
seen_bodies = set()
|
141 |
count = 0
|
142 |
-
max_results_to_process = 5
|
143 |
|
144 |
for result in search_results:
|
145 |
if count >= max_results_to_process:
|
146 |
break
|
147 |
if isinstance(result, dict) and 'body' in result and result['body']:
|
148 |
body = result['body'].strip()
|
149 |
-
# Avoid adding duplicate snippets
|
150 |
if body not in seen_bodies:
|
151 |
combined_text += body + "\n\n"
|
152 |
seen_bodies.add(body)
|
@@ -155,184 +137,138 @@ def extract_text_from_search_results(search_results):
|
|
155 |
if not combined_text:
|
156 |
return "No relevant text found in search results."
|
157 |
|
158 |
-
# Basic cleaning
|
159 |
combined_text = re.sub(r'\s+', ' ', combined_text).strip()
|
160 |
-
|
161 |
-
max_length = 2000 # Characters
|
162 |
return combined_text[:max_length] + "..." if len(combined_text) > max_length else combined_text
|
163 |
|
164 |
-
|
165 |
# --- LLM Generation Functions ---
|
166 |
|
167 |
def parse_llm_output(full_output, input_prompt_list):
|
168 |
-
"""
|
169 |
-
Attempts to parse only the newly generated text from the LLM output,
|
170 |
-
assuming the output might contain the input prompt messages.
|
171 |
-
"""
|
172 |
-
# If the output is a list of dicts (as expected from pipeline), get the text
|
173 |
if isinstance(full_output, list) and len(full_output) > 0:
|
174 |
if isinstance(full_output[0], dict) and "generated_text" in full_output[0]:
|
175 |
generated_text = full_output[0]["generated_text"]
|
176 |
-
else:
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
else:
|
181 |
-
return str(full_output) # Unexpected format
|
182 |
-
|
183 |
-
# Heuristic: Find the last message's content from the input prompt
|
184 |
-
# The actual formatting depends on the tokenizer's chat template.
|
185 |
-
# This is a simplified approach.
|
186 |
last_input_content = ""
|
187 |
if isinstance(input_prompt_list, list) and input_prompt_list:
|
|
|
|
|
188 |
last_input_content = input_prompt_list[-1].get("content", "")
|
189 |
|
190 |
-
# Try to find the last input message content in the generated text
|
191 |
-
# If found, take the text after it. This might fail if formatting differs.
|
192 |
if last_input_content:
|
193 |
last_occurrence_index = generated_text.rfind(last_input_content)
|
194 |
if last_occurrence_index != -1:
|
195 |
potential_response = generated_text[last_occurrence_index + len(last_input_content):].strip()
|
196 |
-
|
197 |
-
|
198 |
-
# Simple cleanup for potential role markers if model adds them
|
199 |
potential_response = re.sub(r'^<\/?s?>', '', potential_response).strip()
|
200 |
potential_response = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', potential_response).strip()
|
201 |
-
|
|
|
|
|
202 |
|
203 |
-
# Fallback
|
204 |
-
# Or, if the prompt asked for ONLY the output, the model might have behaved correctly.
|
205 |
-
# Let's clean up potential boilerplate often added by models
|
206 |
cleaned_text = generated_text
|
207 |
if isinstance(input_prompt_list, list) and input_prompt_list:
|
208 |
-
# Remove potential initial prompt remnants if possible (very basic)
|
209 |
first_prompt_content = input_prompt_list[0].get("content", "")
|
210 |
if first_prompt_content and cleaned_text.startswith(first_prompt_content):
|
211 |
-
|
|
|
212 |
|
213 |
-
#
|
214 |
cleaned_text = re.sub(r'^<\/?s?>', '', cleaned_text).strip()
|
215 |
cleaned_text = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', cleaned_text).strip()
|
216 |
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
|
|
|
|
|
221 |
def generate_enhanced_persona(model, name, bio_text, context=""):
|
222 |
"""Use the LLM to enhance the persona profile."""
|
223 |
print(f"Generating enhanced persona for {name}...")
|
224 |
-
if model is None:
|
225 |
-
raise ValueError("Model is not loaded.")
|
226 |
|
227 |
enhancement_prompt = [
|
228 |
-
{"role": "system", "content": """You are an expert AI character developer. Your task is to synthesize information into a detailed and coherent character profile. Focus on personality, potential interests, speaking style, and mannerisms based ONLY on the provided text. If the text indicates the character is a child, ensure the profile reflects age-appropriate traits.
|
229 |
-
|
230 |
-
Output ONLY the enhanced character profile description. Do not include conversational introductions, explanations, apologies for limited info, or markdown formatting like headers (e.g., ### Personality). Start directly with the profile text."""},
|
231 |
-
{"role": "user", "content": f"""Synthesize the following information about '{name}' into a character profile.
|
232 |
-
Context: {context}
|
233 |
-
Information Found:
|
234 |
-
{bio_text}
|
235 |
-
|
236 |
-
Create the profile based *only* on the text above."""}
|
237 |
]
|
238 |
|
239 |
try:
|
240 |
-
# Use torch.amp.autocast instead of torch.cuda.amp.autocast
|
241 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
242 |
outputs = model(enhancement_prompt, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9)
|
243 |
-
|
244 |
-
# Parse the output
|
245 |
parsed_output = parse_llm_output(outputs, enhancement_prompt)
|
246 |
print("Enhanced persona generated.")
|
247 |
-
|
248 |
-
return parsed_output if parsed_output else bio_text
|
249 |
-
|
250 |
except Exception as e:
|
251 |
error_msg = f"Error generating enhanced persona: {str(e)}"
|
252 |
print(error_msg)
|
253 |
-
# Fallback to the original bio text in case of error
|
254 |
return f"Error enhancing profile: {str(e)}\n\nUsing basic info:\n{bio_text}"
|
255 |
|
256 |
-
|
257 |
def generate_system_prompt_with_llm(model, name, enhanced_profile, context=""):
|
258 |
"""Generate an optimized system prompt for the persona."""
|
259 |
print(f"Generating system prompt for {name}...")
|
260 |
-
if model is None:
|
261 |
-
raise ValueError("Model is not loaded.")
|
262 |
|
263 |
-
fallback_prompt = f"""You are simulating the character '{name}'. Act and respond according to this profile:
|
264 |
-
{enhanced_profile}
|
265 |
-
Additional context for the simulation: {context}
|
266 |
-
---
|
267 |
-
Maintain this persona consistently. Respond naturally based on the profile. Do not mention that you are an AI or a simulation. If asked about details not in the profile, you can be evasive or state you don't know/remember, consistent with the persona."""
|
268 |
|
269 |
prompt = [
|
270 |
-
{"role": "system", "content": """You are an expert AI prompt engineer specializing in character simulation. Your task is to create a concise and effective system prompt for an LLM that will simulate a character based on a provided profile.
|
271 |
-
|
272 |
-
The system prompt should instruct the LLM to embody the character, covering:
|
273 |
-
1. Core personality, attitude, and speaking style (based on the profile).
|
274 |
-
2. Key interests or knowledge areas (if mentioned in the profile).
|
275 |
-
3. How to handle questions outside its knowledge (e.g., be evasive, admit ignorance naturally).
|
276 |
-
4. Explicitly state it should *not* break character or mention being an AI.
|
277 |
-
5. Incorporate age-appropriateness if the profile suggests a specific age group.
|
278 |
-
|
279 |
-
Output ONLY the system prompt itself. Do not add any explanation or introductory text."""},
|
280 |
-
{"role": "user", "content": f"""Create a system prompt for an AI to simulate the character '{name}'.
|
281 |
-
Context for simulation: {context}
|
282 |
-
Character Profile:
|
283 |
-
{enhanced_profile}
|
284 |
-
|
285 |
-
Generate the system prompt based *only* on the profile and context provided."""}
|
286 |
]
|
287 |
|
288 |
try:
|
289 |
-
# Use torch.amp.autocast instead of torch.cuda.amp.autocast
|
290 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
291 |
-
outputs = model(prompt, max_new_tokens=300, do_sample=True, temperature=0.6)
|
292 |
-
|
293 |
-
# Parse the output
|
294 |
parsed_output = parse_llm_output(outputs, prompt)
|
295 |
print("System prompt generated.")
|
296 |
-
# Return parsed output or fallback
|
297 |
return parsed_output if parsed_output else fallback_prompt
|
298 |
-
|
299 |
except Exception as e:
|
300 |
error_msg = f"Error generating system prompt: {str(e)}"
|
301 |
print(error_msg)
|
302 |
-
# Fallback to a basic system prompt in case of error
|
303 |
return fallback_prompt
|
304 |
|
305 |
-
|
306 |
def generate_response(model, messages):
|
307 |
"""Generate a response using the LLM."""
|
308 |
print("Generating response...")
|
309 |
-
if model is None:
|
310 |
-
|
311 |
-
if not messages:
|
312 |
-
return "Error: No message history provided."
|
313 |
|
314 |
try:
|
315 |
-
# Use torch.amp.autocast instead of torch.cuda.amp.autocast
|
316 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
|
|
317 |
outputs = model(
|
318 |
messages,
|
319 |
-
max_new_tokens=512,
|
320 |
do_sample=True,
|
321 |
top_p=0.9,
|
322 |
temperature=0.7,
|
323 |
use_cache=True,
|
324 |
-
|
|
|
325 |
)
|
326 |
-
|
327 |
-
# Parse the output - expecting only the assistant's new reply
|
328 |
parsed_output = parse_llm_output(outputs, messages)
|
329 |
print("Response generated.")
|
330 |
-
return parsed_output if parsed_output else "..."
|
331 |
-
|
332 |
except Exception as e:
|
333 |
error_msg = f"Error during response generation: {str(e)}"
|
334 |
print(error_msg)
|
335 |
-
|
|
|
336 |
|
337 |
|
338 |
# --- Persona Chat Class ---
|
@@ -344,140 +280,123 @@ class PersonaChat:
|
|
344 |
self.persona_context = ""
|
345 |
self.messages = []
|
346 |
self.enhanced_profile = ""
|
347 |
-
self.model_loaded = False
|
348 |
|
349 |
-
# @GPU
|
350 |
def load_model_if_needed(self):
|
351 |
"""Loads the model if it hasn't been loaded successfully."""
|
352 |
-
if not self.model_loaded:
|
353 |
-
print("Model not loaded
|
354 |
-
#
|
355 |
-
self.model = load_model()
|
356 |
if self.model is None:
|
357 |
-
#
|
358 |
raise RuntimeError("Failed to load the language model. Cannot proceed.")
|
359 |
else:
|
360 |
self.model_loaded = True
|
361 |
print("Model loaded successfully within PersonaChat instance.")
|
362 |
-
else:
|
363 |
-
print("Model already loaded.")
|
364 |
|
365 |
-
#
|
366 |
-
# @GPU(memory=40)
|
367 |
def set_persona(self, name, context=""):
|
368 |
"""Orchestrates persona creation: search, enhance, generate prompt."""
|
|
|
369 |
try:
|
370 |
-
#
|
371 |
-
self.load_model_if_needed() # This will raise RuntimeError if it fails
|
372 |
|
373 |
self.persona_name = name
|
374 |
self.persona_context = context
|
375 |
-
self.messages = []
|
376 |
-
self.enhanced_profile = ""
|
377 |
|
378 |
status = f"Searching for information about {name}..."
|
379 |
-
yield status, "", [{"role": "system", "content": "Initializing persona creation..."}]
|
380 |
|
381 |
search_results = search_person(name, context)
|
382 |
-
|
383 |
-
# Check if search returned an error string
|
384 |
if isinstance(search_results, str) and search_results.startswith("Error"):
|
385 |
error_msg = f"Failed to set persona: {search_results}"
|
386 |
-
yield error_msg, "", [{"role": "system", "content": error_msg}]
|
387 |
-
return
|
388 |
|
389 |
bio_text = extract_text_from_search_results(search_results)
|
390 |
if bio_text.startswith("Could not extract text"):
|
391 |
-
yield f"Warning: {bio_text}", "", [{"role": "system", "content": bio_text}]
|
392 |
-
# Continue with potentially limited info
|
393 |
|
394 |
status = f"Creating enhanced profile for {name}..."
|
395 |
-
yield status, "", [{"role": "system", "content": status}]
|
396 |
|
397 |
-
#
|
398 |
self.enhanced_profile = generate_enhanced_persona(self.model, name, bio_text, context)
|
399 |
-
|
400 |
if self.enhanced_profile.startswith("Error enhancing profile"):
|
401 |
-
yield f"Warning: Could not enhance profile. Using basic info.", "", [{"role": "system", "content": self.enhanced_profile}]
|
402 |
-
|
403 |
-
profile_for_prompt = bio_text
|
404 |
-
else:
|
405 |
-
profile_for_prompt = self.enhanced_profile
|
406 |
-
|
407 |
|
408 |
status = f"Generating optimal system prompt for {name}..."
|
409 |
-
|
|
|
410 |
|
411 |
-
#
|
412 |
self.system_prompt = generate_system_prompt_with_llm(self.model, name, profile_for_prompt, context)
|
413 |
-
|
414 |
-
# Set the initial system message for the chat history
|
415 |
self.messages = [{"role": "system", "content": self.system_prompt}]
|
416 |
|
417 |
-
yield f"Persona set to '{name}'. Ready to chat!", self.system_prompt, self.messages
|
418 |
|
419 |
except RuntimeError as e:
|
420 |
-
# Catch model loading errors from load_model_if_needed
|
421 |
error_msg = f"Critical Error: {str(e)}"
|
422 |
print(error_msg)
|
423 |
-
yield error_msg, "", [{"role": "system", "content": error_msg}]
|
424 |
except Exception as e:
|
425 |
-
# Catch other unexpected errors during persona setting
|
426 |
error_msg = f"An unexpected error occurred during persona setup: {str(e)}"
|
427 |
print(error_msg)
|
428 |
-
yield
|
429 |
-
|
430 |
|
431 |
-
#
|
432 |
-
# @GPU(memory=40)
|
433 |
def chat(self, user_message):
|
434 |
"""Processes a user message and returns the AI's response."""
|
|
|
435 |
try:
|
436 |
-
|
437 |
-
self.load_model_if_needed() # Raises RuntimeError if model failed to load initially
|
438 |
|
439 |
if not self.messages:
|
440 |
-
# This case should ideally be prevented by UI logic
|
441 |
-
# but handle it defensively.
|
442 |
print("Error: Chat called before persona was set.")
|
443 |
return "Please set a persona first using the controls above."
|
444 |
|
445 |
print(f"User message: {user_message}")
|
446 |
-
# Append user message (ensure correct format)
|
447 |
formatted_message = {"role": "user", "content": user_message}
|
|
|
448 |
self.messages.append(formatted_message)
|
449 |
|
450 |
-
#
|
451 |
response = generate_response(self.model, self.messages)
|
452 |
|
453 |
-
# Append assistant response
|
454 |
-
|
455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
|
457 |
-
print(f"Assistant response: {response}")
|
458 |
return response
|
459 |
|
460 |
except RuntimeError as e:
|
461 |
-
# Catch model loading errors
|
462 |
error_msg = f"Critical Error: {str(e)}. Cannot generate response."
|
463 |
print(error_msg)
|
464 |
return error_msg
|
465 |
except Exception as e:
|
466 |
-
# Catch errors during generation itself
|
467 |
error_msg = f"Error generating response: {str(e)}"
|
468 |
print(error_msg)
|
469 |
-
|
470 |
-
# Let's return the error string directly.
|
471 |
-
# We might want to avoid adding the error to self.messages history
|
472 |
-
return error_msg
|
473 |
|
474 |
|
475 |
# --- Gradio Interface ---
|
476 |
def create_interface():
|
477 |
-
# Instantiate the
|
478 |
-
persona_chat = PersonaChat()
|
479 |
|
480 |
-
# Custom CSS (minor adjustments possible)
|
481 |
css = """
|
482 |
.gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
|
483 |
.main-container { max-width: 1200px; margin: auto; padding: 0; }
|
@@ -491,105 +410,50 @@ def create_interface():
|
|
491 |
.persona-button { background-color: #4ca1af !important; color: white !important; }
|
492 |
.system-prompt-display { background-color: #f5f5f5; border-radius: 8px; padding: 15px; margin-top: 15px; border: 1px solid #e0e0e0; font-family: monospace; white-space: pre-wrap; word-wrap: break-word; }
|
493 |
.footer { text-align: center; margin-top: 20px; font-size: 0.9rem; color: #666; }
|
494 |
-
|
495 |
-
.bot-message > .message { background-color: #f1f3f5; border-radius: 15px 15px 15px 0 !important; padding: 10px 15px !important; margin: 8px auto 8px 0 !important; max-width: 80%; float: left; clear: both; color: #333; }
|
496 |
-
.message p { margin: 0 !important; padding: 0 !important; } /* Prevent extra margins in chatbot messages */
|
497 |
.typing-indicator { color: #aaa; font-style: italic; }
|
498 |
"""
|
499 |
|
500 |
with gr.Blocks(css=css, title="AI Persona Simulator") as interface:
|
501 |
with gr.Row(elem_classes="main-container"):
|
502 |
with gr.Column():
|
503 |
-
# Header
|
504 |
with gr.Column(elem_classes="header"):
|
505 |
gr.Markdown("# AI Persona Simulator")
|
506 |
gr.Markdown("Create and interact with AI-driven character simulations")
|
507 |
|
508 |
-
# Setup Section
|
509 |
with gr.Column(elem_classes="setup-section"):
|
510 |
gr.Markdown("### 1. Create Your Persona")
|
511 |
-
gr.Markdown("Enter a name and
|
512 |
-
|
513 |
with gr.Row():
|
514 |
-
name_input = gr.Textbox(
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
placeholder="e.g., Living in 221B Baker Street, London. OR 7th grade, loves math and video games, has a pet cat named Luna. OR A spaceship captain exploring Alpha Centauri.",
|
522 |
-
lines=2,
|
523 |
-
elem_id="context_input"
|
524 |
-
)
|
525 |
-
|
526 |
-
set_persona_button = gr.Button(
|
527 |
-
"Create Persona & Start Chat",
|
528 |
-
variant="primary",
|
529 |
-
elem_classes="persona-button"
|
530 |
-
)
|
531 |
|
532 |
-
status_output = gr.Textbox(
|
533 |
-
label="Status",
|
534 |
-
value="Enter details above and click 'Create Persona'.",
|
535 |
-
interactive=False,
|
536 |
-
elem_classes="status-bar"
|
537 |
-
)
|
538 |
|
539 |
-
with gr.Accordion("View Generated System Prompt", open=False):
|
540 |
-
system_prompt_display = gr.TextArea(
|
541 |
-
label="System Prompt (Instructions for the AI)",
|
542 |
-
interactive=False,
|
543 |
-
lines=10,
|
544 |
-
elem_classes="system-prompt-display" # Use dedicated class
|
545 |
-
)
|
546 |
-
enhanced_profile_display = gr.TextArea(
|
547 |
-
label="Enhanced Profile (Generated by AI)",
|
548 |
-
interactive=False,
|
549 |
-
lines=10,
|
550 |
-
elem_classes="system-prompt-display" # Reuse style or create new
|
551 |
-
)
|
552 |
-
|
553 |
-
|
554 |
-
# Chat Section
|
555 |
with gr.Column(elem_classes="chat-section"):
|
556 |
gr.Markdown("### 2. Chat with Your Character")
|
557 |
-
|
558 |
-
|
559 |
-
value="*No persona created yet*",
|
560 |
-
elem_id="character-name-display"
|
561 |
-
)
|
562 |
-
|
563 |
chatbot = gr.Chatbot(
|
564 |
label="Conversation",
|
565 |
height=450,
|
566 |
elem_classes="chat-container",
|
567 |
-
bubble_full_width=False, #
|
568 |
-
avatar_images=(None, "🤖") # User
|
|
|
569 |
)
|
570 |
-
|
571 |
with gr.Row():
|
572 |
-
msg_input = gr.Textbox(
|
573 |
-
|
574 |
-
|
575 |
-
elem_classes="message-input",
|
576 |
-
scale=4 # Make input wider
|
577 |
-
)
|
578 |
-
send_button = gr.Button(
|
579 |
-
"Send",
|
580 |
-
variant="primary",
|
581 |
-
elem_classes="send-button",
|
582 |
-
scale=1
|
583 |
-
)
|
584 |
-
|
585 |
-
# Footer
|
586 |
with gr.Column(elem_classes="footer"):
|
587 |
gr.Markdown(f"Powered by {MODEL_ID}")
|
588 |
|
589 |
-
|
590 |
# --- Event Handlers ---
|
591 |
-
|
592 |
-
# Generator function for smoother UI updates during persona creation
|
593 |
def set_persona_flow(name, context):
|
594 |
if not name:
|
595 |
yield "Status: Please enter a character name.", "", "", "*No persona created yet*", []
|
@@ -599,89 +463,74 @@ def create_interface():
|
|
599 |
initial_character_display = f"### Preparing to chat with {name}..."
|
600 |
initial_prompt = "System prompt will appear here..."
|
601 |
initial_profile = "Enhanced profile will appear here..."
|
602 |
-
|
|
|
603 |
|
604 |
-
# Initial yield to show activity starting
|
605 |
yield initial_status, initial_prompt, initial_profile, initial_character_display, initial_history
|
606 |
|
607 |
-
# Use the PersonaChat instance's method, which is a generator
|
608 |
final_status, final_prompt, final_profile = "Error", "", ""
|
609 |
final_history = initial_history
|
610 |
try:
|
611 |
-
#
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
if isinstance(history_update, list):
|
616 |
-
|
617 |
-
|
618 |
-
# Determine character display based on status
|
619 |
character_display = f"### Preparing chat with {name}..."
|
620 |
-
if "Ready to chat" in
|
621 |
character_display = f"### Chatting with {name}"
|
622 |
-
elif "Error" in
|
623 |
character_display = f"### Error creating {name}"
|
624 |
|
625 |
-
yield
|
626 |
-
#
|
627 |
-
# time.sleep(0.1)
|
628 |
|
629 |
except Exception as e:
|
630 |
-
|
631 |
-
error_msg = f"Failed to set persona due to an unexpected error: {str(e)}"
|
632 |
print(error_msg)
|
633 |
-
|
|
|
634 |
|
635 |
|
636 |
-
# Function to handle sending messages
|
637 |
def send_message_flow(message, history):
|
|
|
|
|
638 |
if not message.strip():
|
639 |
-
|
640 |
-
return "", history # Return unchanged history and clear input box
|
641 |
|
642 |
-
# Check if persona is ready (
|
643 |
if not persona_chat.messages or persona_chat.messages[0]['role'] != 'system':
|
644 |
-
# Persona not set or history is corrupted
|
645 |
history.append({"role": "user", "content": message})
|
646 |
-
history.append({"role": "assistant", "content": "Error: Please create a valid persona first
|
647 |
-
return "", history
|
648 |
|
649 |
-
# Append user message to
|
650 |
history.append({"role": "user", "content": message})
|
651 |
-
#
|
652 |
-
history.append({"role": "assistant", "content": None}) # Use None for
|
653 |
-
|
654 |
-
# Yield the updated history to show user message and typing indicator
|
655 |
-
yield "", history
|
656 |
|
657 |
-
|
658 |
-
response = persona_chat.chat(message) # This now uses the internal self.messages
|
659 |
|
660 |
-
#
|
661 |
-
|
662 |
|
663 |
-
#
|
664 |
-
|
665 |
-
yield "", history
|
666 |
|
|
|
667 |
|
668 |
-
# Connect Gradio components to functions
|
669 |
|
670 |
-
# Use the generator for persona setting
|
671 |
set_persona_button.click(
|
672 |
set_persona_flow,
|
673 |
inputs=[name_input, context_input],
|
674 |
outputs=[status_output, system_prompt_display, enhanced_profile_display, character_name_display, chatbot]
|
675 |
)
|
676 |
-
|
677 |
-
# Use the generator for sending messages
|
678 |
send_button.click(
|
679 |
send_message_flow,
|
680 |
inputs=[msg_input, chatbot],
|
681 |
outputs=[msg_input, chatbot]
|
682 |
)
|
683 |
-
|
684 |
-
# Allow submitting message with Enter key
|
685 |
msg_input.submit(
|
686 |
send_message_flow,
|
687 |
inputs=[msg_input, chatbot],
|
@@ -692,14 +541,12 @@ def create_interface():
|
|
692 |
|
693 |
# --- Main Execution ---
|
694 |
if __name__ == "__main__":
|
695 |
-
print("Starting Gradio application...")
|
696 |
-
# Ensure necessary packages are installed:
|
697 |
-
# pip install gradio transformers torch duckduckgo_search huggingface_hub accelerate bitsandbytes sentencepiece
|
698 |
demo = create_interface()
|
699 |
-
demo.queue().launch( #
|
700 |
-
server_name="0.0.0.0",
|
701 |
server_port=7860,
|
702 |
-
share=False
|
703 |
-
show_error=True, #
|
704 |
-
debug=True #
|
705 |
)
|
|
|
1 |
# --- Required Installs ---
|
2 |
+
# Ensure these are in your requirements.txt for Hugging Face Spaces
|
3 |
+
# gradio
|
4 |
+
# transformers
|
5 |
+
# torch
|
6 |
+
# duckduckgo_search
|
7 |
+
# huggingface_hub
|
8 |
+
# accelerate
|
9 |
+
# bitsandbytes
|
10 |
+
# sentencepiece
|
11 |
+
# spaces <--- Provided by the Spaces environment
|
12 |
|
13 |
import gradio as gr
|
14 |
import transformers
|
15 |
import torch
|
16 |
+
from transformers import pipeline, BitsAndBytesConfig
|
17 |
from duckduckgo_search import DDGS
|
18 |
import re
|
19 |
import time
|
20 |
from huggingface_hub import HfApi
|
21 |
+
from spaces import GPU # Directly import GPU from spaces - Crucial for HF Spaces
|
22 |
|
23 |
# --- Constants and Configuration ---
|
24 |
MODEL_ID = "nvidia/Llama-3.1-Nemotron-8B-UltraLong-4M-Instruct"
|
25 |
MAX_GPU_MEMORY = "40GiB" # A100 memory allocation
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# --- Model Loading ---
|
28 |
+
@GPU(memory=40) # ****** THIS DECORATOR IS ESSENTIAL FOR SPACES STARTUP ******
|
29 |
def load_model():
|
30 |
"""Load the LLM model optimized for A100 GPU."""
|
31 |
print(f"Attempting to load model: {MODEL_ID}")
|
|
|
37 |
"text-generation",
|
38 |
model=MODEL_ID,
|
39 |
torch_dtype=torch.bfloat16,
|
40 |
+
device_map="auto", # Relies on accelerate
|
41 |
model_kwargs={
|
|
|
42 |
"quantization_config": quantization_config,
|
43 |
"use_cache": True,
|
44 |
+
# "max_memory": {0: MAX_GPU_MEMORY} # Often handled by device_map="auto"
|
|
|
|
|
45 |
}
|
46 |
)
|
47 |
+
print(f"Model {MODEL_ID} loaded successfully on device: {pipe.device}")
|
48 |
return pipe
|
49 |
except Exception as e:
|
50 |
+
print(f"FATAL Error loading model '{MODEL_ID}': {e}")
|
51 |
+
# Raise the error to potentially get more detailed logs in Spaces
|
52 |
+
raise e
|
53 |
+
# return None # Returning None might hide the root cause in Spaces logs
|
54 |
|
55 |
# --- Web Search ---
|
56 |
+
# (Keep search_person, create_synthetic_profile, extract_text_from_search_results as before)
|
57 |
def search_person(name, context=""):
|
58 |
"""Search for information about a person using DuckDuckGo."""
|
59 |
print(f"Searching for: {name} with context: {context}")
|
60 |
results = []
|
61 |
search_terms = []
|
62 |
|
|
|
63 |
if context:
|
64 |
search_terms.append(f"{name} {context}")
|
|
|
65 |
grade_match = re.search(r'(\d+)(?:st|nd|rd|th)?\s+grade', context.lower())
|
66 |
if grade_match:
|
67 |
grade = grade_match.group(1)
|
68 |
search_terms.append(f"{name} student {grade} grade")
|
69 |
|
70 |
+
search_terms.append(f"{name}")
|
|
|
71 |
search_terms.append(f"{name} biography")
|
72 |
search_terms.append(f"{name} interests")
|
73 |
search_terms.append(f"{name} personality")
|
74 |
|
|
|
75 |
search_terms = list(dict.fromkeys(search_terms))
|
76 |
print(f"Using search terms: {search_terms}")
|
77 |
|
|
|
79 |
with DDGS() as ddgs:
|
80 |
for term in search_terms:
|
81 |
print(f"Searching DDG for: '{term}'")
|
|
|
82 |
search_results = list(ddgs.text(term, max_results=2))
|
83 |
results.extend(search_results)
|
84 |
+
time.sleep(0.2)
|
85 |
except Exception as e:
|
86 |
error_msg = f"Error during DuckDuckGo search: {str(e)}"
|
87 |
print(error_msg)
|
88 |
+
return error_msg
|
89 |
|
90 |
if not results:
|
91 |
print(f"No search results found for {name}. Creating synthetic profile.")
|
|
|
101 |
"href": "",
|
102 |
"body": f"{name} is a person described with the context: '{context}'. "
|
103 |
}
|
|
|
|
|
104 |
if "grade" in context.lower():
|
105 |
grade_match = re.search(r'(\d+)(?:st|nd|rd|th)?\s+grade', context.lower())
|
106 |
if grade_match:
|
107 |
try:
|
108 |
grade = int(grade_match.group(1))
|
109 |
+
age = 5 + grade
|
110 |
profile["body"] += f"Based on being in {grade}th grade, {name} is likely around {age} years old. "
|
111 |
profile["body"] += f"Typical interests for this age might include friends, hobbies, school subjects, and developing independence. "
|
112 |
except ValueError:
|
113 |
profile["body"] += f"The grade mentioned ('{grade_match.group(1)}') could not be parsed to estimate age. "
|
|
|
114 |
profile["body"] += "Since no public information was found, this profile is based solely on the provided context."
|
|
|
115 |
return [profile]
|
116 |
|
117 |
def extract_text_from_search_results(search_results):
|
118 |
"""Extract relevant text from search results."""
|
119 |
+
if isinstance(search_results, str):
|
120 |
return f"Could not extract text due to search error: {search_results}"
|
121 |
|
122 |
combined_text = ""
|
123 |
seen_bodies = set()
|
124 |
count = 0
|
125 |
+
max_results_to_process = 5
|
126 |
|
127 |
for result in search_results:
|
128 |
if count >= max_results_to_process:
|
129 |
break
|
130 |
if isinstance(result, dict) and 'body' in result and result['body']:
|
131 |
body = result['body'].strip()
|
|
|
132 |
if body not in seen_bodies:
|
133 |
combined_text += body + "\n\n"
|
134 |
seen_bodies.add(body)
|
|
|
137 |
if not combined_text:
|
138 |
return "No relevant text found in search results."
|
139 |
|
|
|
140 |
combined_text = re.sub(r'\s+', ' ', combined_text).strip()
|
141 |
+
max_length = 2000
|
|
|
142 |
return combined_text[:max_length] + "..." if len(combined_text) > max_length else combined_text
|
143 |
|
|
|
144 |
# --- LLM Generation Functions ---
|
145 |
|
146 |
def parse_llm_output(full_output, input_prompt_list):
|
147 |
+
"""Attempts to parse only the newly generated text from the LLM output."""
|
|
|
|
|
|
|
|
|
148 |
if isinstance(full_output, list) and len(full_output) > 0:
|
149 |
if isinstance(full_output[0], dict) and "generated_text" in full_output[0]:
|
150 |
generated_text = full_output[0]["generated_text"]
|
151 |
+
else: return str(full_output)
|
152 |
+
elif isinstance(full_output, str): generated_text = full_output
|
153 |
+
else: return str(full_output)
|
154 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
last_input_content = ""
|
156 |
if isinstance(input_prompt_list, list) and input_prompt_list:
|
157 |
+
# Find the last message with 'user' or 'system' role potentially?
|
158 |
+
# Let's stick to finding the last message content for simplicity
|
159 |
last_input_content = input_prompt_list[-1].get("content", "")
|
160 |
|
|
|
|
|
161 |
if last_input_content:
|
162 |
last_occurrence_index = generated_text.rfind(last_input_content)
|
163 |
if last_occurrence_index != -1:
|
164 |
potential_response = generated_text[last_occurrence_index + len(last_input_content):].strip()
|
165 |
+
if potential_response:
|
166 |
+
# Basic cleanup
|
|
|
167 |
potential_response = re.sub(r'^<\/?s?>', '', potential_response).strip()
|
168 |
potential_response = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', potential_response).strip()
|
169 |
+
# Check if the response is just whitespace or seems empty after cleanup
|
170 |
+
if potential_response:
|
171 |
+
return potential_response
|
172 |
|
173 |
+
# Fallback or if model correctly outputted only the response
|
|
|
|
|
174 |
cleaned_text = generated_text
|
175 |
if isinstance(input_prompt_list, list) and input_prompt_list:
|
|
|
176 |
first_prompt_content = input_prompt_list[0].get("content", "")
|
177 |
if first_prompt_content and cleaned_text.startswith(first_prompt_content):
|
178 |
+
# Be careful not to strip if the response happens to start the same way
|
179 |
+
pass # Let's rely more on the end-stripping heuristic above
|
180 |
|
181 |
+
# General cleanup
|
182 |
cleaned_text = re.sub(r'^<\/?s?>', '', cleaned_text).strip()
|
183 |
cleaned_text = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', cleaned_text).strip()
|
184 |
|
185 |
+
# If after all this, it's empty, maybe return original generated_text?
|
186 |
+
# Or log a warning and return the cleaned version.
|
187 |
+
if not cleaned_text and generated_text:
|
188 |
+
print("Warning: Parsing resulted in empty string, returning original generation.")
|
189 |
+
return generated_text # Return original if cleaning failed
|
190 |
+
|
191 |
+
# If input prompt wasn't found, assume the model outputted only the response (ideal case)
|
192 |
+
# or the whole thing (fallback case). The cleaning helps for the latter.
|
193 |
+
if last_input_content and last_occurrence_index == -1:
|
194 |
+
print("Warning: Could not find last input prompt in LLM output. Returning cleaned full output.")
|
195 |
|
196 |
+
return cleaned_text
|
197 |
+
|
198 |
+
@GPU(memory=40) # Decorator needed for Spaces resource allocation during calls
|
199 |
def generate_enhanced_persona(model, name, bio_text, context=""):
|
200 |
"""Use the LLM to enhance the persona profile."""
|
201 |
print(f"Generating enhanced persona for {name}...")
|
202 |
+
if model is None: raise ValueError("Model is not loaded.")
|
|
|
203 |
|
204 |
enhancement_prompt = [
|
205 |
+
{"role": "system", "content": """You are an expert AI character developer. Your task is to synthesize information into a detailed and coherent character profile. Focus on personality, potential interests, speaking style, and mannerisms based ONLY on the provided text. If the text indicates the character is a child, ensure the profile reflects age-appropriate traits. Output ONLY the enhanced character profile description. Do not include conversational introductions, explanations, apologies for limited info, or markdown formatting like headers (e.g., ### Personality). Start directly with the profile text."""},
|
206 |
+
{"role": "user", "content": f"""Synthesize the following information about '{name}' into a character profile. Context: {context} Information Found:\n{bio_text}\n\nCreate the profile based *only* on the text above."""}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
]
|
208 |
|
209 |
try:
|
|
|
210 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
211 |
outputs = model(enhancement_prompt, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9)
|
|
|
|
|
212 |
parsed_output = parse_llm_output(outputs, enhancement_prompt)
|
213 |
print("Enhanced persona generated.")
|
214 |
+
return parsed_output if parsed_output else f"Could not generate profile based on:\n{bio_text}"
|
|
|
|
|
215 |
except Exception as e:
|
216 |
error_msg = f"Error generating enhanced persona: {str(e)}"
|
217 |
print(error_msg)
|
|
|
218 |
return f"Error enhancing profile: {str(e)}\n\nUsing basic info:\n{bio_text}"
|
219 |
|
220 |
+
@GPU(memory=40) # Decorator needed for Spaces resource allocation during calls
|
221 |
def generate_system_prompt_with_llm(model, name, enhanced_profile, context=""):
|
222 |
"""Generate an optimized system prompt for the persona."""
|
223 |
print(f"Generating system prompt for {name}...")
|
224 |
+
if model is None: raise ValueError("Model is not loaded.")
|
|
|
225 |
|
226 |
+
fallback_prompt = f"""You are simulating the character '{name}'. Act and respond according to this profile:\n{enhanced_profile}\nAdditional context for the simulation: {context}\n---\nMaintain this persona consistently. Respond naturally based on the profile. Do not mention that you are an AI or a simulation. If asked about details not in the profile, you can be evasive or state you don't know/remember, consistent with the persona."""
|
|
|
|
|
|
|
|
|
227 |
|
228 |
prompt = [
|
229 |
+
{"role": "system", "content": """You are an expert AI prompt engineer specializing in character simulation. Your task is to create a concise and effective system prompt for an LLM that will simulate a character based on a provided profile. The system prompt should instruct the LLM to embody the character, covering: 1. Core personality, attitude, and speaking style (based on the profile). 2. Key interests or knowledge areas (if mentioned in the profile). 3. How to handle questions outside its knowledge (e.g., be evasive, admit ignorance naturally). 4. Explicitly state it should *not* break character or mention being an AI. 5. Incorporate age-appropriateness if the profile suggests a specific age group. Output ONLY the system prompt itself. Do not add any explanation or introductory text."""},
|
230 |
+
{"role": "user", "content": f"""Create a system prompt for an AI to simulate the character '{name}'. Context for simulation: {context} Character Profile:\n{enhanced_profile}\n\nGenerate the system prompt based *only* on the profile and context provided."""}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
]
|
232 |
|
233 |
try:
|
|
|
234 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
235 |
+
outputs = model(prompt, max_new_tokens=300, do_sample=True, temperature=0.6)
|
|
|
|
|
236 |
parsed_output = parse_llm_output(outputs, prompt)
|
237 |
print("System prompt generated.")
|
|
|
238 |
return parsed_output if parsed_output else fallback_prompt
|
|
|
239 |
except Exception as e:
|
240 |
error_msg = f"Error generating system prompt: {str(e)}"
|
241 |
print(error_msg)
|
|
|
242 |
return fallback_prompt
|
243 |
|
244 |
+
@GPU(memory=40) # Decorator needed for Spaces resource allocation during calls
|
245 |
def generate_response(model, messages):
|
246 |
"""Generate a response using the LLM."""
|
247 |
print("Generating response...")
|
248 |
+
if model is None: raise ValueError("Model is not loaded.")
|
249 |
+
if not messages: return "Error: No message history provided."
|
|
|
|
|
250 |
|
251 |
try:
|
|
|
252 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
253 |
+
# Ensure pad_token_id is set correctly if needed, especially for batching or specific models
|
254 |
outputs = model(
|
255 |
messages,
|
256 |
+
max_new_tokens=512,
|
257 |
do_sample=True,
|
258 |
top_p=0.9,
|
259 |
temperature=0.7,
|
260 |
use_cache=True,
|
261 |
+
# Check if EOS token is needed for this model/pipeline setup
|
262 |
+
pad_token_id=model.tokenizer.eos_token_id if model.tokenizer.eos_token_id else None
|
263 |
)
|
|
|
|
|
264 |
parsed_output = parse_llm_output(outputs, messages)
|
265 |
print("Response generated.")
|
266 |
+
return parsed_output if parsed_output else "..."
|
|
|
267 |
except Exception as e:
|
268 |
error_msg = f"Error during response generation: {str(e)}"
|
269 |
print(error_msg)
|
270 |
+
# Consider if the specific error should be shown to the user
|
271 |
+
return f"Sorry, I encountered an error trying to respond."
|
272 |
|
273 |
|
274 |
# --- Persona Chat Class ---
|
|
|
280 |
self.persona_context = ""
|
281 |
self.messages = []
|
282 |
self.enhanced_profile = ""
|
283 |
+
self.model_loaded = False
|
284 |
|
285 |
+
# No @GPU decorator needed here typically, as it calls functions that ARE decorated
|
286 |
def load_model_if_needed(self):
|
287 |
"""Loads the model if it hasn't been loaded successfully."""
|
288 |
+
if not self.model_loaded or self.model is None: # Check self.model too
|
289 |
+
print("Model not loaded or instance lost. Attempting to load...")
|
290 |
+
# Call the @GPU decorated load_model function
|
291 |
+
self.model = load_model() # This function IS decorated
|
292 |
if self.model is None:
|
293 |
+
# load_model now raises error, but double-check here
|
294 |
raise RuntimeError("Failed to load the language model. Cannot proceed.")
|
295 |
else:
|
296 |
self.model_loaded = True
|
297 |
print("Model loaded successfully within PersonaChat instance.")
|
298 |
+
# else: print("Model already loaded.") # Reduce log noise
|
|
|
299 |
|
300 |
+
# No @GPU decorator needed here typically
|
|
|
301 |
def set_persona(self, name, context=""):
|
302 |
"""Orchestrates persona creation: search, enhance, generate prompt."""
|
303 |
+
# This method calls other functions that have @GPU decorators
|
304 |
try:
|
305 |
+
self.load_model_if_needed() # Ensures model is ready
|
|
|
306 |
|
307 |
self.persona_name = name
|
308 |
self.persona_context = context
|
309 |
+
self.messages = []
|
310 |
+
self.enhanced_profile = ""
|
311 |
|
312 |
status = f"Searching for information about {name}..."
|
313 |
+
yield status, "", "", [{"role": "system", "content": "Initializing persona creation..."}] # Added empty profile yield
|
314 |
|
315 |
search_results = search_person(name, context)
|
|
|
|
|
316 |
if isinstance(search_results, str) and search_results.startswith("Error"):
|
317 |
error_msg = f"Failed to set persona: {search_results}"
|
318 |
+
yield error_msg, "", "", [{"role": "system", "content": error_msg}]
|
319 |
+
return
|
320 |
|
321 |
bio_text = extract_text_from_search_results(search_results)
|
322 |
if bio_text.startswith("Could not extract text"):
|
323 |
+
yield f"Warning: {bio_text}", "", "", [{"role": "system", "content": bio_text}]
|
|
|
324 |
|
325 |
status = f"Creating enhanced profile for {name}..."
|
326 |
+
yield status, "", bio_text, [{"role": "system", "content": status}] # Show basic bio while enhancing
|
327 |
|
328 |
+
# Call the @GPU decorated function
|
329 |
self.enhanced_profile = generate_enhanced_persona(self.model, name, bio_text, context)
|
330 |
+
profile_for_prompt = self.enhanced_profile
|
331 |
if self.enhanced_profile.startswith("Error enhancing profile"):
|
332 |
+
yield f"Warning: Could not enhance profile. Using basic info.", "", self.enhanced_profile, [{"role": "system", "content": self.enhanced_profile}]
|
333 |
+
profile_for_prompt = bio_text # Fallback
|
|
|
|
|
|
|
|
|
334 |
|
335 |
status = f"Generating optimal system prompt for {name}..."
|
336 |
+
# Yield the enhanced profile while generating prompt
|
337 |
+
yield status, self.enhanced_profile, self.enhanced_profile, [{"role": "system", "content": status}]
|
338 |
|
339 |
+
# Call the @GPU decorated function
|
340 |
self.system_prompt = generate_system_prompt_with_llm(self.model, name, profile_for_prompt, context)
|
|
|
|
|
341 |
self.messages = [{"role": "system", "content": self.system_prompt}]
|
342 |
|
343 |
+
yield f"Persona set to '{name}'. Ready to chat!", self.system_prompt, self.enhanced_profile, self.messages
|
344 |
|
345 |
except RuntimeError as e:
|
|
|
346 |
error_msg = f"Critical Error: {str(e)}"
|
347 |
print(error_msg)
|
348 |
+
yield error_msg, "", "", [{"role": "system", "content": error_msg}]
|
349 |
except Exception as e:
|
|
|
350 |
error_msg = f"An unexpected error occurred during persona setup: {str(e)}"
|
351 |
print(error_msg)
|
352 |
+
# Attempt to yield current state even on error
|
353 |
+
yield error_msg, self.system_prompt, self.enhanced_profile, [{"role": "system", "content": error_msg}]
|
354 |
|
355 |
+
# No @GPU decorator needed here typically
|
|
|
356 |
def chat(self, user_message):
|
357 |
"""Processes a user message and returns the AI's response."""
|
358 |
+
# This method calls generate_response which has the @GPU decorator
|
359 |
try:
|
360 |
+
self.load_model_if_needed()
|
|
|
361 |
|
362 |
if not self.messages:
|
|
|
|
|
363 |
print("Error: Chat called before persona was set.")
|
364 |
return "Please set a persona first using the controls above."
|
365 |
|
366 |
print(f"User message: {user_message}")
|
|
|
367 |
formatted_message = {"role": "user", "content": user_message}
|
368 |
+
# Keep internal history, pass copy to model if needed, but pipeline usually handles state
|
369 |
self.messages.append(formatted_message)
|
370 |
|
371 |
+
# Call the @GPU decorated function
|
372 |
response = generate_response(self.model, self.messages)
|
373 |
|
374 |
+
# Append assistant response IF generation succeeded
|
375 |
+
if not response.startswith("Sorry, I encountered an error"):
|
376 |
+
assistant_message = {"role": "assistant", "content": response}
|
377 |
+
self.messages.append(assistant_message)
|
378 |
+
print(f"Assistant response: {response}")
|
379 |
+
else:
|
380 |
+
print(f"Assistant error response: {response}")
|
381 |
+
# Do not add the error message itself to the persistent history
|
382 |
+
# Let the UI show the error, but don't make the bot repeat it next turn.
|
383 |
|
|
|
384 |
return response
|
385 |
|
386 |
except RuntimeError as e:
|
|
|
387 |
error_msg = f"Critical Error: {str(e)}. Cannot generate response."
|
388 |
print(error_msg)
|
389 |
return error_msg
|
390 |
except Exception as e:
|
|
|
391 |
error_msg = f"Error generating response: {str(e)}"
|
392 |
print(error_msg)
|
393 |
+
return f"Sorry, I encountered an error: {str(e)}"
|
|
|
|
|
|
|
394 |
|
395 |
|
396 |
# --- Gradio Interface ---
|
397 |
def create_interface():
|
398 |
+
persona_chat = PersonaChat() # Instantiate the handler class
|
|
|
399 |
|
|
|
400 |
css = """
|
401 |
.gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
|
402 |
.main-container { max-width: 1200px; margin: auto; padding: 0; }
|
|
|
410 |
.persona-button { background-color: #4ca1af !important; color: white !important; }
|
411 |
.system-prompt-display { background-color: #f5f5f5; border-radius: 8px; padding: 15px; margin-top: 15px; border: 1px solid #e0e0e0; font-family: monospace; white-space: pre-wrap; word-wrap: break-word; }
|
412 |
.footer { text-align: center; margin-top: 20px; font-size: 0.9rem; color: #666; }
|
413 |
+
/* Use default chatbot message styling provided by type='messages' */
|
|
|
|
|
414 |
.typing-indicator { color: #aaa; font-style: italic; }
|
415 |
"""
|
416 |
|
417 |
with gr.Blocks(css=css, title="AI Persona Simulator") as interface:
|
418 |
with gr.Row(elem_classes="main-container"):
|
419 |
with gr.Column():
|
|
|
420 |
with gr.Column(elem_classes="header"):
|
421 |
gr.Markdown("# AI Persona Simulator")
|
422 |
gr.Markdown("Create and interact with AI-driven character simulations")
|
423 |
|
|
|
424 |
with gr.Column(elem_classes="setup-section"):
|
425 |
gr.Markdown("### 1. Create Your Persona")
|
426 |
+
gr.Markdown("Enter a name and context. The AI will search, build a profile, and prepare for chat.")
|
|
|
427 |
with gr.Row():
|
428 |
+
name_input = gr.Textbox(label="Character Name", placeholder="e.g., Sherlock Holmes, Erenalp, A curious 7th grader", elem_id="name_input")
|
429 |
+
context_input = gr.Textbox(label="Character Context / Description", placeholder="e.g., Living in 221B Baker Street, London. OR 7th grade, loves math...", lines=2, elem_id="context_input")
|
430 |
+
set_persona_button = gr.Button("Create Persona & Start Chat", variant="primary", elem_classes="persona-button")
|
431 |
+
status_output = gr.Textbox(label="Status", value="Enter details above and click 'Create Persona'.", interactive=False, elem_classes="status-bar")
|
432 |
+
with gr.Accordion("View Generated Details", open=False):
|
433 |
+
enhanced_profile_display = gr.TextArea(label="Enhanced Profile (Generated by AI)", interactive=False, lines=10, elem_classes="system-prompt-display")
|
434 |
+
system_prompt_display = gr.TextArea(label="System Prompt (Instructions for the AI)", interactive=False, lines=10, elem_classes="system-prompt-display")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
with gr.Column(elem_classes="chat-section"):
|
438 |
gr.Markdown("### 2. Chat with Your Character")
|
439 |
+
character_name_display = gr.Markdown(value="*No persona created yet*", elem_id="character-name-display")
|
440 |
+
# ***** FIX GRADIO WARNINGS *****
|
|
|
|
|
|
|
|
|
441 |
chatbot = gr.Chatbot(
|
442 |
label="Conversation",
|
443 |
height=450,
|
444 |
elem_classes="chat-container",
|
445 |
+
# bubble_full_width=False, # Deprecated
|
446 |
+
avatar_images=(None, "🤖"), # User default, Bot emoji
|
447 |
+
type="messages" # ***** USE RECOMMENDED TYPE *****
|
448 |
)
|
|
|
449 |
with gr.Row():
|
450 |
+
msg_input = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter...", elem_classes="message-input", scale=4)
|
451 |
+
send_button = gr.Button("Send", variant="primary", elem_classes="send-button", scale=1)
|
452 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
with gr.Column(elem_classes="footer"):
|
454 |
gr.Markdown(f"Powered by {MODEL_ID}")
|
455 |
|
|
|
456 |
# --- Event Handlers ---
|
|
|
|
|
457 |
def set_persona_flow(name, context):
|
458 |
if not name:
|
459 |
yield "Status: Please enter a character name.", "", "", "*No persona created yet*", []
|
|
|
463 |
initial_character_display = f"### Preparing to chat with {name}..."
|
464 |
initial_prompt = "System prompt will appear here..."
|
465 |
initial_profile = "Enhanced profile will appear here..."
|
466 |
+
# Start with empty history for messages type
|
467 |
+
initial_history = []
|
468 |
|
|
|
469 |
yield initial_status, initial_prompt, initial_profile, initial_character_display, initial_history
|
470 |
|
|
|
471 |
final_status, final_prompt, final_profile = "Error", "", ""
|
472 |
final_history = initial_history
|
473 |
try:
|
474 |
+
# Use the PersonaChat instance's method generator
|
475 |
+
# Expected yield order: status, system_prompt, enhanced_profile, messages_list
|
476 |
+
for status_update, prompt_update, profile_update, history_update in persona_chat.set_persona(name, context):
|
477 |
+
final_status, final_prompt, final_profile = status_update, prompt_update, profile_update
|
478 |
+
if isinstance(history_update, list): final_history = history_update
|
479 |
+
|
|
|
|
|
480 |
character_display = f"### Preparing chat with {name}..."
|
481 |
+
if "Ready to chat" in status_update:
|
482 |
character_display = f"### Chatting with {name}"
|
483 |
+
elif "Error" in status_update:
|
484 |
character_display = f"### Error creating {name}"
|
485 |
|
486 |
+
yield status_update, final_prompt, final_profile, character_display, final_history
|
487 |
+
time.sleep(0.1) # Small delay for UI update visibility
|
|
|
488 |
|
489 |
except Exception as e:
|
490 |
+
error_msg = f"Failed to set persona (interface error): {str(e)}"
|
|
|
491 |
print(error_msg)
|
492 |
+
# Try to yield error state
|
493 |
+
yield error_msg, final_prompt, final_profile, f"### Error creating {name}", final_history
|
494 |
|
495 |
|
|
|
496 |
def send_message_flow(message, history):
|
497 |
+
# Ensure history is a list (for messages type)
|
498 |
+
if history is None: history = []
|
499 |
if not message.strip():
|
500 |
+
return "", history
|
|
|
501 |
|
502 |
+
# Check if persona is ready (looks for system message in internal state)
|
503 |
if not persona_chat.messages or persona_chat.messages[0]['role'] != 'system':
|
|
|
504 |
history.append({"role": "user", "content": message})
|
505 |
+
history.append({"role": "assistant", "content": "Error: Please create a valid persona first."})
|
506 |
+
return "", history
|
507 |
|
508 |
+
# Append user message to UI history
|
509 |
history.append({"role": "user", "content": message})
|
510 |
+
# Append placeholder for bot response (typing indicator)
|
511 |
+
history.append({"role": "assistant", "content": None}) # Use None for typing indicator with type='messages'
|
|
|
|
|
|
|
512 |
|
513 |
+
yield "", history # Update UI to show user msg + typing
|
|
|
514 |
|
515 |
+
# Call chat method (uses internal state, returns string response)
|
516 |
+
response_text = persona_chat.chat(message)
|
517 |
|
518 |
+
# Update the placeholder in UI history with the actual response
|
519 |
+
history[-1]["content"] = response_text
|
|
|
520 |
|
521 |
+
yield "", history # Update UI with final response
|
522 |
|
|
|
523 |
|
|
|
524 |
set_persona_button.click(
|
525 |
set_persona_flow,
|
526 |
inputs=[name_input, context_input],
|
527 |
outputs=[status_output, system_prompt_display, enhanced_profile_display, character_name_display, chatbot]
|
528 |
)
|
|
|
|
|
529 |
send_button.click(
|
530 |
send_message_flow,
|
531 |
inputs=[msg_input, chatbot],
|
532 |
outputs=[msg_input, chatbot]
|
533 |
)
|
|
|
|
|
534 |
msg_input.submit(
|
535 |
send_message_flow,
|
536 |
inputs=[msg_input, chatbot],
|
|
|
541 |
|
542 |
# --- Main Execution ---
|
543 |
if __name__ == "__main__":
|
544 |
+
print("Starting Gradio application for Hugging Face Spaces...")
|
|
|
|
|
545 |
demo = create_interface()
|
546 |
+
demo.queue().launch( # queue() is recommended for Spaces
|
547 |
+
server_name="0.0.0.0",
|
548 |
server_port=7860,
|
549 |
+
# share=False is default and usually needed for Spaces deployment structure
|
550 |
+
show_error=True, # Good for debugging in Spaces logs
|
551 |
+
debug=True # More verbose logging
|
552 |
)
|