Spaces:
Running
Running
ahmadgenus
commited on
Commit
·
ed2c8ad
1
Parent(s):
318f035
Switched to SQLite for Hugging Face deployment
Browse files- chatbot.py +53 -110
chatbot.py
CHANGED
@@ -1,16 +1,6 @@
|
|
1 |
-
|
2 |
-
|
3 |
from langchain.chains import LLMChain
|
4 |
-
# chat_chain = LLMChain(
|
5 |
-
# llm=llm,
|
6 |
-
# prompt=chat_prompt,
|
7 |
-
# memory=memory,
|
8 |
-
# verbose=True # Enable verbose logging for debugging
|
9 |
-
# )
|
10 |
-
|
11 |
-
|
12 |
import os
|
13 |
-
import
|
14 |
import praw
|
15 |
import json
|
16 |
from datetime import datetime, timedelta
|
@@ -26,8 +16,7 @@ load_dotenv()
|
|
26 |
# Initialize the LLM via LangChain (using Groq)
|
27 |
llm = ChatGroq(
|
28 |
groq_api_key=os.getenv("GROQ_API_KEY"),
|
29 |
-
|
30 |
-
model_name= "meta-llama/llama-4-maverick-17b-128e-instruct",
|
31 |
temperature=0.2
|
32 |
)
|
33 |
|
@@ -41,32 +30,27 @@ reddit = praw.Reddit(
|
|
41 |
user_agent=os.getenv("REDDIT_USER_AGENT")
|
42 |
)
|
43 |
|
44 |
-
#
|
45 |
-
import psycopg2
|
46 |
-
import os
|
47 |
-
|
48 |
def get_db_conn():
|
49 |
-
return
|
50 |
-
|
51 |
-
# Set up the database schema: store raw post text, comments, computed embedding, and metadata.
|
52 |
|
|
|
53 |
def setup_db():
|
54 |
conn = get_db_conn()
|
55 |
cur = conn.cursor()
|
56 |
try:
|
57 |
-
cur.execute("""
|
58 |
CREATE TABLE IF NOT EXISTS reddit_posts (
|
59 |
-
id
|
60 |
-
reddit_id
|
61 |
keyword TEXT,
|
62 |
title TEXT,
|
63 |
post_text TEXT,
|
64 |
-
comments
|
65 |
-
created_at
|
66 |
-
embedding
|
67 |
-
metadata
|
68 |
);
|
69 |
-
CREATE INDEX IF NOT EXISTS idx_keyword_created_at ON reddit_posts(keyword, created_at DESC);
|
70 |
""")
|
71 |
conn.commit()
|
72 |
except Exception as e:
|
@@ -74,29 +58,9 @@ def setup_db():
|
|
74 |
finally:
|
75 |
cur.close()
|
76 |
conn.close()
|
77 |
-
# def setup_db():
|
78 |
-
# conn = get_db_conn()
|
79 |
-
# cur = conn.cursor()
|
80 |
-
# cur.execute("""
|
81 |
-
# CREATE EXTENSION IF NOT EXISTS vector;
|
82 |
-
# CREATE TABLE IF NOT EXISTS reddit_posts (
|
83 |
-
# id SERIAL PRIMARY KEY,
|
84 |
-
# reddit_id VARCHAR(50) UNIQUE,
|
85 |
-
# keyword TEXT,
|
86 |
-
# title TEXT,
|
87 |
-
# post_text TEXT,
|
88 |
-
# comments JSONB,
|
89 |
-
# created_at TIMESTAMP,
|
90 |
-
# embedding VECTOR(384),
|
91 |
-
# metadata JSONB
|
92 |
-
# );
|
93 |
-
# CREATE INDEX IF NOT EXISTS idx_keyword_created_at ON reddit_posts(keyword, created_at DESC);
|
94 |
-
# """)
|
95 |
-
# conn.commit()
|
96 |
-
# cur.close()
|
97 |
-
# conn.close()
|
98 |
|
99 |
-
#
|
|
|
100 |
def keyword_in_post_or_comments(post, keyword):
|
101 |
keyword_lower = keyword.lower()
|
102 |
combined_text = (post.title + " " + post.selftext).lower()
|
@@ -108,19 +72,19 @@ def keyword_in_post_or_comments(post, keyword):
|
|
108 |
return True
|
109 |
return False
|
110 |
|
111 |
-
# Fetch
|
112 |
-
|
113 |
def fetch_reddit_data(keyword, days=7, limit=None):
|
114 |
end_time = datetime.utcnow()
|
115 |
start_time = end_time - timedelta(days=days)
|
116 |
subreddit = reddit.subreddit("all")
|
117 |
posts_generator = subreddit.search(keyword, sort="new", time_filter="all", limit=limit)
|
118 |
-
|
119 |
data = []
|
120 |
for post in posts_generator:
|
121 |
created = datetime.utcfromtimestamp(post.created_utc)
|
122 |
if created < start_time:
|
123 |
-
break
|
124 |
if not keyword_in_post_or_comments(post, keyword):
|
125 |
continue
|
126 |
|
@@ -139,81 +103,66 @@ def fetch_reddit_data(keyword, days=7, limit=None):
|
|
139 |
"title": post.title,
|
140 |
"post_text": post.selftext,
|
141 |
"comments": comments,
|
142 |
-
"created_at": created,
|
143 |
"embedding": embedding,
|
144 |
"metadata": metadata
|
145 |
})
|
146 |
save_to_db(data)
|
147 |
|
148 |
-
# Save
|
|
|
149 |
def save_to_db(posts):
|
150 |
conn = get_db_conn()
|
151 |
cur = conn.cursor()
|
152 |
for post in posts:
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
168 |
conn.commit()
|
169 |
cur.close()
|
170 |
conn.close()
|
171 |
|
172 |
-
# Retrieve context from
|
173 |
-
|
174 |
def retrieve_context(question, keyword, reddit_id=None, top_k=10):
|
175 |
lower_q = question.lower()
|
176 |
-
|
177 |
-
if any(word in lower_q for word in ["summarize", "overview", "all posts"]):
|
178 |
-
requested_top_k = 50
|
179 |
-
else:
|
180 |
-
requested_top_k = top_k
|
181 |
|
182 |
-
# Retrieve posts based on query embedding.
|
183 |
-
query_embedding = embedder.encode(question).tolist()
|
184 |
-
query_embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
185 |
-
|
186 |
conn = get_db_conn()
|
187 |
cur = conn.cursor()
|
|
|
188 |
if reddit_id:
|
189 |
cur.execute("""
|
190 |
SELECT title, post_text, comments FROM reddit_posts
|
191 |
-
WHERE reddit_id =
|
192 |
""", (reddit_id,))
|
193 |
else:
|
194 |
cur.execute("""
|
195 |
SELECT title, post_text, comments FROM reddit_posts
|
196 |
-
WHERE keyword =
|
197 |
-
|
198 |
-
|
199 |
results = cur.fetchall()
|
|
|
200 |
conn.close()
|
201 |
-
|
202 |
-
# If there are fewer posts than requested and none were retrieved by vector search,
|
203 |
-
# fall back to retrieving all posts for that keyword.
|
204 |
-
if not results:
|
205 |
-
conn = get_db_conn()
|
206 |
-
cur = conn.cursor()
|
207 |
-
cur.execute("""
|
208 |
-
SELECT title, post_text, comments FROM reddit_posts
|
209 |
-
WHERE keyword = %s ORDER BY created_at DESC;
|
210 |
-
""", (keyword,))
|
211 |
-
results = cur.fetchall()
|
212 |
-
conn.close()
|
213 |
return results
|
214 |
|
215 |
-
#
|
216 |
-
|
217 |
summarize_prompt = ChatPromptTemplate.from_template("""
|
218 |
You are a summarizer. Summarize the following context from Reddit posts into a concise summary that preserves the key insights. Do not add extra commentary.
|
219 |
|
@@ -224,10 +173,9 @@ Summary:
|
|
224 |
""")
|
225 |
summarize_chain = LLMChain(llm=llm, prompt=summarize_prompt)
|
226 |
|
|
|
227 |
|
228 |
-
# Set up conversation memory and chain.
|
229 |
memory = ConversationBufferMemory(memory_key="chat_history")
|
230 |
-
# Updated prompt: we now expect a single input field "input"
|
231 |
chat_prompt = ChatPromptTemplate.from_template("""
|
232 |
Chat History:
|
233 |
{chat_history}
|
@@ -236,27 +184,22 @@ Context from Reddit and User Question:
|
|
236 |
{input}
|
237 |
|
238 |
Act as an Professional Assistant as incremental chat agent and also give reasioning and Answer clearly based on context and chat history, your response should be valid and concise, and relavant .
|
239 |
-
|
240 |
""")
|
241 |
|
242 |
chat_chain = LLMChain(
|
243 |
llm=llm,
|
244 |
prompt=chat_prompt,
|
245 |
memory=memory,
|
246 |
-
verbose=True
|
247 |
)
|
248 |
|
249 |
-
#
|
250 |
-
# Updated get_chatbot_response to handle summarization if context is too long.
|
251 |
|
252 |
def get_chatbot_response(question, keyword, reddit_id=None):
|
253 |
context_posts = retrieve_context(question, keyword, reddit_id)
|
254 |
context = "\n\n".join([f"{p[0]}:\n{p[1]}" for p in context_posts])
|
255 |
-
|
256 |
-
# Set a threshold (e.g., 3000 characters); if context length exceeds it, compress the context.
|
257 |
if len(context) > 3000:
|
258 |
context = summarize_chain.run({"context": context})
|
259 |
-
|
260 |
combined_input = f"Context:\n{context}\n\nUser Question: {question}"
|
261 |
response = chat_chain.run({"input": combined_input})
|
262 |
-
return response, context_posts
|
|
|
|
|
|
|
1 |
from langchain.chains import LLMChain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
3 |
+
import sqlite3
|
4 |
import praw
|
5 |
import json
|
6 |
from datetime import datetime, timedelta
|
|
|
16 |
# Initialize the LLM via LangChain (using Groq)
|
17 |
llm = ChatGroq(
|
18 |
groq_api_key=os.getenv("GROQ_API_KEY"),
|
19 |
+
model_name="meta-llama/llama-4-maverick-17b-128e-instruct",
|
|
|
20 |
temperature=0.2
|
21 |
)
|
22 |
|
|
|
30 |
user_agent=os.getenv("REDDIT_USER_AGENT")
|
31 |
)
|
32 |
|
33 |
+
# SQLite DB Connection
|
|
|
|
|
|
|
34 |
def get_db_conn():
|
35 |
+
return sqlite3.connect("reddit_data.db", check_same_thread=False)
|
|
|
|
|
36 |
|
37 |
+
# Set up the database schema
|
38 |
def setup_db():
|
39 |
conn = get_db_conn()
|
40 |
cur = conn.cursor()
|
41 |
try:
|
42 |
+
cur.execute("""
|
43 |
CREATE TABLE IF NOT EXISTS reddit_posts (
|
44 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
45 |
+
reddit_id TEXT UNIQUE,
|
46 |
keyword TEXT,
|
47 |
title TEXT,
|
48 |
post_text TEXT,
|
49 |
+
comments TEXT,
|
50 |
+
created_at TEXT,
|
51 |
+
embedding TEXT,
|
52 |
+
metadata TEXT
|
53 |
);
|
|
|
54 |
""")
|
55 |
conn.commit()
|
56 |
except Exception as e:
|
|
|
58 |
finally:
|
59 |
cur.close()
|
60 |
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
# Keyword filter
|
63 |
+
|
64 |
def keyword_in_post_or_comments(post, keyword):
|
65 |
keyword_lower = keyword.lower()
|
66 |
combined_text = (post.title + " " + post.selftext).lower()
|
|
|
72 |
return True
|
73 |
return False
|
74 |
|
75 |
+
# Fetch and process Reddit data
|
76 |
+
|
77 |
def fetch_reddit_data(keyword, days=7, limit=None):
|
78 |
end_time = datetime.utcnow()
|
79 |
start_time = end_time - timedelta(days=days)
|
80 |
subreddit = reddit.subreddit("all")
|
81 |
posts_generator = subreddit.search(keyword, sort="new", time_filter="all", limit=limit)
|
82 |
+
|
83 |
data = []
|
84 |
for post in posts_generator:
|
85 |
created = datetime.utcfromtimestamp(post.created_utc)
|
86 |
if created < start_time:
|
87 |
+
break
|
88 |
if not keyword_in_post_or_comments(post, keyword):
|
89 |
continue
|
90 |
|
|
|
103 |
"title": post.title,
|
104 |
"post_text": post.selftext,
|
105 |
"comments": comments,
|
106 |
+
"created_at": created.isoformat(),
|
107 |
"embedding": embedding,
|
108 |
"metadata": metadata
|
109 |
})
|
110 |
save_to_db(data)
|
111 |
|
112 |
+
# Save data into SQLite
|
113 |
+
|
114 |
def save_to_db(posts):
|
115 |
conn = get_db_conn()
|
116 |
cur = conn.cursor()
|
117 |
for post in posts:
|
118 |
+
try:
|
119 |
+
cur.execute("""
|
120 |
+
INSERT OR IGNORE INTO reddit_posts
|
121 |
+
(reddit_id, keyword, title, post_text, comments, created_at, embedding, metadata)
|
122 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
|
123 |
+
""", (
|
124 |
+
post["reddit_id"],
|
125 |
+
post["keyword"],
|
126 |
+
post["title"],
|
127 |
+
post["post_text"],
|
128 |
+
json.dumps(post["comments"]),
|
129 |
+
post["created_at"],
|
130 |
+
json.dumps(post["embedding"]),
|
131 |
+
json.dumps(post["metadata"])
|
132 |
+
))
|
133 |
+
except Exception as e:
|
134 |
+
print("Insert Error:", e)
|
135 |
conn.commit()
|
136 |
cur.close()
|
137 |
conn.close()
|
138 |
|
139 |
+
# Retrieve similar context from DB
|
140 |
+
|
141 |
def retrieve_context(question, keyword, reddit_id=None, top_k=10):
|
142 |
lower_q = question.lower()
|
143 |
+
requested_top_k = 50 if any(word in lower_q for word in ["summarize", "overview", "all posts"]) else top_k
|
|
|
|
|
|
|
|
|
144 |
|
|
|
|
|
|
|
|
|
145 |
conn = get_db_conn()
|
146 |
cur = conn.cursor()
|
147 |
+
|
148 |
if reddit_id:
|
149 |
cur.execute("""
|
150 |
SELECT title, post_text, comments FROM reddit_posts
|
151 |
+
WHERE reddit_id = ?;
|
152 |
""", (reddit_id,))
|
153 |
else:
|
154 |
cur.execute("""
|
155 |
SELECT title, post_text, comments FROM reddit_posts
|
156 |
+
WHERE keyword = ? ORDER BY created_at DESC LIMIT ?;
|
157 |
+
""", (keyword, requested_top_k))
|
158 |
+
|
159 |
results = cur.fetchall()
|
160 |
+
cur.close()
|
161 |
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
return results
|
163 |
|
164 |
+
# Summarizer
|
165 |
+
|
166 |
summarize_prompt = ChatPromptTemplate.from_template("""
|
167 |
You are a summarizer. Summarize the following context from Reddit posts into a concise summary that preserves the key insights. Do not add extra commentary.
|
168 |
|
|
|
173 |
""")
|
174 |
summarize_chain = LLMChain(llm=llm, prompt=summarize_prompt)
|
175 |
|
176 |
+
# Chatbot memory and prompt
|
177 |
|
|
|
178 |
memory = ConversationBufferMemory(memory_key="chat_history")
|
|
|
179 |
chat_prompt = ChatPromptTemplate.from_template("""
|
180 |
Chat History:
|
181 |
{chat_history}
|
|
|
184 |
{input}
|
185 |
|
186 |
Act as an Professional Assistant as incremental chat agent and also give reasioning and Answer clearly based on context and chat history, your response should be valid and concise, and relavant .
|
|
|
187 |
""")
|
188 |
|
189 |
chat_chain = LLMChain(
|
190 |
llm=llm,
|
191 |
prompt=chat_prompt,
|
192 |
memory=memory,
|
193 |
+
verbose=True
|
194 |
)
|
195 |
|
196 |
+
# Chatbot response
|
|
|
197 |
|
198 |
def get_chatbot_response(question, keyword, reddit_id=None):
|
199 |
context_posts = retrieve_context(question, keyword, reddit_id)
|
200 |
context = "\n\n".join([f"{p[0]}:\n{p[1]}" for p in context_posts])
|
|
|
|
|
201 |
if len(context) > 3000:
|
202 |
context = summarize_chain.run({"context": context})
|
|
|
203 |
combined_input = f"Context:\n{context}\n\nUser Question: {question}"
|
204 |
response = chat_chain.run({"input": combined_input})
|
205 |
+
return response, context_posts
|