ahmadgenus commited on
Commit
ed2c8ad
·
1 Parent(s): 318f035

Switched to SQLite for Hugging Face deployment

Browse files
Files changed (1) hide show
  1. chatbot.py +53 -110
chatbot.py CHANGED
@@ -1,16 +1,6 @@
1
-
2
-
3
  from langchain.chains import LLMChain
4
- # chat_chain = LLMChain(
5
- # llm=llm,
6
- # prompt=chat_prompt,
7
- # memory=memory,
8
- # verbose=True # Enable verbose logging for debugging
9
- # )
10
-
11
-
12
  import os
13
- import psycopg2
14
  import praw
15
  import json
16
  from datetime import datetime, timedelta
@@ -26,8 +16,7 @@ load_dotenv()
26
  # Initialize the LLM via LangChain (using Groq)
27
  llm = ChatGroq(
28
  groq_api_key=os.getenv("GROQ_API_KEY"),
29
- # model_name=os.getenv("MODEL_NAME"),
30
- model_name= "meta-llama/llama-4-maverick-17b-128e-instruct",
31
  temperature=0.2
32
  )
33
 
@@ -41,32 +30,27 @@ reddit = praw.Reddit(
41
  user_agent=os.getenv("REDDIT_USER_AGENT")
42
  )
43
 
44
- # Database connection function
45
- import psycopg2
46
- import os
47
-
48
  def get_db_conn():
49
- return psycopg2.connect(os.getenv("DATABASE_URL"))
50
-
51
- # Set up the database schema: store raw post text, comments, computed embedding, and metadata.
52
 
 
53
  def setup_db():
54
  conn = get_db_conn()
55
  cur = conn.cursor()
56
  try:
57
- cur.execute(""" -- remove EXTENSION line
58
  CREATE TABLE IF NOT EXISTS reddit_posts (
59
- id SERIAL PRIMARY KEY,
60
- reddit_id VARCHAR(50) UNIQUE,
61
  keyword TEXT,
62
  title TEXT,
63
  post_text TEXT,
64
- comments JSONB,
65
- created_at TIMESTAMP,
66
- embedding VECTOR(384),
67
- metadata JSONB
68
  );
69
- CREATE INDEX IF NOT EXISTS idx_keyword_created_at ON reddit_posts(keyword, created_at DESC);
70
  """)
71
  conn.commit()
72
  except Exception as e:
@@ -74,29 +58,9 @@ def setup_db():
74
  finally:
75
  cur.close()
76
  conn.close()
77
- # def setup_db():
78
- # conn = get_db_conn()
79
- # cur = conn.cursor()
80
- # cur.execute("""
81
- # CREATE EXTENSION IF NOT EXISTS vector;
82
- # CREATE TABLE IF NOT EXISTS reddit_posts (
83
- # id SERIAL PRIMARY KEY,
84
- # reddit_id VARCHAR(50) UNIQUE,
85
- # keyword TEXT,
86
- # title TEXT,
87
- # post_text TEXT,
88
- # comments JSONB,
89
- # created_at TIMESTAMP,
90
- # embedding VECTOR(384),
91
- # metadata JSONB
92
- # );
93
- # CREATE INDEX IF NOT EXISTS idx_keyword_created_at ON reddit_posts(keyword, created_at DESC);
94
- # """)
95
- # conn.commit()
96
- # cur.close()
97
- # conn.close()
98
 
99
- # Utility: Check if the keyword appears in the post title, selftext, or any comment.
 
100
  def keyword_in_post_or_comments(post, keyword):
101
  keyword_lower = keyword.lower()
102
  combined_text = (post.title + " " + post.selftext).lower()
@@ -108,19 +72,19 @@ def keyword_in_post_or_comments(post, keyword):
108
  return True
109
  return False
110
 
