|
import streamlit as st |
|
import google.generativeai as genai |
|
import chromadb |
|
import os |
|
import time |
|
from transformers import pipeline |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") |
|
|
|
|
|
genai.configure(api_key=GEMINI_API_KEY) |
|
model = genai.GenerativeModel("gemini-1.5-pro") |
|
|
|
|
|
client = chromadb.PersistentClient(path="./mental_health_memory") |
|
collection = client.get_or_create_collection(name="chat_history") |
|
|
|
|
|
sentiment_pipeline = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", return_all_scores=True) |
|
|
|
|
|
def analyze_sentiment(text): |
|
results = sentiment_pipeline(text) |
|
emotions = {res["label"]: res["score"] for res in results[0]} |
|
return max(emotions, key=emotions.get), emotions |
|
|
|
|
|
def store_chat(user_input, bot_response): |
|
collection.add( |
|
documents=[user_input, bot_response], |
|
metadatas=[{"role": "user"}, {"role": "bot"}], |
|
ids=[str(len(collection.get())) + "_user", str(len(collection.get())) + "_bot"] |
|
) |
|
|
|
|
|
|
|
def retrieve_context(): |
|
history = collection.get() |
|
if len(history["documents"]) > 3: |
|
return history["documents"][-3:] |
|
return history["documents"] |
|
|
|
|
|
def get_gemini_response(user_input): |
|
past_context = retrieve_context() |
|
full_prompt = f"Previous Chat Context: {past_context}\nUser: {user_input}\nBot:" |
|
try: |
|
response = model.generate_content(full_prompt) |
|
return response.text |
|
except Exception as e: |
|
return f"Sorry, I encountered an issue. Error: {str(e)}" |
|
|
|
|
|
st.title("\U0001F9E0 Mental Health Chatbot") |
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
|
|
for role, message in st.session_state.chat_history: |
|
with st.chat_message(role): |
|
st.write(message) |
|
|
|
|
|
user_input = st.chat_input("Ask me anything about mental health...") |
|
|
|
if user_input: |
|
st.chat_message("user").write(user_input) |
|
|
|
|
|
emotion, scores = analyze_sentiment(user_input) |
|
|
|
|
|
with st.spinner("Thinking..."): |
|
bot_response = get_gemini_response(user_input) |
|
|
|
st.chat_message("assistant").write(bot_response) |
|
|
|
|
|
st.session_state.chat_history.append(("user", user_input)) |
|
st.session_state.chat_history.append(("assistant", bot_response)) |
|
store_chat(user_input, bot_response) |
|
|
|
|
|
with st.sidebar: |
|
if st.button("Clear Chat"): |
|
st.session_state.chat_history = [] |
|
collection.delete(ids=collection.get()["ids"]) |
|
st.success("Chat cleared! Refreshing...") |
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
st.markdown("---") |
|
st.markdown("*Note: This chatbot is for informational purposes only and should not replace professional mental health advice.*") |
|
|
|
|
|
if st.button("Sentiment Analysis"): |
|
chat_text = " ".join([msg for _, msg in st.session_state.chat_history]) |
|
if chat_text: |
|
detected_emotion, emotion_scores = analyze_sentiment(chat_text) |
|
st.subheader("Sentiment Analysis Result") |
|
st.write(f"Detected Emotion: **{detected_emotion}**") |
|
st.json(emotion_scores) |
|
else: |
|
st.warning("No chat history found for analysis.") |
|
|