shibam007 commited on
Commit
8954c92
·
verified ·
1 Parent(s): a066407

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -1
app.py CHANGED
@@ -1,2 +1,111 @@
1
  import streamlit as st
2
- st.title("hello")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import google.generativeai as genai
3
+ import chromadb
4
+ import os
5
+ import time
6
+ from transformers import pipeline
7
+ from dotenv import load_dotenv
8
+
9
+ # Load API key from .env
10
+ load_dotenv()
11
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
12
+
13
+ # Configure Gemini AI
14
+ genai.configure(api_key=GEMINI_API_KEY)
15
+ model = genai.GenerativeModel("gemini-1.5-pro")
16
+
17
+ # Initialize ChromaDB for RAG
18
+ client = chromadb.PersistentClient(path="./mental_health_memory")
19
+ collection = client.get_or_create_collection(name="chat_history")
20
+
21
+ # Load Sentiment Analysis Model
22
+ sentiment_pipeline = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", return_all_scores=True)
23
+
24
+ # Function to analyze sentiment
25
+ def analyze_sentiment(text):
26
+ results = sentiment_pipeline(text)
27
+ emotions = {res["label"]: res["score"] for res in results[0]}
28
+ return max(emotions, key=emotions.get), emotions
29
+
30
+ # Function to store chat history in ChromaDB
31
+ def store_chat(user_input, bot_response):
32
+ collection.add(
33
+ documents=[user_input, bot_response],
34
+ metadatas=[{"role": "user"}, {"role": "bot"}],
35
+ ids=[str(len(collection.get())) + "_user", str(len(collection.get())) + "_bot"]
36
+ )
37
+
38
+ # Function to retrieve relevant past messages (RAG)
39
+ def retrieve_context():
40
+ history = collection.get()
41
+ if len(history["documents"]) > 3:
42
+ return history["documents"][-3:]
43
+ return history["documents"]
44
+
45
+ # Function to generate a response using LLM with past context
46
+ def get_gemini_response(user_input):
47
+ past_context = retrieve_context()
48
+ full_prompt = f"Previous Chat Context: {past_context}\nUser: {user_input}\nBot:"
49
+ try:
50
+ response = model.generate_content(full_prompt)
51
+ return response.text
52
+ except Exception as e:
53
+ return f"Sorry, I encountered an issue. Error: {str(e)}"
54
+
55
+ # Streamlit UI
56
+ st.title("\U0001F9E0 Mental Health Chatbot")
57
+
58
+ # Chat history session
59
+ if "chat_history" not in st.session_state:
60
+ st.session_state.chat_history = []
61
+
62
+ # Display chat history
63
+ for role, message in st.session_state.chat_history:
64
+ with st.chat_message(role):
65
+ st.write(message)
66
+
67
+ # User input
68
+ user_input = st.chat_input("Ask me anything about mental health...")
69
+
70
+ if user_input:
71
+ st.chat_message("user").write(user_input)
72
+
73
+ # Analyze sentiment
74
+ emotion, scores = analyze_sentiment(user_input)
75
+
76
+ # Generate chatbot response with RAG
77
+ with st.spinner("Thinking..."):
78
+ bot_response = get_gemini_response(user_input)
79
+
80
+ st.chat_message("assistant").write(bot_response)
81
+
82
+ # Store chat history
83
+ st.session_state.chat_history.append(("user", user_input))
84
+ st.session_state.chat_history.append(("assistant", bot_response))
85
+ store_chat(user_input, bot_response)
86
+
87
+ # Sidebar: Clear Chat & Instructions
88
+ with st.sidebar:
89
+ if st.button("Clear Chat"):
90
+ st.session_state.chat_history = []
91
+ collection.delete(ids=collection.get()["ids"])
92
+ st.success("Chat cleared! Refreshing...")
93
+ st.rerun()
94
+
95
+
96
+
97
+
98
+
99
+ st.markdown("---")
100
+ st.markdown("*Note: This chatbot is for informational purposes only and should not replace professional mental health advice.*")
101
+
102
+ # Sentiment Analysis Button
103
+ if st.button("Sentiment Analysis"):
104
+ chat_text = " ".join([msg for _, msg in st.session_state.chat_history])
105
+ if chat_text:
106
+ detected_emotion, emotion_scores = analyze_sentiment(chat_text)
107
+ st.subheader("Sentiment Analysis Result")
108
+ st.write(f"Detected Emotion: **{detected_emotion}**")
109
+ st.json(emotion_scores)
110
+ else:
111
+ st.warning("No chat history found for analysis.")