|
import os |
|
import openai |
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY", "") |
|
|
|
|
|
|
|
model = SentenceTransformer('jhgan/ko-sroberta-multitask') |
|
|
|
|
|
df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv') |
|
df = df.dropna() |
|
|
|
if 'Unnamed: 3' in df.columns: |
|
df = df.drop(columns=['Unnamed: 3']) |
|
|
|
|
|
df['embedding'] = df['์ ์ '].map(lambda x: model.encode(str(x))) |
|
|
|
|
|
MAX_TURN = 5 |
|
|
|
def set_openai_model(): |
|
""" |
|
GPT-4 ๋์ 'gpt-4o' (์ค์ ๋ก ๋น์กด์ฌ ๋ชจ๋ธ) |
|
=> ์ค์ ๋ก๋ 'gpt-3.5-turbo' ๋ฑ์ผ๋ก ๊ต์ฒด ๊ถ์ฅ |
|
""" |
|
return "gpt-4o" |
|
|
|
EMPATHY_PROMPT = """\ |
|
๋น์ ์ ์น์ ํ ์ ์ ์ํ๊ณผ ์ ๋ฌธ์์ด๋ฉฐ ์ฌ๋ฆฌ์๋ด ์ ๋ฌธ๊ฐ์
๋๋ค. |
|
|
|
์ฌ์ฉ์์ ๋ฌธ์ฅ์ ๊ฑฐ์ ๊ทธ๋๋ก ์์ฝํ๋, ๋์ '๋๊ตฐ์.' ๊ฐ์ ๊ณต๊ฐ ์ด๋ฏธ๋ก ์์ฐ์ค๋ฝ๊ฒ ์๋ตํ๊ณ , |
|
๊ทธ ๋ค์ ์ค์ ํ ๋ฌธ์ฅ์ผ๋ก ์ง๋ฌธ์ ์์ฑํ์ธ์. |
|
|
|
์ ์์ฌํญ: |
|
1) ์ฒซ ๋ฌธ์ฅ์ ๊ณต๊ฐํ ์์ฝ (์: "์ํ์ ์๋๊ณ ๋ถ์ํด์ ๋ฉฐ์น ์งธ ์ ์ ๋ชป ์๊ณ ๊ณ์๋๊ตฐ์.") |
|
2) ๋ ๋ฒ์งธ ๋ฌธ์ฅ์ ํ์/์ ๋ ์ง๋ฌธ |
|
- ์: "์ด๋ค ๊ณ ๋ฏผ๋ค์ด ๋ฐค์ ๊ฐ์ฅ ๋ง์ด ๋ ์ค๋ฅด์๋์?" |
|
|
|
(์์) |
|
์ฌ์ฉ์: "์ํ์ ์๋๊ณ ๋ถ์ํด์ ๋ฉฐ์น ์งธ ์ ์ด ์ ์์." |
|
์ฑ๋ด: |
|
"์ํ์ ์๋๊ณ ๋ถ์ํด์ ๋ฉฐ์น ์งธ ์ ์ ๋ชป ์๊ณ ๊ณ์๋๊ตฐ์. |
|
์ํ ๊ธฐ๊ฐ์ด ๋ค๊ฐ์ฌ ๋ ๊ฐ์ฅ ํ๋์ ๋ถ๋ถ์ ๋ฌด์์ธ๊ฐ์?" |
|
|
|
์ด์ ์ฌ์ฉ์ ๋ฐํ๋ฅผ ์๋์ ์ฃผ๊ฒ ์ต๋๋ค: |
|
์ฌ์ฉ์ ๋ฐํ: "{sentence}" |
|
์ฑ๋ด: |
|
""" |
|
|
|
SOCRATIC_PROMPT = """\ |
|
๋น์ ์ ์ ์ ์ํ๊ณผ ์ ๋ฌธ์์ด๋ฉฐ Socratic CBT ๊ธฐ๋ฒ์ ์ฌ์ฉํ๋ ์ฌ๋ฆฌ์๋ด ์ ๋ฌธ๊ฐ์
๋๋ค. |
|
|
|
์๋ '๋ํ ํํธ'์๋ ์ฌ์ฉ์๊ฐ ์ง์ ๊น์ง ์ด์ผ๊ธฐํ ์ํฉ์ด๋ ๊ณ ๋ฏผ์ด ์์ฝ๋์ด ์๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. |
|
์ด ๋ด์ฉ์ **๊ณต๊ฐ**์ ํ์ํ ๋ค, ๊ทธ ํ๋ฆ์ ์ด์ด๋ฐ์ **์์ฐ์ค๋ฝ๊ณ ๊ตฌ์ฒด์ ์ธ ํ์ ์ง๋ฌธ**์ ํ ๋ฌธ์ฅ์ผ๋ก ์์ฑํ์ธ์. |
|
|
|
**์ธ๋ถ ์ง์นจ**: |
|
1) ์ฒซ ๋ฌธ์ฅ์ ์ฌ์ฉ์์ ์ํฉ์ ๊ฐ๋จํ ๊ณต๊ฐํด ์ฃผ๋, ๋์ '๋๊ตฐ์.' ๋ฑ์ ์ด๋ฏธ๋ก ์์ฐ์ค๋ฝ๊ฒ ๋ง๋ฌด๋ฆฌํ์ธ์. |
|
- ์: "์ํ ๊ธฐ๊ฐ ๋์ ์ ๋ง ๋ง์ ๋ถ๋ด์ ๋๋ผ๊ณ ๊ณ์๋๊ตฐ์." |
|
2) ๋ ๋ฒ์งธ ๋ฌธ์ฅ์ ํ์/์ ๋ ์ง๋ฌธ์ ๋ฑ ํ ๋ฌธ์ฅ์ผ๋ก ์์ฑํ์ธ์. |
|
- '์ง๋ฌธ:' ๊ฐ์ ์ ๋์ด๋ ์ฐ์ง ๋ง๊ณ , ๋ฐ๋ก ๋ฌธ์ฅ์ผ๋ก ์์ํฉ๋๋ค. |
|
- ๋ฐ๋์ ๋ฌผ์ํ๋ก ๋๋์ผ ํฉ๋๋ค (์: "...์ด๋ค ๊ฒ๋ค์ด ๊ฐ์ฅ ํ๋์
จ๋์?"). |
|
3) ์ง๋ฌธ์ ์ฌ์ฉ์์ ํ์ฌ ๊ณ ๋ฏผ๊ณผ ์ง์ ์ ์ผ๋ก ์ฐ๊ฒฐ๋์ด, ์ฌ์ธต์ ์ธ ์๊ธฐ ํ์์ ์ ๋ํด์ผ ํฉ๋๋ค. |
|
- ์: "๋ฐค์ ๋ค๋ ค์ค๋ ์ด๋ค ์๊ฐ๋ค์ด ์ ์ ๋ ์ค์น๊ฒ ๋ง๋๋์ง ํน์ ๋ ์ค๋ฅด์๋์?" |
|
4) Bullet Point๋ ๋ชฉ๋ก ๋์ , ๊ฐ๋จํ ๋ ์ค(๊ณต๊ฐ + ์ง๋ฌธ) ๊ตฌ์กฐ๋ก ์์ฑํ๋, ๋๋ฌด ๊ธธ๊ฒ ์ฐ์ง ๋ง๊ณ ๋ถ๋๋ฌ์ด ํค์ ์ ์งํ์ธ์. |
|
|
|
(์์) |
|
์ฌ์ฉ์ ๋ฐํ: "๋จํธ์ด ๋นํธ์ฝ์ธ ํฌ์๋ก ์์ ์ฉ์ด๋ค" |
|
์ฑ๋ด: |
|
"๋จํธ๋ถ์ ํฌ์ ๋ฌธ์ ๋ก ์์ด ๋ง์ด ์ํ์๋๊ตฐ์. |
|
ํน์ ๊ทธ๋ก ์ธํด ๊ฐ์ฅ ํ๋ค๋ค๊ณ ๋๋ผ๋ ๋ถ๋ถ์ ๋ฌด์์ธ๊ฐ์?" |
|
|
|
์ด์ ์๋ '๋ํ ํํธ'๋ฅผ ์ฐธ์กฐํ์ฌ, 1์ค ๊ณต๊ฐ + 1์ค ์ง๋ฌธ ๋ ์ค๋ก ๋ต๋ณํด ์ฃผ์ธ์. |
|
|
|
๋ํ ํํธ: |
|
{context} |
|
""" |
|
|
|
ADVICE_PROMPT = """\ |
|
๋น์ ์ ์ ์ ์ํ๊ณผ ์ ๋ฌธ์์ด๋ฉฐ Socratic CBT ๊ธฐ๋ฒ์ ์ฌ์ฉํ๋ ์ฌ๋ฆฌ์๋ด ์ ๋ฌธ๊ฐ์
๋๋ค. |
|
|
|
์๋ ํํธ(๋ํ ์์ฝ)์ ํจ๊ป, ๋ค์์ ์ ์๋ 5๊ฐ์ง CBT ๊ธฐ๋ฒ์ ์ ์ ํ ์ฐธ๊ณ ํ์ฌ, |
|
์ฌ์ฉ์ ๋ง์ถคํ์ผ๋ก ๊ตฌ์ฒด์ ์ด๊ณ ๊ณต๊ฐ ์ด๋ฆฐ ์กฐ์ธ์ ํ๊ตญ์ด๋ก ์์ฑํ์ธ์: |
|
|
|
(1) ์๋ฉด ์ ํ ์๋ฒ (Sleep Restriction): |
|
"์๋ฉด ์ ํ ์๋ฒ์ ์นจ๋์ ๋จธ๋ฌด๋ ์๊ฐ์ ์๋์ ์ผ๋ก ์ค์ฌ, ์นจ๋์ ์๋ฉด ์ฌ์ด์ ์ฌ๋ฐ๋ฅธ ์ฐ๊ฒฐ๊ณ ๋ฆฌ๋ฅผ ์ฌ๊ตฌ์ถํ๋ ๋ฐฉ๋ฒ์
๋๋ค. |
|
์๋ฅผ ๋ค์ด, ์นจ๋์ 10์๊ฐ ๋จธ๋ฌผ์ง๋ง ์ค์ ์๋ฉด ์๊ฐ์ด 5์๊ฐ์ธ ๊ฒฝ์ฐ, ์ฒ์์๋ 5์๊ฐ๋ง ์นจ๋์์ ์๊ณ ์ ์ฐจ ์๊ฐ์ ๋๋ ค๊ฐ๋ฉฐ |
|
'์นจ๋๋ ์๋ฉด์ ์ํ ์ฅ์'๋ก ์ธ์ํ๋๋ก ๋์ต๋๋ค." |
|
|
|
(2) ์๊ทน ์กฐ์ ์๋ฒ (Stimulus Control): |
|
"์๊ทน ์กฐ์ ์๋ฒ์ ์นจ๋์ ์๋ฉด์ ํ๊ฒฝ์ ์ฌ์ ๋ฆฝํ๋ฉฐ, ์นจ๋๋ฅผ ์ค์ง ์๋ฉด๋ง์ ์ํ ์ฅ์๋ก ์ธ์ํ๊ฒ ๋ง๋๋ ์น๋ฃ๋ฒ์
๋๋ค. |
|
์๋ฅผ ๋ค์ด, ์นจ๋์ ๋์ ์์ ๋๋ ์ฆ์ ์ ๋ค์ง ๋ชปํ๋๋ผ๋, ์นจ๋์์๋ ์ค์ง ์๋ฉด์ ์ทจํ๋ ์ต๊ด์ ๊ธฐ๋ฅด๋ ๊ฒ์ด ํต์ฌ์
๋๋ค." |
|
|
|
(3) ์๋ฉด ์์ ๊ต์ก (Sleep Hygiene): |
|
"์๋ฉด ์์ ๊ต์ก์ ๊ฑด๊ฐํ ์๋ฉด์ ์ํด ์ํ ์ต๊ด์ ๊ฐ์ ํ๋ ๋ฐฉ๋ฒ์
๋๋ค. |
|
์นดํ์ธ ์ญ์ทจ๋ฅผ ์ค์ด๊ฑฐ๋ ๋ฆ์ ์๊ฐ์ ์ ์๊ธฐ๊ธฐ ์ฌ์ฉยท๋ฐ์ ์กฐ๋ช
๋ฑ์ ํผํ๊ณ , ๋ฎ์๋ ๊ฐ๋ฒผ์ด ์ด๋์ ํด๋๋ ๋ฑ์ ์ต๊ด์ ํฌํจํฉ๋๋ค." |
|
|
|
(4) ์ด์ ๊ธฐ๋ฒ (Relaxation Techniques): |
|
"์ด์ ๊ธฐ๋ฒ์ ์ฌํธํก, ์ ์ง์ ๊ทผ์ก์ด์, ๋ช
์ ๊ฐ์ ๋ฐฉ๋ฒ์ ํตํด ์์ฐ์ค๋ฌ์ด ์๋ฉด์ ์ ๋ํ๋ ๋ฐฉ๋ฒ์
๋๋ค. |
|
๋ชธ์ ์ค์บํ๊ณ , ๊ฑฐ๋ถํ ์คํธ๋ ์นญ์ ํ๊ณ , ๊ทผ์ก ์ด์์ ์ฐ์ตํ๋ฉฐ ๊ธด์ฅ์ ๋ฎ์ถ๋ ๊ฒ์ด ์ฃผ๋ ๋ชฉํ์
๋๋ค." |
|
|
|
(5) ์ธ์ง ์ฌ๊ตฌ์ฑ (Cognitive Restructuring): |
|
"์ธ์ง ์ฌ๊ตฌ์ฑ์ โ์ฐ๋ฆฌ๊ฐ ์ํฉ์ ์ด๋ป๊ฒ ๋ฐ๋ผ๋ณด๋๋์ ๋ฐ๋ผ ๋ชธ์ ๋ฐ์๋ ๋ฌ๋ผ์ง ์ ์๋คโ๋ ๊ธ์ ์ ๊ด์ ์ผ๋ก ์ ํ์ํค๋ฉฐ, |
|
๊ฑฑ์ ์ด๋ ๋ถ์, ๋ถ์ ์ ์ธ ์ฌ๊ณ ํจํด์ ์ ๊ฒยท์กฐ์ ํ๋ ๊ธฐ๋ฒ์
๋๋ค. ์ด๋ฅผ ํตํด ์ฌ์ฉ์์ ๊ฑฑ์ ์ ์ํํ๊ณ |
|
์๊ธฐํจ๋ฅ๊ฐ์ ๋์ด๋๋ก ๋์ต๋๋ค." |
|
|
|
์๋ ์ฌํญ์ ๊ผญ ๋ฐ์ํด ์ฃผ์ธ์: |
|
- ๋ถ์์ ์ํํ๊ธฐ ์ํ ์ ๊ธฐ๋ฒ๋ค์ ์์ฐ์ค๋ฝ๊ฒ ๋
น์ด๋, ์ฌ์ฉ์์ ํ์ฌ ์ํฉ(ํํธ์ ๋ด๊ธด ๊ณ ๋ฏผ)๊ณผ ์ฐ๊ฒฐํด ์ด์ผ๊ธฐํ์ธ์. |
|
- ๋๋ฌด ๋ฑ๋ฑํ์ง ์๊ฒ, ๋ถ๋๋ฝ๊ณ ์น์ ํ ๋งํฌ๋ฅผ ์ฌ์ฉํ์ธ์. |
|
|
|
ํํธ: |
|
{hints} |
|
|
|
์กฐ์ธ: |
|
""" |
|
|
|
|
|
|
|
def call_empathy(user_input: str) -> str: |
|
""" ๊ณต๊ฐ ์์ฝ ์์ฑ """ |
|
prompt = EMPATHY_PROMPT.format(sentence=user_input) |
|
resp = openai.ChatCompletion.create( |
|
model=set_openai_model(), |
|
messages=[ |
|
{"role":"system","content":"๋น์ ์ ์น์ ํ ์ฌ๋ฆฌ์๋ด ์ ๋ฌธ๊ฐ์
๋๋ค."}, |
|
{"role":"user","content":prompt} |
|
], |
|
max_tokens=150, |
|
temperature=0.7 |
|
) |
|
return resp.choices[0].message.content.strip() |
|
|
|
def call_socratic_question(context: str) -> str: |
|
""" ์ํฌ๋ผํ
์ค ํ์์ง๋ฌธ 1๋ฌธ์ฅ ์์ฑ """ |
|
prompt = f"{SOCRATIC_PROMPT}\n\n๋ํ ํํธ:\n{context}" |
|
resp = openai.ChatCompletion.create( |
|
model=set_openai_model(), |
|
messages=[ |
|
{"role":"system","content":"๋น์ ์ Socratic CBT ์ ๋ฌธ๊ฐ์
๋๋ค."}, |
|
{"role":"user","content":prompt} |
|
], |
|
max_tokens=200, |
|
temperature=0.7 |
|
) |
|
return resp.choices[0].message.content.strip() |
|
|
|
def call_advice(hints: str) -> str: |
|
""" ์ต์ข
CBT ์กฐ์ธ """ |
|
final_prompt = ADVICE_PROMPT.format(hints=hints) |
|
resp = openai.ChatCompletion.create( |
|
model=set_openai_model(), |
|
messages=[ |
|
{"role":"system","content":"๋น์ ์ Socratic CBT ๊ธฐ๋ฒ ์ ๋ฌธ๊ฐ์
๋๋ค."}, |
|
{"role":"user","content":final_prompt} |
|
], |
|
max_tokens=700, |
|
temperature=0.8 |
|
) |
|
return resp.choices[0].message.content.strip() |
|
|
|
|
|
def predict(user_input: str, state: dict): |
|
history = state.get("history", []) |
|
stage = state.get("stage", "EMPATHY") |
|
turn = state.get("turn", 0) |
|
hints = state.get("hints", []) |
|
|
|
|
|
history.append(("User", user_input)) |
|
|
|
|
|
query_emb = model.encode(user_input) |
|
df["sim"] = df["embedding"].map(lambda emb: cosine_similarity([query_emb],[emb]).squeeze()) |
|
|
|
|
|
if df["sim"].count() == 0: |
|
|
|
kb_answer = "์ ํฉํ ์ง์๋ฒ ์ด์ค ์๋ต์ ์ฐพ์ง ๋ชปํ์ด์." |
|
else: |
|
kb_answer = df.loc[df["sim"].idxmax(), "์ฑ๋ด"] |
|
|
|
hints.append(f"[KB] {kb_answer}") |
|
|
|
|
|
if stage == "EMPATHY": |
|
empathic = call_empathy(user_input) |
|
history.append(("Chatbot", empathic)) |
|
hints.append(empathic) |
|
stage = "SQ" |
|
turn = 0 |
|
return history, {"history": history, "stage": stage, "turn": turn, "hints": hints} |
|
|
|
if stage == "SQ" and turn < MAX_TURN: |
|
|
|
context_text = "\n".join([f"{r}: {c}" for (r,c) in history]) + "\n" + "\n".join(hints) |
|
sq = call_socratic_question(context_text) |
|
history.append(("Chatbot", sq)) |
|
hints.append(sq) |
|
turn += 1 |
|
return history, {"history": history, "stage": stage, "turn": turn, "hints": hints} |
|
|
|
|
|
stage = "END" |
|
combined_hints = "\n".join(hints) |
|
advice = call_advice(combined_hints) |
|
history.append(("Chatbot", advice)) |
|
|
|
return history, {"history":history, "stage":stage, "turn":turn, "hints":hints} |
|
|
|
|
|
def gradio_predict(user_input, chat_state): |
|
new_history, new_state = predict(user_input, chat_state) |
|
|
|
|
|
display_history = [] |
|
for (role, txt) in new_history: |
|
if role == "User": |
|
display_history.append([txt, ""]) |
|
else: |
|
if len(display_history) == 0: |
|
display_history.append(["", txt]) |
|
else: |
|
display_history[-1][1] = txt |
|
|
|
|
|
return display_history, new_state, "" |
|
|
|
def create_app(): |
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot() |
|
chat_state = gr.State({"history": [], "stage": "EMPATHY", "turn": 0, "hints": []}) |
|
txt = gr.Textbox(show_label=False, placeholder="ํน์ ์ ์ ์ด๋ฃจ์ง ๋ชปํ๊ณ ๊ณ์ ๊ฐ์? ๋น์ ์ ์ด์ผ๊ธฐ๋ฅผ ๋ฃ๊ณ ์ถ์ด์! ๊ฑฑ์ ์ด ์์ผ์๋ฉด ํธํ๊ฒ ๋ง์ํด์ฃผ์ธ์ :D") |
|
|
|
|
|
txt.submit( |
|
fn=gradio_predict, |
|
inputs=[txt, chat_state], |
|
outputs=[chatbot, chat_state, txt] |
|
) |
|
|
|
return demo |
|
|
|
app = create_app() |
|
|
|
if __name__ == "__main__": |
|
|
|
app.launch(debug=True, share=True) |