111
- # Fetch Reddit posts if the keyword is in the post or any comment.
112
- # This version iterates over posts until reaching posts older than the specified day range.
113
  def fetch_reddit_data(keyword, days=7, limit=None):
114
  end_time = datetime.utcnow()
115
  start_time = end_time - timedelta(days=days)
116
  subreddit = reddit.subreddit("all")
117
  posts_generator = subreddit.search(keyword, sort="new", time_filter="all", limit=limit)
118
-
119
  data = []
120
  for post in posts_generator:
121
  created = datetime.utcfromtimestamp(post.created_utc)
122
  if created < start_time:
123
- break # Since sorted by new, we break once older posts are encountered.
124
  if not keyword_in_post_or_comments(post, keyword):
125
  continue
126
 
@@ -139,81 +103,66 @@ def fetch_reddit_data(keyword, days=7, limit=None):
139
  "title": post.title,
140
  "post_text": post.selftext,
141
  "comments": comments,
142
- "created_at": created,
143
  "embedding": embedding,
144
  "metadata": metadata
145
  })
146
  save_to_db(data)
147
 
148
- # Save posts data into PostgreSQL.
 
149
  def save_to_db(posts):
150
  conn = get_db_conn()
151
  cur = conn.cursor()
152
  for post in posts:
153
- cur.execute("""
154
- INSERT INTO reddit_posts
155
- (reddit_id, keyword, title, post_text, comments, created_at, embedding, metadata)
156
- VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
157
- ON CONFLICT DO NOTHING;
158
- """, (
159
- post["reddit_id"],
160
- post["keyword"],
161
- post["title"],
162
- post["post_text"],
163
- json.dumps(post["comments"]),
164
- post["created_at"],
165
- post["embedding"],
166
- json.dumps(post["metadata"])
167
- ))
 
 
168
  conn.commit()
169
  cur.close()
170
  conn.close()
171
 
172
- # Retrieve context from the DB.
173
- # Updated retrieval: if summarization intent is detected, retrieve more posts.
174
  def retrieve_context(question, keyword, reddit_id=None, top_k=10):
175
  lower_q = question.lower()
176
- # Check for summarization intent.
177
- if any(word in lower_q for word in ["summarize", "overview", "all posts"]):
178
- requested_top_k = 50
179
- else:
180
- requested_top_k = top_k
181
 
182
- # Retrieve posts based on query embedding.
183
- query_embedding = embedder.encode(question).tolist()
184
- query_embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
185
-
186
  conn = get_db_conn()
187
  cur = conn.cursor()
 
188
  if reddit_id:
189
  cur.execute("""
190
  SELECT title, post_text, comments FROM reddit_posts
191
- WHERE reddit_id = %s;
192
  """, (reddit_id,))
193
  else:
194
  cur.execute("""
195
  SELECT title, post_text, comments FROM reddit_posts
196
- WHERE keyword = %s
197
- ORDER BY embedding <-> %s::vector LIMIT %s;
198
- """, (keyword, query_embedding_str, requested_top_k))
199
  results = cur.fetchall()
 
200
  conn.close()
201
-
202
- # If there are fewer posts than requested and none were retrieved by vector search,
203
- # fall back to retrieving all posts for that keyword.
204
- if not results:
205
- conn = get_db_conn()
206
- cur = conn.cursor()
207
- cur.execute("""
208
- SELECT title, post_text, comments FROM reddit_posts
209
- WHERE keyword = %s ORDER BY created_at DESC;
210
- """, (keyword,))
211
- results = cur.fetchall()
212
- conn.close()
213
  return results
214
 
215
- # --- New Summarization Step for Handling Long Context ---
216
- # Create a summarization chain to compress the context if it exceeds a token/character threshold.
217
  summarize_prompt = ChatPromptTemplate.from_template("""
218
  You are a summarizer. Summarize the following context from Reddit posts into a concise summary that preserves the key insights. Do not add extra commentary.
219
 
@@ -224,10 +173,9 @@ Summary:
224
  """)
