Spaces:
Running
on
Zero
Running
on
Zero
# AI Persona Simulator - Final Optimized Version | |
import gradio as gr | |
import torch | |
from duckduckgo_search import DDGS | |
import re | |
import time | |
from spaces import GPU | |
import logging | |
from datetime import datetime | |
# Configure logging for suspicious activity | |
logging.basicConfig( | |
filename='persona_attempts.log', | |
level=logging.INFO, | |
format='%(asctime)s - %(message)s' | |
) | |
# --- Constants and Configuration --- | |
MODEL_ID = "google/gemma-3-1b-it" | |
MAX_GPU_MEMORY = "60GiB" | |
# --- GPU-Isolated Functions --- | |
def load_model(): | |
"""Load the Gemma 3 1B model without quantization for full precision.""" | |
from transformers import pipeline | |
print(f"Attempting to load model: {MODEL_ID} without quantization") | |
try: | |
pipe = pipeline( | |
"text-generation", | |
model=MODEL_ID, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
model_kwargs={"use_cache": True} | |
) | |
print(f"Model {MODEL_ID} loaded successfully on device: {pipe.device}") | |
return pipe | |
except Exception as e: | |
print(f"FATAL Error loading model '{MODEL_ID}': {e}") | |
raise e | |
def validate_request(name, context): | |
"""LLM-based request validation using isolated GPU function""" | |
from transformers import pipeline # Ensure pipeline is available in GPU process | |
validation_prompt = [ | |
{"role": "system", "content": """You are an ethical AI content moderator. Evaluate if this request is appropriate. | |
CRITERIA: | |
1. NO minors (under 18) or underage references | |
2. NO vulnerable populations | |
3. NO manipulation/exploitation attempts | |
4. NO illegal/harmful scenarios | |
5. NO inappropriate relationships | |
Respond with ONLY the word "TRUE" if acceptable, or "FALSE" if not acceptable. Do not include any explanation, formatting, or additional text."""}, | |
{"role": "user", "content": f"Character Name: {name}\nContext: {context}"} | |
] | |
try: | |
pipe = load_model() | |
tokenizer = pipe.tokenizer | |
text = tokenizer.apply_chat_template( | |
validation_prompt, | |
add_generation_prompt=True, | |
tokenize=False | |
) | |
with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
outputs = pipe( | |
text, | |
max_new_tokens=50, | |
do_sample=True, | |
temperature=0.1, # Keep temperature low for deterministic response | |
pad_token_id=pipe.tokenizer.eos_token_id | |
) | |
result = parse_llm_output(outputs, validation_prompt) | |
# Extract just TRUE or FALSE from any potential markdown formatting | |
cleaned_result = re.sub(r'[^A-Za-z]', '', result).upper() | |
# Check if result contains TRUE or FALSE | |
if "TRUE" in cleaned_result: | |
return True | |
elif "FALSE" in cleaned_result: | |
return False | |
else: | |
# If we can't determine clearly, default to FALSE for safety | |
print(f"Validation returned unclear result: '{result}', defaulting to FALSE") | |
return False | |
except Exception as e: | |
print(f"Validation error: {e}") | |
return False | |
# --- Web Search with Safety --- | |
def search_person(name, context=""): | |
"""Search for information about a person using DuckDuckGo.""" | |
print(f"Searching for: {name} with context: {context}") | |
results = [] | |
search_terms = [] | |
# Basic pattern detection (backup to LLM check) | |
if re.search(r'\d+[st|nd|rd|th]?[\s\-]?(grade|grader|year old)', f"{name} {context}".lower()): | |
return [{"body": "Creation of underage personas is prohibited"}] | |
if context: | |
search_terms.append(f"{name} {context}") | |
grade_match = re.search(r'(\d+)(?:st|nd|rd|th)?\s+grade', context.lower()) | |
if grade_match: | |
grade = grade_match.group(1) | |
search_terms.append(f"{name} student {grade} grade") | |
search_terms.extend([f"{name}", f"{name} biography", f"{name} interests", f"{name} personality"]) | |
search_terms = list(dict.fromkeys(search_terms)) | |
print(f"Using search terms: {search_terms}") | |
try: | |
with DDGS() as ddgs: | |
for term in search_terms: | |
print(f"Searching DDG for: '{term}'") | |
search_results = list(ddgs.text(term, max_results=2)) | |
results.extend(search_results) | |
time.sleep(0.2) | |
except Exception as e: | |
error_msg = f"Error during DuckDuckGo search: {str(e)}" | |
print(error_msg) | |
return error_msg | |
if not results: | |
print(f"No search results found for {name}. Creating synthetic profile.") | |
return create_synthetic_profile(name, context) | |
print(f"Found {len(results)} potential search results.") | |
return results | |
def create_synthetic_profile(name, context): | |
"""Create a synthetic profile when search returns no results.""" | |
profile = { | |
"title": f"Synthetic Profile for {name}", | |
"href": "", | |
"body": f"{name} is a person described with the context: '{context}'. " | |
} | |
if "grade" in context.lower(): | |
grade_match = re.search(r'(\d+)(?:st|nd|rd|th)?\s+grade', context.lower()) | |
if grade_match: | |
try: | |
grade = int(grade_match.group(1)) | |
age = 5 + grade | |
profile["body"] += f"Based on being in {grade}th grade, {name} is likely around {age} years old. " | |
profile["body"] += f"Typical interests for this age might include friends, hobbies, school subjects, and developing independence. " | |
except ValueError: | |
profile["body"] += f"The grade mentioned ('{grade_match.group(1)}') could not be parsed to estimate age. " | |
profile["body"] += "Since no public information was found, this profile is based solely on the provided context." | |
return [profile] | |
def extract_text_from_search_results(search_results): | |
"""Extract relevant text from search results.""" | |
if isinstance(search_results, str): | |
return f"Could not extract text due to search error: {search_results}" | |
combined_text = "" | |
seen_bodies = set() | |
count = 0 | |
max_results_to_process = 5 | |
for result in search_results: | |
if count >= max_results_to_process: | |
break | |
if isinstance(result, dict) and 'body' in result and result['body']: | |
body = result['body'].strip() | |
if body not in seen_bodies: | |
combined_text += body + "\n" | |
seen_bodies.add(body) | |
count += 1 | |
if not combined_text: | |
return "No relevant text found in search results." | |
combined_text = re.sub(r'\s+', ' ', combined_text).strip() | |
max_length = 2000 | |
return combined_text[:max_length] + "..." if len(combined_text) > max_length else combined_text | |
# --- Text Processing Functions --- | |
def parse_llm_output(full_output, input_prompt_list): | |
"""Attempts to parse only the newly generated text from the LLM output.""" | |
if isinstance(full_output, list) and len(full_output) > 0: | |
if isinstance(full_output[0], dict) and "generated_text" in full_output[0]: | |
generated_text = full_output[0]["generated_text"] | |
else: | |
return str(full_output) | |
elif isinstance(full_output, str): | |
generated_text = full_output | |
else: | |
return str(full_output) | |
last_input_content = "" | |
if isinstance(input_prompt_list, list) and input_prompt_list: | |
last_input_content = input_prompt_list[-1].get("content", "") | |
if last_input_content: | |
last_occurrence_index = generated_text.rfind(last_input_content) | |
if last_occurrence_index != -1: | |
potential_response = generated_text[last_occurrence_index + len(last_input_content):].strip() | |
if potential_response: | |
potential_response = re.sub(r'^<\/?s?>', '', potential_response).strip() | |
potential_response = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', potential_response).strip() | |
potential_response = re.sub(r'<end_of_turn>|<start_of_turn>model', '', potential_response).strip() | |
if potential_response: | |
return potential_response | |
cleaned_text = generated_text | |
if isinstance(input_prompt_list, list) and input_prompt_list: | |
first_prompt_content = input_prompt_list[0].get("content", "") | |
if first_prompt_content and cleaned_text.startswith(first_prompt_content): | |
pass | |
cleaned_text = re.sub(r'^<\/?s?>', '', cleaned_text).strip() | |
cleaned_text = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', cleaned_text).strip() | |
cleaned_text = re.sub(r'<end_of_turn>|<start_of_turn>model', '', cleaned_text).strip() | |
if not cleaned_text and generated_text: | |
print("Warning: Parsing resulted in empty string, returning original generation.") | |
return re.sub(r'<end_of_turn>|<start_of_turn>model', '', generated_text).strip() | |
return cleaned_text | |
# --- LLM Generation Functions --- | |
def generate_enhanced_persona(name, bio_text, context=""): | |
"""Use the LLM to enhance the persona profile.""" | |
from transformers import pipeline | |
pipe = load_model() | |
print(f"Generating enhanced persona for {name}...") | |
enhancement_prompt = [ | |
{"role": "system", "content": """You are an expert AI character developer. Your task is to synthesize information into a detailed and coherent character profile. Focus on personality, potential interests, speaking style, and mannerisms based ONLY on the provided text. Output ONLY the enhanced character profile description. Do not include conversational introductions, explanations, or markdown formatting like headers.""" | |
}, | |
{"role": "user", "content": f"""Synthesize the following information about '{name}' into a character profile. Context: {context} Information Found: | |
{bio_text} | |
Create the profile based *only* on the text above."""} | |
] | |
try: | |
tokenizer = pipe.tokenizer | |
text = tokenizer.apply_chat_template( | |
enhancement_prompt, | |
add_generation_prompt=True, | |
tokenize=False | |
) | |
with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
outputs = pipe( | |
text, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.8, | |
pad_token_id=pipe.tokenizer.eos_token_id | |
) | |
parsed_output = parse_llm_output(outputs, enhancement_prompt) | |
print("Enhanced persona generated.") | |
return parsed_output if parsed_output else f"Could not generate profile based on:\n{bio_text}" | |
except Exception as e: | |
error_msg = f"Error generating enhanced persona: {str(e)}" | |
print(error_msg) | |
return f"Error enhancing profile: {str(e)}\nUsing basic info:\n{bio_text}" | |
def generate_system_prompt_with_llm(name, enhanced_profile, context=""): | |
"""Generate an optimized system prompt for the persona.""" | |
from transformers import pipeline | |
pipe = load_model() | |
print(f"Generating system prompt for {name}...") | |
fallback_prompt = f"""You are simulating the character '{name}'. Act and respond according to this profile: | |
{enhanced_profile} | |
Additional context for the simulation: {context} | |
--- | |
Maintain this persona consistently. Respond naturally based on the profile. Do not mention that you are an AI or a simulation. If asked about details not in the profile, you can be evasive or state you don't know/remember, consistent with the persona.""" | |
prompt = [ | |
{"role": "system", "content": """You are an expert AI prompt engineer specializing in character simulation. Create a concise system prompt that instructs the LLM to embody the character based on the profile. The prompt must: 1. Define core personality and speaking style. 2. Specify how to handle unknown topics. 3. Prohibit breaking character or mentioning AI nature. Output ONLY the system prompt itself.""" | |
}, | |
{"role": "user", "content": f"""Create a system prompt for an AI to simulate the character '{name}'. Context for simulation: {context} Character Profile: | |
{enhanced_profile} | |
Generate the system prompt based *only* on the profile and context provided."""} | |
] | |
try: | |
tokenizer = pipe.tokenizer | |
text = tokenizer.apply_chat_template( | |
prompt, | |
add_generation_prompt=True, | |
tokenize=False | |
) | |
with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
outputs = pipe( | |
text, | |
max_new_tokens=300, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.8, | |
pad_token_id=pipe.tokenizer.eos_token_id | |
) | |
parsed_output = parse_llm_output(outputs, prompt) | |
print("System prompt generated.") | |
return parsed_output if parsed_output else fallback_prompt | |
except Exception as e: | |
error_msg = f"Error generating system prompt: {str(e)}" | |
print(error_msg) | |
return fallback_prompt | |
def generate_response(messages): | |
"""Generate a response using the LLM.""" | |
from transformers import pipeline | |
pipe = load_model() | |
print("Generating response...") | |
if not messages: | |
return "Error: No message history provided." | |
try: | |
tokenizer = pipe.tokenizer | |
text = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=False | |
) | |
with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
outputs = pipe( | |
text, | |
max_new_tokens=512, | |
do_sample=True, | |
top_p=0.8, | |
temperature=0.7, | |
pad_token_id=pipe.tokenizer.eos_token_id | |
) | |
parsed_output = parse_llm_output(outputs, messages) | |
parsed_output = re.sub(r'<end_of_turn>|<start_of_turn>model', '', parsed_output).strip() | |
print("Response generated.") | |
return parsed_output if parsed_output else "..." | |
except Exception as e: | |
error_msg = f"Error during response generation: {str(e)}" | |
print(error_msg) | |
return f"Sorry, I encountered an error trying to respond." | |
# --- Persona Chat Class with Safety --- | |
class PersonaChat: | |
def __init__(self): | |
self.system_prompt = "You are a helpful assistant." | |
self.persona_name = "Assistant" | |
self.persona_context = "" | |
self.messages = [] | |
self.enhanced_profile = "" | |
def set_persona(self, name, context=""): | |
"""Orchestrates persona creation: validation, search, enhance, generate prompt.""" | |
try: | |
# First validate the request with LLM | |
is_valid = validate_request(name, context) | |
if not is_valid: | |
warning = "This request has been flagged as inappropriate. We cannot create personas that involve minors, vulnerable individuals, or potentially harmful scenarios." | |
yield warning, "", "", [{"role": "system", "content": warning}] | |
return | |
self.persona_name = name | |
self.persona_context = context | |
self.messages = [] | |
self.enhanced_profile = "" | |
status = f"Searching for information about {name}..." | |
print(f"set_persona: Yielding search status: {status}") | |
yield status, "", "", [] | |
search_results = search_person(name, context) | |
if isinstance(search_results, str) and search_results.startswith("Error"): | |
error_msg = f"Failed to set persona: {search_results}" | |
print(f"set_persona: Yielding error: {error_msg}") | |
yield error_msg, "", "", [{"role": "system", "content": error_msg}] | |
return | |
bio_text = extract_text_from_search_results(search_results) | |
if bio_text.startswith("Could not extract text"): | |
print(f"set_persona: Yielding bio warning: {bio_text}") | |
yield f"Warning: {bio_text}", "", "", [{"role": "system", "content": bio_text}] | |
status = f"Creating enhanced profile for {name}..." | |
print(f"set_persona: Yielding profile status: {status}") | |
yield status, "", bio_text, [] | |
self.enhanced_profile = generate_enhanced_persona(name, bio_text, context) | |
profile_for_prompt = self.enhanced_profile | |
if self.enhanced_profile.startswith("Error enhancing profile"): | |
print(f"set_persona: Yielding profile warning: {self.enhanced_profile}") | |
yield f"Warning: Could not enhance profile. Using basic info.", "", self.enhanced_profile, [{"role": "system", "content": self.enhanced_profile}] | |
profile_for_prompt = bio_text | |
status = f"Generating optimal system prompt for {name}..." | |
print(f"set_persona: Yielding prompt status: {status}") | |
yield status, self.enhanced_profile, self.enhanced_profile, [] | |
self.system_prompt = generate_system_prompt_with_llm(name, profile_for_prompt, context) | |
self.system_prompt = re.sub(r'<\|im_tailored\|>|<\|im_start\|>|^assistant\s*', '', self.system_prompt).strip() | |
self.messages = [{"role": "system", "content": self.system_prompt}] | |
print(f"set_persona: Final yield with messages (not sent to Chatbot): {self.messages}") | |
yield f"Persona set to '{name}'. Ready to chat!", self.system_prompt, self.enhanced_profile, [] | |
except Exception as e: | |
error_msg = f"An unexpected error occurred during persona setup: {str(e)}" | |
print(f"set_persona: Yielding exception: {error_msg}") | |
yield error_msg, self.system_prompt, self.enhanced_profile, [{"role": "system", "content": error_msg}] | |
def chat(self, user_message): | |
"""Processes a user message and returns the AI's response.""" | |
try: | |
if not self.messages: | |
print("Error: Chat called before persona was set.") | |
return "Please set a persona first using the controls above." | |
print(f"User message: {user_message}") | |
self.messages.append({"role": "user", "content": user_message}) | |
response = generate_response(self.messages) | |
if not response.startswith("Sorry, I encountered an error"): | |
self.messages.append({"role": "assistant", "content": response}) | |
print(f"Assistant response: {response}") | |
else: | |
print(f"Assistant error response: {response}") | |
return response | |
except Exception as e: | |
error_msg = f"Error generating response: {str(e)}" | |
print(error_msg) | |
return f"Sorry, I encountered an error: {str(e)}" | |
# --- Gradio Interface with Enhanced UI --- | |
def create_interface(): | |
persona_chat = PersonaChat() | |
# Mobile-optimized CSS with modern styling | |
css = """ | |
.gradio-container { | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
background: linear-gradient(to right, #f8f9fa, #e9ecef); | |
} | |
.main-container { | |
max-width: 1200px; | |
margin: auto; | |
padding: 20px; | |
background: white; | |
border-radius: 15px; | |
box-shadow: 0 8px 24px rgba(0,0,0,0.08); | |
} | |
.header { | |
background: linear-gradient(135deg, #1e3c72, #2a5298); | |
color: white; | |
padding: 25px; | |
border-radius: 10px 10px 0 0; | |
margin-bottom: 25px; | |
text-align: center; | |
} | |
.setup-section { | |
background-color: #f8f9fa; | |
border-radius: 10px; | |
padding: 20px; | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05); | |
margin-bottom: 25px; | |
} | |
.chat-section { | |
background-color: #ffffff; | |
border-radius: 10px; | |
padding: 20px; | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05); | |
} | |
.chat-container { | |
border: 1px solid #eaeaea; | |
border-radius: 10px; | |
height: 450px !important; | |
overflow-y: auto; | |
background-color: #ffffff; | |
padding: 10px; | |
} | |
.message-input { | |
margin-top: 10px; | |
} | |
.send-button, .persona-button { | |
background-color: #1e3c72 !important; | |
color: white !important; | |
border-radius: 8px; | |
padding: 10px 20px; | |
font-weight: bold; | |
} | |
.footer { | |
text-align: center; | |
margin-top: 30px; | |
font-size: 0.9em; | |
color: #666; | |
padding: 15px; | |
border-top: 1px solid #eee; | |
} | |
@media (max-width: 768px) { | |
.chat-container { height: 300px !important; } | |
.main-container { padding: 10px; } | |
} | |
""" | |
with gr.Blocks(css=css, title="AI Persona Simulator") as interface: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Column(): | |
gr.Markdown("# 🤖 AI Persona Simulator") | |
gr.Markdown("Create and interact with ethical character simulations using advanced AI") | |
with gr.Column(): | |
gr.Markdown("### Create Your Persona") | |
gr.Markdown("Enter a name and context for your character") | |
name_input = gr.Textbox(label="Character Name", placeholder="e.g., Sherlock Holmes, Historical Figure") | |
context_input = gr.Textbox(label="Character Context", lines=2, placeholder="e.g., Victorian detective living in London, OR Tech entrepreneur focused on AI ethics") | |
set_persona_button = gr.Button("Create Persona & Start Chat", variant="primary") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
with gr.Accordion("View Generated Details", open=False): | |
enhanced_profile_display = gr.TextArea(label="Enhanced Profile", lines=10) | |
system_prompt_display = gr.TextArea(label="System Prompt", lines=10) | |
with gr.Column(): | |
gr.Markdown("### Chat with Character") | |
character_name_display = gr.Markdown("*No persona created yet*") | |
chatbot = gr.Chatbot( | |
height=450, | |
show_label=False, | |
type="messages", | |
avatar_images=("https://api.dicebear.com/6.x/bottts/svg?seed=user ", | |
"https://api.dicebear.com/6.x/bottts/svg?seed=bot ") | |
) | |
msg_input = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter...") | |
send_button = gr.Button("Send Message") | |
gr.Markdown("Powered by Gemma 3 1B • Ethically Designed • Safe & Secure") | |
def set_persona_flow(name, context): | |
if not name: | |
yield "Status: Please enter a character name.", "", "", "*No persona created yet*", [] | |
return | |
initial_status = f"Creating persona for '{name}'..." | |
initial_character_display = f"### Preparing to chat with {name}..." | |
initial_history = [] | |
yield initial_status, "", "", initial_character_display, initial_history | |
try: | |
for status_update, prompt_update, profile_update, history_update in persona_chat.set_persona(name, context): | |
gradio_history = [] | |
for i in range(0, len(history_update), 2): | |
if i+1 < len(history_update): | |
user_msg = history_update[i].get("content", "") | |
bot_msg = history_update[i+1].get("content", "") | |
gradio_history.append({"role": "user", "content": user_msg}) | |
gradio_history.append({"role": "assistant", "content": bot_msg}) | |
character_display = f"### Preparing chat with {name}..." | |
if "Ready to chat" in status_update: | |
character_display = f"### Chatting with {name}" | |
elif "Error" in status_update: | |
character_display = f"### Error creating {name}" | |
yield status_update, prompt_update, profile_update, character_display, gradio_history | |
time.sleep(0.1) | |
except Exception as e: | |
error_msg = f"Failed to set persona (interface error): {str(e)}" | |
print(f"set_persona_flow: Exception: {error_msg}") | |
yield error_msg, "", "", f"### Error creating {name}", [] | |
def send_message_flow(message, history): | |
if not message.strip(): | |
return "", history | |
if not persona_chat.messages or persona_chat.messages[0]['role'] != 'system': | |
history.append({"role": "assistant", "content": "Error: Please create a valid persona first."}) | |
return "", history | |
history.append({"role": "user", "content": message}) | |
response = persona_chat.chat(message) | |
history.append({"role": "assistant", "content": response}) | |
return "", history | |
set_persona_button.click( | |
set_persona_flow, | |
inputs=[name_input, context_input], | |
outputs=[status_output, system_prompt_display, enhanced_profile_display, character_name_display, chatbot] | |
) | |
send_button.click( | |
send_message_flow, | |
inputs=[msg_input, chatbot], | |
outputs=[msg_input, chatbot] | |
) | |
msg_input.submit( | |
send_message_flow, | |
inputs=[msg_input, chatbot], | |
outputs=[msg_input, chatbot] | |
) | |
return interface | |
# --- Main Execution --- | |
if __name__ == "__main__": | |
print("Starting secure AI Persona Simulator with LLM-based request validation...") | |
demo = create_interface() | |
demo.queue().launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
debug=True, | |
ssr_mode=False | |
) |