Spaces:
Running
Running
from langchain.chains import LLMChain | |
import os | |
import sqlite3 | |
import praw | |
import json | |
from datetime import datetime, timedelta | |
from sentence_transformers import SentenceTransformer | |
from dotenv import load_dotenv | |
from langchain_groq import ChatGroq | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.chains import ConversationChain, LLMChain | |
from langchain.memory import ConversationBufferMemory | |
load_dotenv() | |
# Initialize the LLM via LangChain (using Groq) | |
llm = ChatGroq( | |
groq_api_key=os.getenv("GROQ_API_KEY"), | |
model_name="meta-llama/llama-4-maverick-17b-128e-instruct", | |
temperature=0.2 | |
) | |
# Embedding Model | |
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
# Reddit API Setup | |
reddit = praw.Reddit( | |
client_id=os.getenv("REDDIT_CLIENT_ID"), | |
client_secret=os.getenv("REDDIT_CLIENT_SECRET"), | |
user_agent=os.getenv("REDDIT_USER_AGENT") | |
) | |
# SQLite DB Connection | |
def get_db_conn(): | |
return sqlite3.connect("reddit_data.db", check_same_thread=False) | |
# Set up the database schema | |
def setup_db(): | |
conn = get_db_conn() | |
cur = conn.cursor() | |
try: | |
cur.execute(""" | |
CREATE TABLE IF NOT EXISTS reddit_posts ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
reddit_id TEXT UNIQUE, | |
keyword TEXT, | |
title TEXT, | |
post_text TEXT, | |
comments TEXT, | |
created_at TEXT, | |
embedding TEXT, | |
metadata TEXT | |
); | |
""") | |
conn.commit() | |
except Exception as e: | |
print("DB Setup Error:", e) | |
finally: | |
cur.close() | |
conn.close() | |
# Keyword filter | |
def keyword_in_post_or_comments(post, keyword): | |
keyword_lower = keyword.lower() | |
combined_text = (post.title + " " + post.selftext).lower() | |
if keyword_lower in combined_text: | |
return True | |
post.comments.replace_more(limit=None) | |
for comment in post.comments.list(): | |
if keyword_lower in comment.body.lower(): | |
return True | |
return False | |
# Fetch and process Reddit data | |
def fetch_reddit_data(keyword, days=7, limit=None): | |
end_time = datetime.utcnow() | |
start_time = end_time - timedelta(days=days) | |
subreddit = reddit.subreddit("all") | |
posts_generator = subreddit.search(keyword, sort="new", time_filter="all", limit=limit) | |
data = [] | |
for post in posts_generator: | |
created = datetime.utcfromtimestamp(post.created_utc) | |
if created < start_time: | |
break | |
if not keyword_in_post_or_comments(post, keyword): | |
continue | |
post.comments.replace_more(limit=None) | |
comments = [comment.body for comment in post.comments.list()] | |
combined_text = f"{post.title}\n{post.selftext}\n{' '.join(comments)}" | |
embedding = embedder.encode(combined_text).tolist() | |
metadata = { | |
"url": post.url, | |
"subreddit": post.subreddit.display_name, | |
"comments_count": len(comments) | |
} | |
data.append({ | |
"reddit_id": post.id, | |
"keyword": keyword, | |
"title": post.title, | |
"post_text": post.selftext, | |
"comments": comments, | |
"created_at": created.isoformat(), | |
"embedding": embedding, | |
"metadata": metadata | |
}) | |
save_to_db(data) | |
# Save data into SQLite | |
def save_to_db(posts): | |
conn = get_db_conn() | |
cur = conn.cursor() | |
for post in posts: | |
try: | |
cur.execute(""" | |
INSERT OR IGNORE INTO reddit_posts | |
(reddit_id, keyword, title, post_text, comments, created_at, embedding, metadata) | |
VALUES (?, ?, ?, ?, ?, ?, ?, ?); | |
""", ( | |
post["reddit_id"], | |
post["keyword"], | |
post["title"], | |
post["post_text"], | |
json.dumps(post["comments"]), | |
post["created_at"], | |
json.dumps(post["embedding"]), | |
json.dumps(post["metadata"]) | |
)) | |
except Exception as e: | |
print("Insert Error:", e) | |
conn.commit() | |
cur.close() | |
conn.close() | |
# Retrieve similar context from DB | |
def retrieve_context(question, keyword, reddit_id=None, top_k=10): | |
lower_q = question.lower() | |
requested_top_k = 50 if any(word in lower_q for word in ["summarize", "overview", "all posts"]) else top_k | |
conn = get_db_conn() | |
cur = conn.cursor() | |
if reddit_id: | |
cur.execute(""" | |
SELECT title, post_text, comments FROM reddit_posts | |
WHERE reddit_id = ?; | |
""", (reddit_id,)) | |
else: | |
cur.execute(""" | |
SELECT title, post_text, comments FROM reddit_posts | |
WHERE keyword = ? ORDER BY datetime(created_at) DESC LIMIT ?; | |
""", (keyword, requested_top_k)) | |
results = cur.fetchall() | |
cur.close() | |
conn.close() | |
return results | |
# Summarizer | |
summarize_prompt = ChatPromptTemplate.from_template(""" | |
You are a summarizer. Summarize the following context from Reddit posts into a concise summary that preserves the key insights. Do not add extra commentary. | |
Context: | |
{context} | |
Summary: | |
""") | |
summarize_chain = LLMChain(llm=llm, prompt=summarize_prompt) | |
# Chatbot memory and prompt | |
memory = ConversationBufferMemory(memory_key="chat_history") | |
chat_prompt = ChatPromptTemplate.from_template(""" | |
Chat History: | |
{chat_history} | |
Context from Reddit and User Question: | |
{input} | |
Act as an Professional Assistant as incremental chat agent and also give reasioning and Answer clearly based on context and chat history, your response should be valid and concise, and relavant . | |
""") | |
chat_chain = LLMChain( | |
llm=llm, | |
prompt=chat_prompt, | |
memory=memory, | |
verbose=True | |
) | |
# Chatbot response | |
def get_chatbot_response(question, keyword, reddit_id=None): | |
context_posts = retrieve_context(question, keyword, reddit_id) | |
context = "\n\n".join([f"{p[0]}:\n{p[1]}" for p in context_posts]) | |
if len(context) > 3000: | |
context = summarize_chain.run({"context": context}) | |
combined_input = f"Context:\n{context}\n\nUser Question: {question}" | |
response = chat_chain.run({"input": combined_input}) | |
return response, context_posts |