ErenalpCet commited on
Commit
be47882
·
verified ·
1 Parent(s): 60d5b7c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +557 -0
app.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import transformers
3
+ import torch
4
+ from transformers import pipeline
5
+ from duckduckgo_search import DDGS
6
+ import re
7
+ import time
8
+
9
+ def search_person(name, context=""):
10
+ """Search for information about a person using DuckDuckGo."""
11
+ results = []
12
+ search_terms = []
13
+
14
+ # Create search terms based on provided context
15
+ if "grade" in context.lower():
16
+ grade_match = re.search(r'(\d+)(st|nd|rd|th)?\s+grade', context.lower())
17
+ if grade_match:
18
+ grade = grade_match.group(1)
19
+ search_terms.append(f"{name} student {grade} grade")
20
+
21
+ # Add basic search terms
22
+ search_terms.extend([
23
+ f"{name} {context}" if context else name,
24
+ f"{name} interests",
25
+ f"{name} personality"
26
+ ])
27
+
28
+ try:
29
+ with DDGS() as ddgs:
30
+ for term in search_terms:
31
+ search_results = list(ddgs.text(term, max_results=3))
32
+ results.extend(search_results)
33
+ except Exception as e:
34
+ return f"Error during search: {str(e)}"
35
+
36
+ # If no results found but we have context, create synthetic information
37
+ if not results and context:
38
+ return create_synthetic_profile(name, context)
39
+
40
+ return results
41
+
42
+ def create_synthetic_profile(name, context):
43
+ """Create a synthetic profile when search returns no results."""
44
+ profile = {
45
+ "body": f"{name} is a person described as: {context}."
46
+ }
47
+
48
+ # Extract age/grade information
49
+ if "grade" in context.lower():
50
+ grade_match = re.search(r'(\d+)(st|nd|rd|th)?\s+grade', context.lower())
51
+ if grade_match:
52
+ grade = grade_match.group(1)
53
+ age = 5 + int(grade) # Approximate age based on grade
54
+ profile["body"] += f" {name} is approximately {age} years old and in {grade}th grade."
55
+ profile["body"] += f" Like most {grade}th graders, {name} is likely interested in friends, learning new things, and developing their own identity."
56
+
57
+ return [profile]
58
+
59
+ def extract_text_from_search_results(search_results):
60
+ """Extract relevant text from search results."""
61
+ combined_text = ""
62
+ for result in search_results:
63
+ if isinstance(result, dict) and 'body' in result:
64
+ combined_text += result['body'] + "\n\n"
65
+
66
+ # Clean up the text
67
+ combined_text = re.sub(r'\s+', ' ', combined_text)
68
+ return combined_text
69
+
70
+ def load_model():
71
+ """Load the LLM model."""
72
+ model_id = "nvidia/Llama-3.1-Nemotron-8B-UltraLong-4M-Instruct"
73
+ pipe = pipeline(
74
+ "text-generation",
75
+ model=model_id,
76
+ model_kwargs={"torch_dtype": torch.bfloat16},
77
+ device_map="auto",
78
+ )
79
+ return pipe
80
+
81
+ def generate_enhanced_persona(model, name, bio_text, context=""):
82
+ """Use the LLM to enhance the persona profile."""
83
+
84
+ enhancement_prompt = [
85
+ {"role": "system", "content": """You are an expert AI character developer.
86
+ Your task is to create a detailed character profile based on limited information.
87
+ Output ONLY the enhanced profile with no additional explanations or formatting."""},
88
+ {"role": "user", "content": f"""Here's some information I found about {name}:
89
+
90
+ {bio_text}
91
+
92
+ Additional context: {context}
93
+
94
+ Based on this information, create a detailed, rich character profile for {name}.
95
+ Include personality traits, speaking style, interests, background details, quirks, and mannerisms.
96
+ If this is a child in school, include age-appropriate details about school life, friends, family dynamics, and interests.
97
+ Be creative but make the profile coherent with the known facts.
98
+ Focus on what makes {name} unique and authentic.
99
+ Structure your response as a character profile that could be used by an actor playing this role.
100
+
101
+ DO NOT prefix your response with anything. Start directly with the character profile.
102
+ DO NOT include any disclaimers, explanations, or notes in your response.
103
+ DO NOT use bullet points or section titles."""}
104
+ ]
105
+
106
+ try:
107
+ outputs = model(enhancement_prompt, max_new_tokens=1024)
108
+ if isinstance(outputs, list) and len(outputs) > 0:
109
+ if isinstance(outputs[0], dict) and "generated_text" in outputs[0]:
110
+ if isinstance(outputs[0]["generated_text"], list) and len(outputs[0]["generated_text"]) > 0:
111
+ last_message = outputs[0]["generated_text"][-1]
112
+ if isinstance(last_message, dict) and "content" in last_message:
113
+ return last_message["content"]
114
+ elif isinstance(outputs[0]["generated_text"], str):
115
+ return outputs[0]["generated_text"]
116
+
117
+ # Fallback parsing
118
+ if isinstance(outputs, str):
119
+ return outputs
120
+ return bio_text # Fall back to original if parsing fails
121
+ except Exception as e:
122
+ print(f"Error generating enhanced persona: {str(e)}")
123
+ return bio_text
124
+
125
+ def generate_system_prompt_with_llm(model, name, enhanced_profile, context=""):
126
+ """Use the LLM to generate an optimized system prompt."""
127
+
128
+ prompt_generation_message = [
129
+ {"role": "system", "content": """You are an expert AI prompt engineer.
130
+ Your task is to create an optimal system prompt that will make an LLM simulate a specific person accurately.
131
+ Output ONLY the system prompt with no additional explanations."""},
132
+ {"role": "user", "content": f"""Here is a detailed profile for {name}:
133
+
134
+ {enhanced_profile}
135
+
136
+ Additional context: {context}
137
+
138
+ Create a comprehensive and effective system prompt that would make an LLM perfectly simulate {name}.
139
+ The prompt should:
140
+ 1. Include key personality traits and communication style
141
+ 2. Specify how to handle questions about unknown topics
142
+ 3. Provide guidance on maintaining consistent character behavior
143
+ 4. Include instructions for age-appropriate responses if this is a child
144
+ 5. Give specific examples of phrases or expressions this person might use
145
+
146
+ Format the system prompt for direct use - don't include any explanations outside the prompt itself."""}
147
+ ]
148
+
149
+ try:
150
+ outputs = model(prompt_generation_message, max_new_tokens=1024)
151
+ if isinstance(outputs, list) and len(outputs) > 0:
152
+ if isinstance(outputs[0], dict) and "generated_text" in outputs[0]:
153
+ if isinstance(outputs[0]["generated_text"], list) and len(outputs[0]["generated_text"]) > 0:
154
+ last_message = outputs[0]["generated_text"][-1]
155
+ if isinstance(last_message, dict) and "content" in last_message:
156
+ return last_message["content"]
157
+ elif isinstance(outputs[0]["generated_text"], str):
158
+ return outputs[0]["generated_text"]
159
+
160
+ # Fallback parsing
161
+ if isinstance(outputs, str):
162
+ return outputs
163
+
164
+ # If all parsing fails, generate a basic system prompt
165
+ return f"""You are now simulating {name}. Use the following information to respond as if you were {name}:
166
+
167
+ {enhanced_profile}
168
+
169
+ {context}
170
+
171
+ Always stay in character as {name} and respond directly as {name} would respond."""
172
+ except Exception as e:
173
+ print(f"Error generating system prompt: {str(e)}")
174
+ # Fallback to basic prompt
175
+ return f"""You are now simulating {name}. Use the following information to respond as if you were {name}:
176
+
177
+ {enhanced_profile}
178
+
179
+ {context}
180
+
181
+ Always stay in character as {name} and respond directly as {name} would respond."""
182
+
183
+ def generate_response(model, messages):
184
+ """Generate a response using the LLM."""
185
+ outputs = model(messages, max_new_tokens=512)
186
+ # Extract the content from the generated response
187
+ if isinstance(outputs, list) and len(outputs) > 0:
188
+ if isinstance(outputs[0], dict) and "generated_text" in outputs[0]:
189
+ if isinstance(outputs[0]["generated_text"], list) and len(outputs[0]["generated_text"]) > 0:
190
+ last_message = outputs[0]["generated_text"][-1]
191
+ if isinstance(last_message, dict) and "content" in last_message:
192
+ return last_message["content"]
193
+ elif isinstance(outputs[0]["generated_text"], str):
194
+ return outputs[0]["generated_text"]
195
+
196
+ # Fallback parsing for different output formats
197
+ if isinstance(outputs, str):
198
+ return outputs
199
+ return "I couldn't generate a proper response. Please try again."
200
+
201
+ class PersonaChat:
202
+ def __init__(self):
203
+ self.model = None
204
+ self.system_prompt = "You are a helpful assistant."
205
+ self.persona_name = "Assistant"
206
+ self.persona_context = ""
207
+ self.messages = []
208
+ self.enhanced_profile = ""
209
+
210
+ def load_model_if_needed(self):
211
+ if self.model is None:
212
+ self.model = load_model()
213
+
214
+ def set_persona(self, name, context=""):
215
+ self.load_model_if_needed()
216
+ self.persona_name = name
217
+ self.persona_context = context
218
+
219
+ # Show loading indicator
220
+ status = f"Searching for information about {name}..."
221
+ yield status, "", [{"role": "system", "content": "Starting persona creation..."}]
222
+
223
+ search_results = search_person(name, context)
224
+ if isinstance(search_results, str) and search_results.startswith("Error"):
225
+ yield f"Error: {search_results}", "", [{"role": "system", "content": f"Error: {search_results}"}]
226
+ return
227
+
228
+ # Extract text from search results
229
+ bio_text = extract_text_from_search_results(search_results)
230
+
231
+ # Use LLM to enhance the persona
232
+ status = f"Creating enhanced profile for {name}..."
233
+ yield status, "", [{"role": "system", "content": status}]
234
+
235
+ self.enhanced_profile = generate_enhanced_persona(self.model, name, bio_text, context)
236
+
237
+ # Use LLM to generate the optimized system prompt
238
+ status = f"Generating optimal system prompt for {name}..."
239
+ yield status, "", [{"role": "system", "content": status}]
240
+
241
+ self.system_prompt = generate_system_prompt_with_llm(self.model, name, self.enhanced_profile, context)
242
+ self.messages = [{"role": "system", "content": self.system_prompt}]
243
+
244
+ yield f"Persona set to {name}. Ready to chat!", self.system_prompt, self.messages
245
+
246
+ def chat(self, user_message):
247
+ """Process a chat message and return the response."""
248
+ self.load_model_if_needed()
249
+
250
+ try:
251
+ # Create message format for the model
252
+ if isinstance(user_message, str):
253
+ formatted_message = {"role": "user", "content": user_message}
254
+ else:
255
+ formatted_message = user_message
256
+
257
+ # Add the message to history
258
+ self.messages.append(formatted_message)
259
+
260
+ # Generate response using the model
261
+ response = generate_response(self.model, self.messages)
262
+
263
+ # Format the response
264
+ assistant_message = {"role": "assistant", "content": response}
265
+ self.messages.append(assistant_message)
266
+
267
+ return response
268
+
269
+ except Exception as e:
270
+ error_msg = f"Error generating response: {str(e)}"
271
+ print(error_msg)
272
+ return error_msg
273
+
274
+ def create_interface():
275
+ persona_chat = PersonaChat()
276
+
277
+ # Custom CSS for better UI
278
+ css = """
279
+ .gradio-container {
280
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
281
+ }
282
+
283
+ .main-container {
284
+ max-width: 1200px;
285
+ margin: auto;
286
+ padding: 0;
287
+ }
288
+
289
+ .header {
290
+ background: linear-gradient(90deg, #2c3e50, #4ca1af);
291
+ color: white;
292
+ padding: 20px;
293
+ border-radius: 10px 10px 0 0;
294
+ margin-bottom: 20px;
295
+ text-align: center;
296
+ }
297
+
298
+ .setup-section {
299
+ background-color: #f9f9f9;
300
+ border-radius: 10px;
301
+ padding: 20px;
302
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
303
+ margin-bottom: 20px;
304
+ }
305
+
306
+ .chat-section {
307
+ background-color: white;
308
+ border-radius: 10px;
309
+ padding: 20px;
310
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
311
+ }
312
+
313
+ .status-bar {
314
+ background: #f0f0f0;
315
+ padding: 10px 15px;
316
+ border-radius: 5px;
317
+ margin: 15px 0;
318
+ font-weight: 500;
319
+ }
320
+
321
+ .chat-container {
322
+ border: 1px solid #eaeaea;
323
+ border-radius: 10px;
324
+ height: 500px !important;
325
+ overflow-y: auto;
326
+ background-color: #f9f9f9;
327
+ }
328
+
329
+ .message-input {
330
+ margin-top: 10px;
331
+ }
332
+
333
+ .send-button {
334
+ background-color: #2c3e50 !important;
335
+ }
336
+
337
+ .persona-button {
338
+ background-color: #4ca1af !important;
339
+ }
340
+
341
+ .system-prompt {
342
+ background-color: #f5f5f5;
343
+ border-radius: 8px;
344
+ padding: 10px;
345
+ margin-top: 15px;
346
+ border: 1px solid #e0e0e0;
347
+ }
348
+
349
+ .footer {
350
+ text-align: center;
351
+ margin-top: 20px;
352
+ font-size: 0.9rem;
353
+ color: #666;
354
+ }
355
+
356
+ /* Avatar styling */
357
+ .user-message {
358
+ background-color: #e1f5fe;
359
+ border-radius: 15px 15px 0 15px;
360
+ padding: 10px 15px;
361
+ margin: 8px 0;
362
+ max-width: 80%;
363
+ float: right;
364
+ clear: both;
365
+ }
366
+
367
+ .bot-message {
368
+ background-color: #f0f0f0;
369
+ border-radius: 15px 15px 15px 0;
370
+ padding: 10px 15px;
371
+ margin: 8px 0;
372
+ max-width: 80%;
373
+ float: left;
374
+ clear: both;
375
+ }
376
+
377
+ /* Loading animation */
378
+ @keyframes pulse {
379
+ 0% { opacity: 0.6; }
380
+ 50% { opacity: 1; }
381
+ 100% { opacity: 0.6; }
382
+ }
383
+
384
+ .loading {
385
+ animation: pulse 1.5s infinite;
386
+ padding: 10px;
387
+ background-color: #eee;
388
+ border-radius: 5px;
389
+ display: inline-block;
390
+ }
391
+ """
392
+
393
+ with gr.Blocks(css=css, title="AI Persona Simulator") as interface:
394
+ with gr.Row(elem_classes="main-container"):
395
+ with gr.Column():
396
+ # Header
397
+ with gr.Column(elem_classes="header"):
398
+ gr.Markdown("# AI Persona Simulator")
399
+ gr.Markdown("Create lifelike character simulations with advanced AI")
400
+
401
+ # Setup Section
402
+ with gr.Column(elem_classes="setup-section"):
403
+ gr.Markdown("### Create Your Persona")
404
+ gr.Markdown("Enter details about the character you want to simulate")
405
+
406
+ with gr.Row():
407
+ name_input = gr.Textbox(
408
+ label="Character Name",
409
+ placeholder="e.g. Erenalp",
410
+ elem_classes="input-field"
411
+ )
412
+
413
+ with gr.Row():
414
+ context_input = gr.Textbox(
415
+ label="Character Context",
416
+ placeholder="e.g. in 7th grade, loves math and video games, has a pet cat",
417
+ lines=2,
418
+ elem_classes="input-field"
419
+ )
420
+
421
+ with gr.Row():
422
+ set_persona_button = gr.Button(
423
+ "Create Persona",
424
+ variant="primary",
425
+ elem_classes="persona-button"
426
+ )
427
+
428
+ status_output = gr.Textbox(
429
+ label="Status",
430
+ interactive=False,
431
+ elem_classes="status-bar"
432
+ )
433
+
434
+ with gr.Accordion("Character System Prompt", open=False, elem_classes="system-prompt-section"):
435
+ system_prompt_display = gr.TextArea(
436
+ label="",
437
+ interactive=False,
438
+ lines=10,
439
+ elem_classes="system-prompt"
440
+ )
441
+
442
+ # Chat Section
443
+ with gr.Column(elem_classes="chat-section"):
444
+ gr.Markdown("### Chat with Your Character")
445
+
446
+ # Display character name dynamically
447
+ character_name_display = gr.Markdown(
448
+ elem_id="character-name",
449
+ value="Start by creating a persona above"
450
+ )
451
+
452
+ chatbot = gr.Chatbot(
453
+ label="",
454
+ height=450,
455
+ elem_classes="chat-container",
456
+ avatar_images=("👤", "🤖"),
457
+ type="messages"
458
+ )
459
+
460
+ with gr.Row():
461
+ msg_input = gr.Textbox(
462
+ label="Your message",
463
+ placeholder="Type your message here...",
464
+ elem_classes="message-input"
465
+ )
466
+ send_button = gr.Button(
467
+ "Send",
468
+ variant="primary",
469
+ elem_classes="send-button"
470
+ )
471
+
472
+ # Footer
473
+ with gr.Column(elem_classes="footer"):
474
+ gr.Markdown("Powered by Llama-3.1-Nemotron-8B-UltraLong-4M-Instruct")
475
+
476
+ # Functions
477
+ def update_character_name(name):
478
+ if name:
479
+ return f"### Chatting with {name}"
480
+ return "### Start by creating a persona above"
481
+
482
+ def set_persona_generator(name, context):
483
+ initial_status = f"Creating persona for {name}..."
484
+ initial_character_display = f"### Creating persona for {name}..."
485
+ initial_prompt = ""
486
+ initial_history = [{"role": "system", "content": "Initializing..."}]
487
+
488
+ # Initial yield
489
+ yield initial_status, initial_prompt, initial_history, initial_character_display
490
+
491
+ # Process persona creation
492
+ for status, prompt, history in persona_chat.set_persona(name, context):
493
+ character_display = f"### Creating persona for {name}..."
494
+ if "Ready to chat" in status:
495
+ character_display = f"### Chatting with {name}"
496
+ yield status, prompt, history, character_display
497
+
498
+ def send_message(message, history):
499
+ if not message.strip():
500
+ return "", history
501
+
502
+ if not persona_chat.messages:
503
+ new_history = list(history) if history else []
504
+ new_history.append({"role": "user", "content": message})
505
+ new_history.append({"role": "assistant", "content": "Please create a persona first using the form above."})
506
+ return "", new_history
507
+
508
+ try:
509
+ # Show typing indicator
510
+ new_history = list(history) if history else []
511
+ new_history.append({"role": "user", "content": message})
512
+ new_history.append({"role": "assistant", "content": "..."})
513
+ yield "", new_history
514
+
515
+ # Generate actual response
516
+ response = persona_chat.chat(message)
517
+ new_history[-1]["content"] = response
518
+ yield "", new_history
519
+
520
+ except Exception as e:
521
+ print(f"Error in send_message: {str(e)}")
522
+ new_history[-1]["content"] = "Sorry, there was an error processing your message."
523
+ yield "", new_history
524
+
525
+ # Event handlers
526
+ set_persona_button.click(
527
+ set_persona_generator,
528
+ inputs=[name_input, context_input],
529
+ outputs=[status_output, system_prompt_display, chatbot, character_name_display]
530
+ )
531
+
532
+ name_input.change(
533
+ update_character_name,
534
+ inputs=[name_input],
535
+ outputs=[character_name_display]
536
+ )
537
+
538
+ send_button.click(
539
+ send_message,
540
+ inputs=[msg_input, chatbot],
541
+ outputs=[msg_input, chatbot]
542
+ )
543
+
544
+ msg_input.submit(
545
+ send_message,
546
+ inputs=[msg_input, chatbot],
547
+ outputs=[msg_input, chatbot]
548
+ )
549
+
550
+ return interface
551
+
552
+ # Install required packages if not already installed
553
+ # !pip install gradio transformers torch duckduckgo_search
554
+
555
+ # Create and launch the interface
556
+ demo = create_interface()
557
+ demo.launch(share=True)