legal_rag / main.py
allenlsl's picture
Update main.py
7907d24 verified
raw
history blame
8.57 kB
import os
import requests
import pdfplumber
import trafilatura
from bs4 import BeautifulSoup
from urllib.parse import urljoin, urlparse
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pickle
import argparse
# from ollama_initial import start_ollama_model
from llama_query import ask_llm_with_context
from smart_chunk import smart_chunk_text # for semantic-aware chunking
# === Config ===
INDEX_FILE = "legal_index.faiss"
DOCS_FILE = "legal_chunks.pkl"
PDF_CACHE_FILE = "processed_pdfs.pkl"
URL_CACHE_FILE = "processed_urls.pkl"
EMBEDDING_MODEL = "intfloat/e5-base-v2"
ALLOWED_DOMAINS = ["gov", "org", "ca"]
PDF_FOLDER = "pdf"
URL_FILE = "urls.txt"
# === CLI args ===
parser = argparse.ArgumentParser()
parser.add_argument("--update", action="store_true", help="Update only new PDFs/URLs (uses cache)")
parser.add_argument("--updateall", action="store_true", help="Force complete reindexing of all documents (ignores cache)")
args = parser.parse_args()
# === Embedding setup ===
model = SentenceTransformer(EMBEDDING_MODEL)
vector_index = faiss.IndexFlatL2(model.get_sentence_embedding_dimension())
documents = []
# === Cache handling ===
def load_cache(file):
if os.path.exists(file):
with open(file, "rb") as f:
return pickle.load(f)
return set()
def save_cache(data, file):
with open(file, "wb") as f:
pickle.dump(data, f)
# === Index persistence ===
def save_index():
faiss.write_index(vector_index, INDEX_FILE)
with open(DOCS_FILE, "wb") as f:
pickle.dump(documents, f)
print("βœ… Vector index and chunks saved.")
def load_index():
global vector_index, documents
if os.path.exists(INDEX_FILE) and os.path.exists(DOCS_FILE):
print("πŸ“‚ Found existing FAISS index and document chunks...")
vector_index = faiss.read_index(INDEX_FILE)
with open(DOCS_FILE, "rb") as f:
documents = pickle.load(f)
print(f"βœ… Loaded {vector_index.ntotal} vectors and {len(documents)} text chunks.")
return True
else:
print("❌ FAISS or document file not found.")
return False
# === Chunk + embed ===
def store_text_chunks(text):
chunks = smart_chunk_text(text, max_tokens=128)
chunks = [chunk.strip() for chunk in chunks if chunk.strip()]
if not chunks:
return
vectors = model.encode(chunks, batch_size=16, show_progress_bar=True)
vector_index.add(np.array(vectors))
documents.extend(chunks)
# === Text extraction ===
def get_text_from_pdf_file(filepath):
try:
with pdfplumber.open(filepath) as pdf:
return "\n".join(page.extract_text() or '' for page in pdf.pages)
except Exception as e:
print(f"[!] Failed to read PDF: {filepath} β€” {e}")
return ""
def get_text_from_pdf_url(url):
try:
response = requests.get(url)
filename = "temp.pdf"
with open(filename, "wb") as f:
f.write(response.content)
text = get_text_from_pdf_file(filename)
os.remove(filename)
return text
except Exception as e:
print(f"[!] Failed to fetch PDF from URL: {url} β€” {e}")
return ""
def get_text_from_html(url):
try:
html = requests.get(url).text
return trafilatura.extract(html, include_comments=False, include_tables=False) or ""
except Exception as e:
print(f"[!] Failed HTML: {url} β€” {e}")
return ""
def is_valid_link(link, base_url):
full_url = urljoin(base_url, link)
parsed = urlparse(full_url)
return parsed.scheme.startswith("http") and any(tld in parsed.netloc for tld in ALLOWED_DOMAINS)
# === Processing ===
def process_pdf_folder(folder_path=PDF_FOLDER, processed_files=None):
if processed_files is None:
processed_files = set()
for filename in os.listdir(folder_path):
if filename.lower().endswith(".pdf") and filename not in processed_files:
full_path = os.path.join(folder_path, filename)
print(f"πŸ“„ Reading new PDF: {full_path}")
text = get_text_from_pdf_file(full_path)
store_text_chunks(text)
processed_files.add(filename)
else:
print(f"βœ… Skipping already processed PDF: {filename}")
def crawl_url(url, depth=1, processed_urls=None):
if processed_urls is None:
processed_urls = set()
if url in processed_urls:
print(f"βœ… Skipping already crawled URL: {url}")
return
print(f"πŸ”— Crawling: {url}")
visited = set()
to_visit = [url]
while to_visit and depth > 0:
current = to_visit.pop()
visited.add(current)
if current.endswith(".pdf"):
text = get_text_from_pdf_url(current)
else:
text = get_text_from_html(current)
store_text_chunks(text)
processed_urls.add(current)
try:
page = requests.get(current).text
soup = BeautifulSoup(page, "html.parser")
for a in soup.find_all("a", href=True):
href = a["href"]
full_url = urljoin(current, href)
if full_url not in visited and is_valid_link(href, current):
to_visit.append(full_url)
except Exception:
continue
depth -= 1
# === Retrieval ===
def load_urls(file_path=URL_FILE):
with open(file_path, "r", encoding="utf-8") as f:
return [line.strip() for line in f if line.strip()]
def query_index(question, top_k=5):
if not documents:
return "No documents found in the index."
query = f"query: {question}"
q_vector = model.encode(query)
D, I = vector_index.search(np.array([q_vector]), top_k)
return "\n---\n".join([documents[i] for i in I[0]])
# === Main Execution ===
if __name__ == "__main__":
print("πŸš€ Starting BC Land Survey Legal Assistant")
# Default behavior: load existing index
update_mode = "none" # can be "none", "update", or "updateall"
if args.updateall:
update_mode = "updateall"
elif args.update:
update_mode = "update"
# Load caches for local PDF and URL tracking
processed_pdfs = load_cache(PDF_CACHE_FILE)
processed_urls = load_cache(URL_CACHE_FILE)
if update_mode == "updateall":
print("πŸ” Rebuilding index from scratch...")
processed_pdfs = set()
processed_urls = set()
index_loaded = load_index()
if update_mode == "updateall" or not index_loaded or update_mode == "update":
if not index_loaded:
print("⚠️ Index not found β€” will rebuild from source.")
print("πŸ”„ Indexing content...")
process_pdf_folder(processed_files=processed_pdfs)
for url in load_urls():
crawl_url(url, depth=1, processed_urls=processed_urls)
save_index()
save_cache(processed_pdfs, PDF_CACHE_FILE)
save_cache(processed_urls, URL_CACHE_FILE)
else:
print(f"βœ… Loaded FAISS index with {vector_index.ntotal} vectors.")
print(f"βœ… Loaded {len(documents)} legal chunks.")
print("\n❓ Ready to query your legal database (type 'exit' to quit)")
while True:
question = input("\nπŸ”Ž Your question: ")
if question.strip().lower() in ["exit", "quit", "q"]:
print("πŸ‘‹ Exiting. See you next time!")
break
context = query_index(question)
answer = ask_llm_with_context(question, context)
print("\n🧠 LLaMA 3 Answer:")
print(answer)
def initialize_index(update_mode="none"):
global documents, vector_index
processed_pdfs = load_cache(PDF_CACHE_FILE)
processed_urls = load_cache(URL_CACHE_FILE)
if update_mode == "updateall":
processed_pdfs = set()
processed_urls = set()
index_loaded = load_index()
if update_mode == "updateall" or not index_loaded or update_mode == "update":
process_pdf_folder(processed_files=processed_pdfs)
for url in load_urls():
crawl_url(url, depth=1, processed_urls=processed_urls)
save_index()
save_cache(processed_pdfs, PDF_CACHE_FILE)
save_cache(processed_urls, URL_CACHE_FILE)
else:
print(f"βœ… FAISS index with {vector_index.ntotal} vectors loaded.")
print(f"βœ… Loaded {len(documents)} legal document chunks.")
# This version includes all 3 enhancements:
# - Smart chunking via smart_chunk.py
# - High-quality embedding model (E5)
# - Structured prompt with legal assistant context and disclaimer