abc / app.py
shibam007's picture
Update app.py
d1d46f8 verified
raw
history blame contribute delete
3.64 kB
import streamlit as st
import google.generativeai as genai
import chromadb
import os
import time
from transformers import pipeline
from dotenv import load_dotenv
# Load API key from .env
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
# Configure Gemini AI
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel("gemini-1.5-pro")
# Initialize ChromaDB for RAG
client = chromadb.PersistentClient(path="./mental_health_memory")
collection = client.get_or_create_collection(name="chat_history")
# Load Sentiment Analysis Model
sentiment_pipeline = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", return_all_scores=True)
# Function to analyze sentiment
def analyze_sentiment(text):
results = sentiment_pipeline(text)
emotions = {res["label"]: res["score"] for res in results[0]}
return max(emotions, key=emotions.get), emotions
# Function to store chat history in ChromaDB
def store_chat(user_input, bot_response):
collection.add(
documents=[user_input, bot_response],
metadatas=[{"role": "user"}, {"role": "bot"}],
ids=[str(len(collection.get())) + "_user", str(len(collection.get())) + "_bot"]
)
# Function to retrieve relevant past messages (RAG)
def retrieve_context():
history = collection.get()
if len(history["documents"]) > 3:
return history["documents"][-3:]
return history["documents"]
# Function to generate a response using LLM with past context
def get_gemini_response(user_input):
past_context = retrieve_context()
full_prompt = f"Previous Chat Context: {past_context}\nUser: {user_input}\nBot:"
try:
response = model.generate_content(full_prompt)
return response.text
except Exception as e:
return f"Sorry, I encountered an issue. Error: {str(e)}"
# Streamlit UI
st.title("\U0001F9E0 Mental Health Chatbot")
# Chat history session
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Display chat history
for role, message in st.session_state.chat_history:
with st.chat_message(role):
st.write(message)
# User input
user_input = st.chat_input("Ask me anything about mental health...")
if user_input:
st.chat_message("user").write(user_input)
# Analyze sentiment
emotion, scores = analyze_sentiment(user_input)
# Generate chatbot response with RAG
with st.spinner("Thinking..."):
bot_response = get_gemini_response(user_input)
st.chat_message("assistant").write(bot_response)
# Store chat history
st.session_state.chat_history.append(("user", user_input))
st.session_state.chat_history.append(("assistant", bot_response))
store_chat(user_input, bot_response)
# Sidebar: Clear Chat & Instructions
with st.sidebar:
if st.button("Clear Chat"):
st.session_state.chat_history = []
collection.delete(ids=collection.get()["ids"])
st.success("Chat cleared! Refreshing...")
st.rerun()
st.markdown("---")
st.markdown("*Note: This chatbot is for informational purposes only and should not replace professional mental health advice.*")
# Sentiment Analysis Button
if st.button("Sentiment Analysis"):
chat_text = " ".join([msg for _, msg in st.session_state.chat_history])
if chat_text:
detected_emotion, emotion_scores = analyze_sentiment(chat_text)
st.subheader("Sentiment Analysis Result")
st.write(f"Detected Emotion: **{detected_emotion}**")
st.json(emotion_scores)
else:
st.warning("No chat history found for analysis.")