ErenalpCet commited on
Commit
c182ab9
·
verified ·
1 Parent(s): 5a3a162

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -29
app.py CHANGED
@@ -132,6 +132,8 @@ def parse_llm_output(full_output, input_prompt_list):
132
  if potential_response:
133
  potential_response = re.sub(r'^<\/?s?>', '', potential_response).strip()
134
  potential_response = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', potential_response).strip()
 
 
135
  if potential_response:
136
  return potential_response
137
  cleaned_text = generated_text
@@ -141,9 +143,12 @@ def parse_llm_output(full_output, input_prompt_list):
141
  pass
142
  cleaned_text = re.sub(r'^<\/?s?>', '', cleaned_text).strip()
143
  cleaned_text = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', cleaned_text).strip()
 
 
144
  if not cleaned_text and generated_text:
145
  print("Warning: Parsing resulted in empty string, returning original generation.")
146
- return generated_text
 
147
  if last_input_content and last_occurrence_index == -1:
148
  print("Warning: Could not find last input prompt in LLM output. Returning cleaned full output.")
149
  return cleaned_text
@@ -239,6 +244,8 @@ def generate_response(messages):
239
  pad_token_id=pipe.tokenizer.eos_token_id if pipe.tokenizer.eos_token_id else None
240
  )
241
  parsed_output = parse_llm_output(outputs, messages)
 
 
242
  print("Response generated.")
243
  return parsed_output if parsed_output else "..."
244
  except Exception as e:
@@ -326,7 +333,7 @@ def create_interface():
326
  .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
327
  .main-container { max-width: 1200px; margin: auto; padding: 0; }
