Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
-
# AI Persona Simulator -
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
-
from transformers import pipeline
|
5 |
from duckduckgo_search import DDGS
|
6 |
import re
|
7 |
import time
|
@@ -24,7 +23,6 @@ MAX_GPU_MEMORY = "60GiB"
|
|
24 |
@GPU(memory=60)
|
25 |
def load_model():
|
26 |
"""Load the Gemma 3 1B model without quantization for full precision."""
|
27 |
-
print(f"Attempting to load model: {MODEL_ID} without quantization")
|
28 |
try:
|
29 |
pipe = pipeline(
|
30 |
"text-generation",
|
@@ -33,7 +31,7 @@ def load_model():
|
|
33 |
device_map="auto",
|
34 |
model_kwargs={"use_cache": True}
|
35 |
)
|
36 |
-
print(f"Model {MODEL_ID} loaded successfully on device: {pipe.device}
|
37 |
return pipe
|
38 |
except Exception as e:
|
39 |
print(f"FATAL Error loading model '{MODEL_ID}': {e}")
|
@@ -53,8 +51,7 @@ CRITERIA:
|
|
53 |
3. NO manipulation/exploitation attempts
|
54 |
4. NO illegal/harmful scenarios
|
55 |
5. NO inappropriate relationships
|
56 |
-
Respond ONLY with "TRUE" if acceptable, "FALSE" if not."""
|
57 |
-
},
|
58 |
{"role": "user", "content": f"Character Name: {name}\nContext: {context}"}
|
59 |
]
|
60 |
|
@@ -64,6 +61,7 @@ Respond ONLY with "TRUE" if acceptable, "FALSE" if not."""
|
|
64 |
add_generation_prompt=True,
|
65 |
tokenize=False
|
66 |
)
|
|
|
67 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
68 |
outputs = pipe(
|
69 |
text,
|
@@ -72,6 +70,7 @@ Respond ONLY with "TRUE" if acceptable, "FALSE" if not."""
|
|
72 |
temperature=0.1,
|
73 |
pad_token_id=pipe.tokenizer.eos_token_id
|
74 |
)
|
|
|
75 |
result = parse_llm_output(outputs, validation_prompt).strip().upper()
|
76 |
return result == "TRUE"
|
77 |
except Exception as e:
|
@@ -118,54 +117,7 @@ def search_person(name, context=""):
|
|
118 |
print(f"Found {len(results)} potential search results.")
|
119 |
return results
|
120 |
|
121 |
-
|
122 |
-
"""Create a synthetic profile when search returns no results."""
|
123 |
-
profile = {
|
124 |
-
"title": f"Synthetic Profile for {name}",
|
125 |
-
"href": "",
|
126 |
-
"body": f"{name} is a person described with the context: '{context}'. "
|
127 |
-
}
|
128 |
-
if "grade" in context.lower():
|
129 |
-
grade_match = re.search(r'(\d+)(?:st|nd|rd|th)?\s+grade', context.lower())
|
130 |
-
if grade_match:
|
131 |
-
try:
|
132 |
-
grade = int(grade_match.group(1))
|
133 |
-
age = 5 + grade
|
134 |
-
profile["body"] += f"Based on being in {grade}th grade, {name} is likely around {age} years old. "
|
135 |
-
profile["body"] += f"Typical interests for this age might include friends, hobbies, school subjects, and developing independence. "
|
136 |
-
except ValueError:
|
137 |
-
profile["body"] += f"The grade mentioned ('{grade_match.group(1)}') could not be parsed to estimate age. "
|
138 |
-
profile["body"] += "Since no public information was found, this profile is based solely on the provided context."
|
139 |
-
return [profile]
|
140 |
-
|
141 |
-
def extract_text_from_search_results(search_results):
|
142 |
-
"""Extract relevant text from search results."""
|
143 |
-
if isinstance(search_results, str):
|
144 |
-
return f"Could not extract text due to search error: {search_results}"
|
145 |
-
|
146 |
-
combined_text = ""
|
147 |
-
seen_bodies = set()
|
148 |
-
count = 0
|
149 |
-
max_results_to_process = 5
|
150 |
-
|
151 |
-
for result in search_results:
|
152 |
-
if count >= max_results_to_process:
|
153 |
-
break
|
154 |
-
if isinstance(result, dict) and 'body' in result and result['body']:
|
155 |
-
body = result['body'].strip()
|
156 |
-
if body not in seen_bodies:
|
157 |
-
combined_text += body + "\n"
|
158 |
-
seen_bodies.add(body)
|
159 |
-
count += 1
|
160 |
-
|
161 |
-
if not combined_text:
|
162 |
-
return "No relevant text found in search results."
|
163 |
-
|
164 |
-
combined_text = re.sub(r'\s+', ' ', combined_text).strip()
|
165 |
-
max_length = 2000
|
166 |
-
return combined_text[:max_length] + "..." if len(combined_text) > max_length else combined_text
|
167 |
-
|
168 |
-
# --- LLM Generation Functions ---
|
169 |
def parse_llm_output(full_output, input_prompt_list):
|
170 |
"""Attempts to parse only the newly generated text from the LLM output."""
|
171 |
if isinstance(full_output, list) and len(full_output) > 0:
|
@@ -207,11 +159,9 @@ def parse_llm_output(full_output, input_prompt_list):
|
|
207 |
print("Warning: Parsing resulted in empty string, returning original generation.")
|
208 |
return re.sub(r'<end_of_turn>|<start_of_turn>model', '', generated_text).strip()
|
209 |
|
210 |
-
if last_input_content and last_occurrence_index == -1:
|
211 |
-
print("Warning: Could not find last input prompt in LLM output. Returning cleaned full output.")
|
212 |
-
|
213 |
return cleaned_text
|
214 |
|
|
|
215 |
@GPU(memory=60)
|
216 |
def generate_enhanced_persona(name, bio_text, context=""):
|
217 |
"""Use the LLM to enhance the persona profile."""
|
@@ -263,7 +213,12 @@ Additional context for the simulation: {context}
|
|
263 |
Maintain this persona consistently. Respond naturally based on the profile. Do not mention that you are an AI or a simulation. If asked about details not in the profile, you can be evasive or state you don't know/remember, consistent with the persona."""
|
264 |
|
265 |
prompt = [
|
266 |
-
{"role": "system", "content": """You are an expert AI prompt engineer specializing in character simulation. Create a concise system prompt that instructs the LLM to embody the character based on the profile. The prompt must:
|
|
|
|
|
|
|
|
|
|
|
267 |
{"role": "user", "content": f"""Create a system prompt for an AI to simulate the character '{name}'. Context for simulation: {context} Character Profile:
|
268 |
{enhanced_profile}
|
269 |
Generate the system prompt based *only* on the profile and context provided."""}
|
@@ -300,6 +255,7 @@ def generate_response(messages):
|
|
300 |
print("Generating response...")
|
301 |
if not messages:
|
302 |
return "Error: No message history provided."
|
|
|
303 |
try:
|
304 |
tokenizer = pipe.tokenizer
|
305 |
text = tokenizer.apply_chat_template(
|
@@ -323,7 +279,7 @@ def generate_response(messages):
|
|
323 |
except Exception as e:
|
324 |
error_msg = f"Error during response generation: {str(e)}"
|
325 |
print(error_msg)
|
326 |
-
return f"Sorry, I encountered an error
|
327 |
|
328 |
# --- Persona Chat Class with Safety ---
|
329 |
class PersonaChat:
|
@@ -453,15 +409,6 @@ def create_interface():
|
|
453 |
padding: 20px;
|
454 |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05);
|
455 |
}
|
456 |
-
.status-bar {
|
457 |
-
background: #f1f3f5;
|
458 |
-
padding: 12px 15px;
|
459 |
-
border-radius: 5px;
|
460 |
-
margin: 15px 0;
|
461 |
-
font-weight: 500;
|
462 |
-
border: 1px solid #e2e6ea;
|
463 |
-
color: #212529;
|
464 |
-
}
|
465 |
.chat-container {
|
466 |
border: 1px solid #eaeaea;
|
467 |
border-radius: 10px;
|
@@ -473,30 +420,13 @@ def create_interface():
|
|
473 |
.message-input {
|
474 |
margin-top: 10px;
|
475 |
}
|
476 |
-
.send-button {
|
477 |
background-color: #1e3c72 !important;
|
478 |
color: white !important;
|
479 |
border-radius: 8px;
|
480 |
padding: 10px 20px;
|
481 |
font-weight: bold;
|
482 |
}
|
483 |
-
.persona-button {
|
484 |
-
background-color: #2a5298 !important;
|
485 |
-
color: white !important;
|
486 |
-
border-radius: 8px;
|
487 |
-
padding: 10px 20px;
|
488 |
-
font-weight: bold;
|
489 |
-
}
|
490 |
-
.system-prompt-display {
|
491 |
-
background-color: #f5f5f5;
|
492 |
-
border-radius: 8px;
|
493 |
-
padding: 15px;
|
494 |
-
margin-top: 15px;
|
495 |
-
border: 1px solid #e0e0e0;
|
496 |
-
font-family: monospace;
|
497 |
-
white-space: pre-wrap;
|
498 |
-
word-wrap: break-word;
|
499 |
-
}
|
500 |
.footer {
|
501 |
text-align: center;
|
502 |
margin-top: 30px;
|
@@ -505,12 +435,6 @@ def create_interface():
|
|
505 |
padding: 15px;
|
506 |
border-top: 1px solid #eee;
|
507 |
}
|
508 |
-
.typing-indicator {
|
509 |
-
color: #aaa;
|
510 |
-
font-style: italic;
|
511 |
-
}
|
512 |
-
|
513 |
-
/* Mobile styles */
|
514 |
@media (max-width: 768px) {
|
515 |
.chat-container { height: 300px !important; }
|
516 |
.main-container { padding: 10px; }
|
@@ -539,10 +463,16 @@ def create_interface():
|
|
539 |
with gr.Column():
|
540 |
gr.Markdown("### Chat with Character")
|
541 |
character_name_display = gr.Markdown("*No persona created yet*")
|
542 |
-
chatbot = gr.Chatbot(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
msg_input = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter...")
|
544 |
send_button = gr.Button("Send Message")
|
545 |
-
|
546 |
gr.Markdown("Powered by Gemma 3 1B • Ethically Designed • Safe & Secure")
|
547 |
|
548 |
def set_persona_flow(name, context):
|
@@ -552,11 +482,9 @@ def create_interface():
|
|
552 |
|
553 |
initial_status = f"Creating persona for '{name}'..."
|
554 |
initial_character_display = f"### Preparing to chat with {name}..."
|
555 |
-
initial_prompt = "System prompt will appear here..."
|
556 |
-
initial_profile = "Enhanced profile will appear here..."
|
557 |
initial_history = []
|
558 |
|
559 |
-
yield initial_status,
|
560 |
|
561 |
try:
|
562 |
for status_update, prompt_update, profile_update, history_update in persona_chat.set_persona(name, context):
|
@@ -565,30 +493,33 @@ def create_interface():
|
|
565 |
if i+1 < len(history_update):
|
566 |
user_msg = history_update[i].get("content", "")
|
567 |
bot_msg = history_update[i+1].get("content", "")
|
568 |
-
gradio_history.append(
|
|
|
|
|
569 |
character_display = f"### Preparing chat with {name}..."
|
570 |
if "Ready to chat" in status_update:
|
571 |
character_display = f"### Chatting with {name}"
|
572 |
elif "Error" in status_update:
|
573 |
character_display = f"### Error creating {name}"
|
|
|
574 |
yield status_update, prompt_update, profile_update, character_display, gradio_history
|
575 |
time.sleep(0.1)
|
576 |
except Exception as e:
|
577 |
error_msg = f"Failed to set persona (interface error): {str(e)}"
|
578 |
-
print(
|
579 |
-
yield error_msg,
|
580 |
|
581 |
def send_message_flow(message, history):
|
582 |
if not message.strip():
|
583 |
return "", history
|
|
|
584 |
if not persona_chat.messages or persona_chat.messages[0]['role'] != 'system':
|
585 |
-
history.append(
|
586 |
return "", history
|
587 |
|
588 |
-
history.append(
|
589 |
-
|
590 |
-
|
591 |
-
history[-1][1] = response_text
|
592 |
return "", history
|
593 |
|
594 |
set_persona_button.click(
|
|
|
1 |
+
# AI Persona Simulator - Final Optimized Version
|
2 |
import gradio as gr
|
3 |
import torch
|
|
|
4 |
from duckduckgo_search import DDGS
|
5 |
import re
|
6 |
import time
|
|
|
23 |
@GPU(memory=60)
|
24 |
def load_model():
|
25 |
"""Load the Gemma 3 1B model without quantization for full precision."""
|
|
|
26 |
try:
|
27 |
pipe = pipeline(
|
28 |
"text-generation",
|
|
|
31 |
device_map="auto",
|
32 |
model_kwargs={"use_cache": True}
|
33 |
)
|
34 |
+
print(f"Model {MODEL_ID} loaded successfully on device: {pipe.device}")
|
35 |
return pipe
|
36 |
except Exception as e:
|
37 |
print(f"FATAL Error loading model '{MODEL_ID}': {e}")
|
|
|
51 |
3. NO manipulation/exploitation attempts
|
52 |
4. NO illegal/harmful scenarios
|
53 |
5. NO inappropriate relationships
|
54 |
+
Respond ONLY with "TRUE" if acceptable, "FALSE" if not."""},
|
|
|
55 |
{"role": "user", "content": f"Character Name: {name}\nContext: {context}"}
|
56 |
]
|
57 |
|
|
|
61 |
add_generation_prompt=True,
|
62 |
tokenize=False
|
63 |
)
|
64 |
+
|
65 |
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
66 |
outputs = pipe(
|
67 |
text,
|
|
|
70 |
temperature=0.1,
|
71 |
pad_token_id=pipe.tokenizer.eos_token_id
|
72 |
)
|
73 |
+
|
74 |
result = parse_llm_output(outputs, validation_prompt).strip().upper()
|
75 |
return result == "TRUE"
|
76 |
except Exception as e:
|
|
|
117 |
print(f"Found {len(results)} potential search results.")
|
118 |
return results
|
119 |
|
120 |
+
# --- Text Processing Functions ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
def parse_llm_output(full_output, input_prompt_list):
|
122 |
"""Attempts to parse only the newly generated text from the LLM output."""
|
123 |
if isinstance(full_output, list) and len(full_output) > 0:
|
|
|
159 |
print("Warning: Parsing resulted in empty string, returning original generation.")
|
160 |
return re.sub(r'<end_of_turn>|<start_of_turn>model', '', generated_text).strip()
|
161 |
|
|
|
|
|
|
|
162 |
return cleaned_text
|
163 |
|
164 |
+
# --- LLM Generation Functions ---
|
165 |
@GPU(memory=60)
|
166 |
def generate_enhanced_persona(name, bio_text, context=""):
|
167 |
"""Use the LLM to enhance the persona profile."""
|
|
|
213 |
Maintain this persona consistently. Respond naturally based on the profile. Do not mention that you are an AI or a simulation. If asked about details not in the profile, you can be evasive or state you don't know/remember, consistent with the persona."""
|
214 |
|
215 |
prompt = [
|
216 |
+
{"role": "system", "content": """You are an expert AI prompt engineer specializing in character simulation. Create a concise system prompt that instructs the LLM to embody the character based on the profile. The prompt must:
|
217 |
+
1. Define core personality and speaking style
|
218 |
+
2. Specify how to handle unknown topics
|
219 |
+
3. Prohibit breaking character or mentioning AI nature
|
220 |
+
Output ONLY the system prompt itself."""
|
221 |
+
},
|
222 |
{"role": "user", "content": f"""Create a system prompt for an AI to simulate the character '{name}'. Context for simulation: {context} Character Profile:
|
223 |
{enhanced_profile}
|
224 |
Generate the system prompt based *only* on the profile and context provided."""}
|
|
|
255 |
print("Generating response...")
|
256 |
if not messages:
|
257 |
return "Error: No message history provided."
|
258 |
+
|
259 |
try:
|
260 |
tokenizer = pipe.tokenizer
|
261 |
text = tokenizer.apply_chat_template(
|
|
|
279 |
except Exception as e:
|
280 |
error_msg = f"Error during response generation: {str(e)}"
|
281 |
print(error_msg)
|
282 |
+
return f"Sorry, I encountered an error: {str(e)}"
|
283 |
|
284 |
# --- Persona Chat Class with Safety ---
|
285 |
class PersonaChat:
|
|
|
409 |
padding: 20px;
|
410 |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05);
|
411 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
.chat-container {
|
413 |
border: 1px solid #eaeaea;
|
414 |
border-radius: 10px;
|
|
|
420 |
.message-input {
|
421 |
margin-top: 10px;
|
422 |
}
|
423 |
+
.send-button, .persona-button {
|
424 |
background-color: #1e3c72 !important;
|
425 |
color: white !important;
|
426 |
border-radius: 8px;
|
427 |
padding: 10px 20px;
|
428 |
font-weight: bold;
|
429 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
.footer {
|
431 |
text-align: center;
|
432 |
margin-top: 30px;
|
|
|
435 |
padding: 15px;
|
436 |
border-top: 1px solid #eee;
|
437 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
@media (max-width: 768px) {
|
439 |
.chat-container { height: 300px !important; }
|
440 |
.main-container { padding: 10px; }
|
|
|
463 |
with gr.Column():
|
464 |
gr.Markdown("### Chat with Character")
|
465 |
character_name_display = gr.Markdown("*No persona created yet*")
|
466 |
+
chatbot = gr.Chatbot(
|
467 |
+
height=450,
|
468 |
+
show_label=False,
|
469 |
+
bubble_full_width=False,
|
470 |
+
type="messages",
|
471 |
+
avatar_images=("https://api.dicebear.com/6.x/bottts/svg?seed=user ",
|
472 |
+
"https://api.dicebear.com/6.x/bottts/svg?seed=bot ")
|
473 |
+
)
|
474 |
msg_input = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter...")
|
475 |
send_button = gr.Button("Send Message")
|
|
|
476 |
gr.Markdown("Powered by Gemma 3 1B • Ethically Designed • Safe & Secure")
|
477 |
|
478 |
def set_persona_flow(name, context):
|
|
|
482 |
|
483 |
initial_status = f"Creating persona for '{name}'..."
|
484 |
initial_character_display = f"### Preparing to chat with {name}..."
|
|
|
|
|
485 |
initial_history = []
|
486 |
|
487 |
+
yield initial_status, "", "", initial_character_display, initial_history
|
488 |
|
489 |
try:
|
490 |
for status_update, prompt_update, profile_update, history_update in persona_chat.set_persona(name, context):
|
|
|
493 |
if i+1 < len(history_update):
|
494 |
user_msg = history_update[i].get("content", "")
|
495 |
bot_msg = history_update[i+1].get("content", "")
|
496 |
+
gradio_history.append({"role": "user", "content": user_msg})
|
497 |
+
gradio_history.append({"role": "assistant", "content": bot_msg})
|
498 |
+
|
499 |
character_display = f"### Preparing chat with {name}..."
|
500 |
if "Ready to chat" in status_update:
|
501 |
character_display = f"### Chatting with {name}"
|
502 |
elif "Error" in status_update:
|
503 |
character_display = f"### Error creating {name}"
|
504 |
+
|
505 |
yield status_update, prompt_update, profile_update, character_display, gradio_history
|
506 |
time.sleep(0.1)
|
507 |
except Exception as e:
|
508 |
error_msg = f"Failed to set persona (interface error): {str(e)}"
|
509 |
+
print(error_msg)
|
510 |
+
yield error_msg, "", "", f"### Error creating {name}", []
|
511 |
|
512 |
def send_message_flow(message, history):
|
513 |
if not message.strip():
|
514 |
return "", history
|
515 |
+
|
516 |
if not persona_chat.messages or persona_chat.messages[0]['role'] != 'system':
|
517 |
+
history.append({"role": "assistant", "content": "Error: Please create a valid persona first."})
|
518 |
return "", history
|
519 |
|
520 |
+
history.append({"role": "user", "content": message})
|
521 |
+
response = persona_chat.chat(message)
|
522 |
+
history.append({"role": "assistant", "content": response})
|
|
|
523 |
return "", history
|
524 |
|
525 |
set_persona_button.click(
|