225
  summarize_chain = LLMChain(llm=llm, prompt=summarize_prompt)
226
 
 
227
 
228
- # Set up conversation memory and chain.
229
  memory = ConversationBufferMemory(memory_key="chat_history")
230
- # Updated prompt: we now expect a single input field "input"
231
  chat_prompt = ChatPromptTemplate.from_template("""
232
  Chat History:
233
  {chat_history}
@@ -236,27 +184,22 @@ Context from Reddit and User Question:
236
  {input}
237
 
238
  Act as an Professional Assistant as incremental chat agent and also give reasioning and Answer clearly based on context and chat history, your response should be valid and concise, and relavant .
239
-
240
  """)
241
 
242
  chat_chain = LLMChain(
243
  llm=llm,
244
  prompt=chat_prompt,
245
  memory=memory,
246
- verbose=True # Enable verbose logging for debugging
247
  )
248
 
249
- # Get chatbot response by merging context and question into a single input.
250
- # Updated get_chatbot_response to handle summarization if context is too long.
251
 
252
  def get_chatbot_response(question, keyword, reddit_id=None):
253
  context_posts = retrieve_context(question, keyword, reddit_id)
254
  context = "\n\n".join([f"{p[0]}:\n{p[1]}" for p in context_posts])
255
-
256
- # Set a threshold (e.g., 3000 characters); if context length exceeds it, compress the context.
257
  if len(context) > 3000:
258
  context = summarize_chain.run({"context": context})
259
-
260
  combined_input = f"Context:\n{context}\n\nUser Question: {question}"
261
  response = chat_chain.run({"input": combined_input})
262
- return response, context_posts
 
 
 
1
  from langchain.chains import LLMChain
 
 
 
 
 
 
 
 
2
  import os
3
+ import sqlite3
4
  import praw
5
  import json
6
  from datetime import datetime, timedelta
 
16
  # Initialize the LLM via LangChain (using Groq)
17
  llm = ChatGroq(
18
  groq_api_key=os.getenv("GROQ_API_KEY"),
19
+ model_name="meta-llama/llama-4-maverick-17b-128e-instruct",
 
20
  temperature=0.2
21
  )
22
 
 
30
  user_agent=os.getenv("REDDIT_USER_AGENT")
31
  )
32
 
33
+ # SQLite DB Connection
 
 
 
34
  def get_db_conn():
35
+ return sqlite3.connect("reddit_data.db", check_same_thread=False)
 
 
36
 
37
+ # Set up the database schema
38
  def setup_db():
39
  conn = get_db_conn()
40
  cur = conn.cursor()
41
  try:
42
+ cur.execute("""
43
  CREATE TABLE IF NOT EXISTS reddit_posts (
44
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
45
+ reddit_id TEXT UNIQUE,
46
  keyword TEXT,
47
  title TEXT,
48
  post_text TEXT,
49
+ comments TEXT,
50
+ created_at TEXT,
51
+ embedding TEXT,
52
+ metadata TEXT
53
  );
 
54
  """)
55
  conn.commit()
56
  except Exception as e:
 
58
  finally:
59
  cur.close()
60
  conn.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Keyword filter
63
+
64
  def keyword_in_post_or_comments(post, keyword):
65
  keyword_lower = keyword.lower()
66
  combined_text = (post.title + " " + post.selftext).lower()
 
72
  return True
73
  return False
74
 
75
+ # Fetch and process Reddit data
76
+
77
  def fetch_reddit_data(keyword, days=7, limit=None):
78
  end_time = datetime.utcnow()
79
  start_time = end_time - timedelta(days=days)
80
  subreddit = reddit.subreddit("all")
81
  posts_generator = subreddit.search(keyword, sort="new", time_filter="all", limit=limit)
82
+
83
  data = []
84
  for post in posts_generator:
85
  created = datetime.utcfromtimestamp(post.created_utc)