328
  .header { background: linear-gradient(90deg, #2c3e50, #4ca1af); color: white; padding: 20px; border-radius: 10px 10px 0 0; margin-bottom: 20px; text-align: center; }
329
- .setup-section { tôackground-color: #f9f9f9; border-radius: 10px; padding: 20px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); margin-bottom: 20px; }
330
  .chat-section { background-color: white; border-radius: 10px; padding: 20px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); }
331
  .status-bar { background: #e9ecef; padding: 10px 15px; border-radius: 5px; margin: 15px 0; font-weight: 500; border: 1px solid #ced4da; }
332
  .chat-container { border: 1px solid #eaeaea; border-radius: 10px; height: 500px !important; overflow-y: auto; background-color: #ffffff; padding: 10px; }
@@ -337,6 +344,11 @@ def create_interface():
337
  .footer { text-align: center; margin-top: 20px; font-size: 0.9em; color: #666; }
338
  .typing-indicator { color: #aaa; font-style: italic; }
339
  """
 
 
 
 
 
340
  with gr.Blocks(css=css, title="AI Persona Simulator") as interface:
341
  with gr.Row(elem_classes="main-container"):
342
  with gr.Column():
@@ -361,14 +373,15 @@ def create_interface():
361
  label="Conversation",
362
  height=450,
363
  elem_classes="chat-container",
364
- avatar_images=(None, "🤖"),
365
- type="messages"
366
  )
367
  with gr.Row():
368
  msg_input = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter...", elem_classes="message-input", scale=4)
369
  send_button = gr.Button("Send", variant="primary", elem_classes="send-button", scale=1)
370
  with gr.Column(elem_classes="footer"):
371
  gr.Markdown(f"Powered by {MODEL_ID}")
 
372
  def set_persona_flow(name, context):
373
  if not name:
374
  yield "Status: Please enter a character name.", "", "", "*No persona created yet*", []
@@ -385,50 +398,53 @@ def create_interface():
385
  final_status, final_prompt, final_profile, final_character_display, final_history = "Error", "", "", f"### Error creating {name}", []
386
  try:
387
  for status_update, prompt_update, profile_update, history_update in persona_chat.set_persona(name, context):
388
- # Strict validation of history
389
- filtered_history = [
390
- msg for msg in (history_update or [])
391
- if isinstance(msg, dict) and "content" in msg and isinstance(msg["content"], str) and msg["content"]
392
- ]
 
 
 
393
  # Log chatbot state
394
  current_chatbot_state = chatbot.value if hasattr(chatbot, 'value') else []
395
  print(f"set_persona_flow: Current chatbot state: {current_chatbot_state}")
396
- # Clean and validate chatbot state
397
- cleaned_chatbot_state = [
398
- msg for msg in current_chatbot_state
399
- if isinstance(msg, dict) and "content" in msg and isinstance(msg["content"], str) and msg["content"]
400
- ]
401
- # Combine histories
402
- combined_history = cleaned_chatbot_state + filtered_history
403
- print(f"set_persona_flow: Processing yield - status: {status_update}, filtered_history: {filtered_history}, combined_history: {combined_history}")
404
- final_status = status_update
405
- final_prompt = prompt_update
406
- final_profile = profile_update
407
- final_history = combined_history
408
  character_display = f"### Preparing chat with {name}..."
409
  if "Ready to chat" in status_update:
410
  character_display = f"### Chatting with {name}"
411
  elif "Error" in status_update:
412
  character_display = f"### Error creating {name}"
413
- yield status_update, final_prompt, final_profile, character_display, final_history
 
414
  time.sleep(0.1)
415
  except Exception as e:
416
  error_msg = f"Failed to set persona (interface error): {str(e)}"
417
  print(f"set_persona_flow: Exception: {error_msg}")
418
- yield error_msg, final_prompt, final_profile, f"### Error creating {name}", final_history
 
419
  def send_message_flow(message, history):
420
- if history is None:
421
- history = []
422
  if not message.strip():
423
  return "", history
 
424
  if not persona_chat.messages or persona_chat.messages[0]['role'] != 'system':
425
- history.append({"role": "user", "content": message})
426
- history.append({"role": "assistant", "content": "Error: Please create a valid persona first."})
427
  return "", history
428
- history.append({"role": "user", "content": message})
 
 
 
 
429
  response_text = persona_chat.chat(message)
430
- history.append({"role": "assistant", "content": response_text})
 
 
 
 
 
 
431
  return "", history
 
432
  set_persona_button.click(
433
  set_persona_flow,
434
  inputs=[name_input, context_input],
 
132
  if potential_response:
133
  potential_response = re.sub(r'^<\/?s?>', '', potential_response).strip()
134
  potential_response = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', potential_response).strip()
135
+ # Remove special tags
136
+ potential_response = re.sub(r'<end_of_turn>|<start_of_turn>model', '', potential_response).strip()
137
  if potential_response:
138
  return potential_response
139
  cleaned_text = generated_text
 
143
  pass
144
  cleaned_text = re.sub(r'^<\/?s?>', '', cleaned_text).strip()
145
  cleaned_text = re.sub(r'^(assistant|ASSISTANT|System|SYSTEM)[:\s]*', '', cleaned_text).strip()
146
+ # Remove special tags from full output
147
+ cleaned_text = re.sub(r'<end_of_turn>|<start_of_turn>model', '', cleaned_text).strip()
148
  if not cleaned_text and generated_text:
149
  print("Warning: Parsing resulted in empty string, returning original generation.")
150
+ # Still clean special tags from original
151
+ return re.sub(r'<end_of_turn>|<start_of_turn>model', '', generated_text).strip()
152
  if last_input_content and last_occurrence_index == -1:
153
  print("Warning: Could not find last input prompt in LLM output. Returning cleaned full output.")
154
  return cleaned_text
 
244
  pad_token_id=pipe.tokenizer.eos_token_id if pipe.tokenizer.eos_token_id else None
245
  )
246
  parsed_output = parse_llm_output(outputs, messages)
247
+ # Extra cleanup for specific model tags
248
+ parsed_output = re.sub(r'<end_of_turn>|<start_of_turn>model', '', parsed_output).strip()
249
  print("Response generated.")
250
  return parsed_output if parsed_output else "..."
251
  except Exception as e:
 
333
  .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
334
  .main-container { max-width: 1200px; margin: auto; padding: 0; }
335
  .header { background: linear-gradient(90deg, #2c3e50, #4ca1af); color: white; padding: 20px; border-radius: 10px 10px 0 0; margin-bottom: 20px; text-align: center; }
336
+ .setup-section { background-color: #f9f9f9; border-radius: 10px; padding: 20px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); margin-bottom: 20px; }
337
  .chat-section { background-color: white; border-radius: 10px; padding: 20px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); }
338
  .status-bar { background: #e9ecef; padding: 10px 15px; border-radius: 5px; margin: 15px 0; font-weight: 500; border: 1px solid #ced4da; }
339
  .chat-container { border: 1px solid #eaeaea; border-radius: 10px; height: 500px !important; overflow-y: auto; background-color: #ffffff; padding: 10px; }
 
344
  .footer { text-align: center; margin-top: 20px; font-size: 0.9em; color: #666; }
345
  .typing-indicator { color: #aaa; font-style: italic; }
346
  """
347
+
348
+ # Define avatar images with full URLs to ensure they work
349
+ user_avatar = "https://api.dicebear.com/6.x/bottts/svg?seed=user"
350
+ bot_avatar = "https://api.dicebear.com/6.x/bottts/svg?seed=bot"
351
+
352
  with gr.Blocks(css=css, title="AI Persona Simulator") as interface:
353
  with gr.Row(elem_classes="main-container"):
354
  with gr.Column():
 
373
  label="Conversation",
374
  height=450,
375
  elem_classes="chat-container",
376
+ avatar_images=(user_avatar, bot_avatar),
377
+ show_label=False
378
  )
379
  with gr.Row():
380
  msg_input = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter...", elem_classes="message-input", scale=4)
381
  send_button = gr.Button("Send", variant="primary", elem_classes="send-button", scale=1)
382
  with gr.Column(elem_classes="footer"):
383
  gr.Markdown(f"Powered by {MODEL_ID}")
384
+
385
  def set_persona_flow(name, context):
386
  if not name:
387
  yield "Status: Please enter a character name.", "", "", "*No persona created yet*", []
 
398
  final_status, final_prompt, final_profile, final_character_display, final_history = "Error", "", "", f"### Error creating {name}", []
399
  try:
400
  for status_update, prompt_update, profile_update, history_update in persona_chat.set_persona(name, context):
401
+ # For Gradio's Chatbot, convert to tuple list format
402
+ gradio_history = []
403
+ for i in range(0, len(history_update), 2):
404
+ if i+1 < len(history_update):
405
+ user_msg = history_update[i].get("content", "")
406
+ bot_msg = history_update[i+1].get("content", "")
407
+ gradio_history.append([user_msg, bot_msg])
408
+
409
  # Log chatbot state
410
  current_chatbot_state = chatbot.value if hasattr(chatbot, 'value') else []
411
  print(f"set_persona_flow: Current chatbot state: {current_chatbot_state}")
412
+
 
 
 
 
 
 
 
 
 
 
 
413
  character_display = f"### Preparing chat with {name}..."
414
  if "Ready to chat" in status_update:
415
  character_display = f"### Chatting with {name}"
416
  elif "Error" in status_update:
417
  character_display = f"### Error creating {name}"
418
+
419
+ yield status_update, prompt_update, profile_update, character_display, gradio_history
420
  time.sleep(0.1)
421
  except Exception as e:
422
  error_msg = f"Failed to set persona (interface error): {str(e)}"
423
  print(f"set_persona_flow: Exception: {error_msg}")
424
+ yield error_msg, final_prompt, final_profile, f"### Error creating {name}", []
425
+
426
  def send_message_flow(message, history):
 
 
427
  if not message.strip():
428
  return "", history
429
+
430
  if not persona_chat.messages or persona_chat.messages[0]['role'] != 'system':
431
+ history.append([message, "Error: Please create a valid persona first."])
 
432
  return "", history
433
+
434
+ # Add user message to history
435
+ history.append([message, None]) # Add placeholder for bot response
436
+
437
+ # Get response from AI
438
  response_text = persona_chat.chat(message)
439
+
440
+ # Clean any special tags that might still be in the response
441
+ response_text = re.sub(r'<end_of_turn>|<start_of_turn>model', '', response_text).strip()
442
+
443
+ # Update the last message with the actual response
444
+ history[-1][1] = response_text
445
+
446
  return "", history
447
+
448
  set_persona_button.click(
449
  set_persona_flow,
450
  inputs=[name_input, context_input],