86
  if created < start_time:
87
+ break
88
  if not keyword_in_post_or_comments(post, keyword):
89
  continue
90
 
 
103
  "title": post.title,
104
  "post_text": post.selftext,
105
  "comments": comments,
106
+ "created_at": created.isoformat(),
107
  "embedding": embedding,
108
  "metadata": metadata
109
  })
110
  save_to_db(data)
111
 
112
+ # Save data into SQLite
113
+
114
  def save_to_db(posts):
115
  conn = get_db_conn()
116
  cur = conn.cursor()
117
  for post in posts:
118
+ try:
119
+ cur.execute("""
120
+ INSERT OR IGNORE INTO reddit_posts
121
+ (reddit_id, keyword, title, post_text, comments, created_at, embedding, metadata)
122
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?);
123
+ """, (
124
+ post["reddit_id"],
125
+ post["keyword"],
126
+ post["title"],
127
+ post["post_text"],
128
+ json.dumps(post["comments"]),
129
+ post["created_at"],
130
+ json.dumps(post["embedding"]),
131
+ json.dumps(post["metadata"])
132
+ ))
133
+ except Exception as e:
134
+ print("Insert Error:", e)
135
  conn.commit()
136
  cur.close()
137
  conn.close()
138
 
139
+ # Retrieve similar context from DB
140
+
141
  def retrieve_context(question, keyword, reddit_id=None, top_k=10):
142
  lower_q = question.lower()
143
+ requested_top_k = 50 if any(word in lower_q for word in ["summarize", "overview", "all posts"]) else top_k
 
 
 
 
144
 
 
 
 
 
145
  conn = get_db_conn()
146
  cur = conn.cursor()
147
+
148
  if reddit_id:
149
  cur.execute("""
150
  SELECT title, post_text, comments FROM reddit_posts
151
+ WHERE reddit_id = ?;
152
  """, (reddit_id,))
153
  else:
154
  cur.execute("""
155
  SELECT title, post_text, comments FROM reddit_posts
156
+ WHERE keyword = ? ORDER BY created_at DESC LIMIT ?;
157
+ """, (keyword, requested_top_k))
158
+
159
  results = cur.fetchall()
160
+ cur.close()
161
  conn.close()
 
 
 
 
 
 
 
 
 
 
 
 
162
  return results
163
 
164
+ # Summarizer
165
+
166
  summarize_prompt = ChatPromptTemplate.from_template("""
167
  You are a summarizer. Summarize the following context from Reddit posts into a concise summary that preserves the key insights. Do not add extra commentary.
168
 
 
173
  """)
174
  summarize_chain = LLMChain(llm=llm, prompt=summarize_prompt)
175
 
176
+ # Chatbot memory and prompt
177
 
 
178
  memory = ConversationBufferMemory(memory_key="chat_history")
 
179
  chat_prompt = ChatPromptTemplate.from_template("""
180
  Chat History:
181
  {chat_history}
 
184
  {input}
185
 
186
  Act as an Professional Assistant as incremental chat agent and also give reasioning and Answer clearly based on context and chat history, your response should be valid and concise, and relavant .
 
187
  """)
188
 
189
  chat_chain = LLMChain(
190
  llm=llm,
191
  prompt=chat_prompt,
192
  memory=memory,
193
+ verbose=True
194
  )
195
 
196
+ # Chatbot response
 
197
 
198
  def get_chatbot_response(question, keyword, reddit_id=None):
199
  context_posts = retrieve_context(question, keyword, reddit_id)
200
  context = "\n\n".join([f"{p[0]}:\n{p[1]}" for p in context_posts])
 
 
201
  if len(context) > 3000:
202
  context = summarize_chain.run({"context": context})
 
203
  combined_input = f"Context:\n{context}\n\nUser Question: {question}"
204
  response = chat_chain.run({"input": combined_input})
205
+ return response, context_posts