PhoenixDecim's picture
Added penalty for reasoning and future prediction questions
18d1c8f
"""SLM with RAG for financial statements"""
# Importing the dependencies
import logging
import os
import subprocess
import time
import re
import pickle
import numpy as np
import pandas as pd
import torch
import spacy
import pdfplumber
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import faiss
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from data_filters import (
restricted_patterns,
restricted_topics,
FINANCIAL_DATA_PATTERNS,
FINANCIAL_ENTITY_LABELS,
GENERAL_KNOWLEDGE_PATTERNS,
sensitive_terms,
EXPLANATORY_PATTERNS,
FINANCIAL_TERMS,
)
# Initialize logger
logging.basicConfig(
# filename="app.log",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger()
os.makedirs("data", exist_ok=True)
# SLM: Microsoft PHI-2 model is loaded
# It does have higher memory and compute requirements compared to TinyLlama and Falcon
# But it gives the best results among the three
DEVICE = "cpu" # or cuda
# DEVICE = "cuda" # or cuda
# MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
# MODEL_NAME = "tiiuae/falcon-rw-1b"
MODEL_NAME = "microsoft/phi-2"
# MODEL_NAME = "google/gemma-3-1b-pt"
# Load the Tokenizer for PHI-2
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
MAX_TOKENS = tokenizer.model_max_length
CONTEXT_MULTIPLIER = 0.7
# The max_context tokens is used to limit the retrieved chunks during querying
# to provide some headroom for the query
MAX_CONTEXT_TOKENS = int(MAX_TOKENS * CONTEXT_MULTIPLIER)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Since the model is to be hosted on a cpu instance, we use float32
# For GPU, we can use float16 or bfloat16
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.float32, trust_remote_code=True
).to(DEVICE)
model.eval()
# model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
logger.info("Model loaded successfully.")
# Load Sentence Transformer for Embeddings and Cross Encoder for re-ranking
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
# Load spaCy English model for Named Entity Recognition (mainly for guardrail)
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
nlp = spacy.load("en_core_web_sm")
# Extract the yaer from the upload file's name if any
def extract_year_from_filename(filename):
"""Extract Year from Filename"""
match = re.search(r"(\d{4})-(\d{4})", filename)
if match:
return match.group(1)
match = re.search(r"(\d{4})", filename)
return match.group(1) if match else "Unknown"
# Use PDFPlumber to extract the tables from the uploaded file
# Add the year column for context and create a dataframe
def extract_tables_from_pdf(pdf_path):
"""Extract tables from PDF into a DataFrame"""
all_tables = []
report_year = extract_year_from_filename(pdf_path)
with pdfplumber.open(pdf_path) as pdf:
for page_num, page in enumerate(pdf.pages, start=1):
tables = page.extract_tables()
for table in tables:
df = pd.DataFrame(table)
df["year"] = report_year
all_tables.append(df)
return pd.concat(all_tables, ignore_index=True) if all_tables else pd.DataFrame()
# Load the csv files directly using pandas into a dataframe
def load_csv(file_path):
"""Loads a CSV file into a DataFrame"""
try:
df = pd.read_csv(file_path)
df["year"] = extract_year_from_filename(file_path)
return df
except Exception as e:
print(f"Error loading CSV: {e}")
return None
# Preprocess the dataframe - Replace null values and create text rows suitable for chunking
def clean_dataframe_text(df):
"""Clean and format PDF/CSV data"""
df.fillna("", inplace=True)
text_data = []
for _, row in df.iterrows():
parts = []
if "year" in df.columns:
parts.append(f"Year: {row['year']}")
parts.extend([str(val).strip() for val in row if str(val).strip()])
text_data.append(", ".join(parts))
df["text"] = text_data
return df[["text"]].replace("", np.nan).dropna()
# Chunk the text for retrival
# Different chunk sizes - 256,512,1024,2048 were tried and 512 worked the best for financial RAG
def chunk_text(text, chunk_size=512):
"""Apply Chunking on the text"""
words = text.split()
chunks, temp_chunk = [], []
for word in words:
if sum(len(w) for w in temp_chunk) + len(temp_chunk) + len(word) <= chunk_size:
temp_chunk.append(word)
else:
chunks.append(" ".join(temp_chunk))
temp_chunk = [word]
if temp_chunk:
chunks.append(" ".join(temp_chunk))
return chunks
# Uses regex to identify financial terms and ensure relevant data is only merged
def is_financial_text(text):
"""Detects financial data"""
return bool(
re.search(
FINANCIAL_DATA_PATTERNS,
text,
re.IGNORECASE,
)
)
# Uses a sentence transformer "all-MiniLM-L6-v2" to embed text chunks
# Stores embeddings in a FAISS vector database for similarity search
# BM25 is implemented alongside FAISS to improve retrieval
# Use FAISS Cosine Similarity index and merge only highly similar text chunks (>85%)
def merge_similar_chunks(chunks, similarity_threshold=0.85):
"""Merge similar chunks while preserving financial data structure"""
if not chunks:
return []
# Encode chunks into embeddings
embeddings = np.array(
embed_model.encode(chunks, normalize_embeddings=True), dtype="float32"
)
# FAISS Cosine Similarity Index
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
# Get top-2 most similar chunks
_, indices = index.search(embeddings, 2)
merged_chunks = {}
for i, idx in enumerate(indices[:, 1]):
if i in merged_chunks or idx in merged_chunks:
continue
sim_score = np.dot(embeddings[i], embeddings[idx])
# Ensure financial data isn't incorrectly merged
if is_financial_text(chunks[i]) or is_financial_text(chunks[idx]):
merged_chunks[i] = chunks[i]
merged_chunks[idx] = chunks[idx]
continue
# Merge only if similarity is high and chunks are adjacent
if sim_score > similarity_threshold and abs(i - idx) == 1:
merged_chunks[i] = chunks[i] + " " + chunks[idx]
merged_chunks[idx] = merged_chunks[i]
else:
merged_chunks[i] = chunks[i]
return list(set(merged_chunks.values()))
# Handle for file upload button in UI
# Processes the uploaded files and generates the embeddings
# The FAISS embeddings and tokenized chunks are saved for retrieval
def process_files(files, chunk_size=512):
"""Process uploaded files and generate embeddings"""
if not files:
logger.warning("No files uploaded!")
return "Please upload at least one PDF or CSV file."
pdf_paths = [file.name for file in files if file.name.endswith(".pdf")]
csv_paths = [file.name for file in files if file.name.endswith(".csv")]
logger.info(f"Processing {len(pdf_paths)} PDFs and {len(csv_paths)} CSVs")
df_list = []
if pdf_paths:
df_list.extend([extract_tables_from_pdf(pdf) for pdf in pdf_paths])
for csv in csv_paths:
df = load_csv(csv)
df_list.append(df)
if not df_list:
logger.warning("No valid data found in the uploaded files")
return "No valid data found in the uploaded files"
df = pd.concat(df_list, ignore_index=True)
df.dropna(how="all", inplace=True)
logger.info("Data extracted from the files")
df_cleaned = clean_dataframe_text(df)
df_cleaned["chunks"] = df_cleaned["text"].apply(lambda x: chunk_text(x, chunk_size))
df_chunks = df_cleaned.explode("chunks").reset_index(drop=True)
merged_chunks = merge_similar_chunks(df_chunks["chunks"].tolist())
chunk_texts = merged_chunks
# chunk_texts = df_chunks["chunks"].tolist()
embeddings = np.array(
embed_model.encode(chunk_texts, normalize_embeddings=True), dtype="float32"
)
# Save FAISS index
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
faiss.write_index(index, "data/faiss_index.bin")
logger.info("FAISS index created and saved.")
# Save BM25 index
tokenized_chunks = [text.lower().split() for text in chunk_texts]
bm25_data = {"tokenized_chunks": tokenized_chunks, "chunk_texts": chunk_texts}
logger.info("BM25 index created and saved.")
with open("data/bm25_data.pkl", "wb") as f:
pickle.dump(bm25_data, f)
return "Files processed successfully! You can now query."
def contains_financial_entities(query):
"""Check if query contains financial entities"""
doc = nlp(query)
for ent in doc.ents:
if ent.label_ in FINANCIAL_ENTITY_LABELS:
return True
return False
def contains_geographical_entities(query):
"""Check if the query contains geographical entities"""
doc = nlp(query)
return any(ent.label_ == "GPE" for ent in doc.ents)
def contains_financial_terms(query):
"""Check if the query contains financial terms"""
return any(term in query.lower() for term in FINANCIAL_TERMS)
def is_general_knowledge_query(query):
"""Check if query contains general knowledge"""
query_lower = query.lower()
for pattern in GENERAL_KNOWLEDGE_PATTERNS:
if re.search(pattern, query_lower):
return True
return False
def get_latest_available_year(retrieved_chunks):
"""Extracts the latest available year from retrieved financial data"""
years = set()
year_pattern = r"\b(20\d{2})\b"
for chunk in retrieved_chunks:
years.update(map(int, re.findall(year_pattern, chunk)))
return max(years) if years else 2024
def is_irrelevant_query(query):
"""Check if the query is not finance related"""
# If the query is general knowledge and not finance-related
if is_general_knowledge_query(query) and not contains_financial_terms(query):
return True
# If the query contains only geographical terms without financial entities
if contains_geographical_entities(query) and not contains_financial_entities(query):
return True
return False
# Input guardrail implementation
# NER + Regex + List of terms used to filter irrelevant queries
# Regex is used to filter queries related to sensitive topics
# Uses spaCy model's Named Entity Recognition to filter queries for personal details
# Uses cosine similarity with the embedded query and sensitive topic vectors
# to filter out queries violating confidential/security rules (additional)
def is_query_allowed(query):
"""Checks if the query violates security or confidentiality rules"""
if is_irrelevant_query(query):
return False, "Query is not finance-related. Please ask a financial question."
for pattern in restricted_patterns:
if re.search(pattern, query.lower(), re.IGNORECASE):
return False, "This query requests sensitive or confidential information."
doc = nlp(query)
# Check if there's a person entity and contains sensitive terms
for ent in doc.ents:
if ent.label_ == "PERSON":
for token in ent.subtree:
if token.text.lower() in sensitive_terms:
return (
False,
"Query contains personal salary information, which is restricted.",
)
query_embedding = embed_model.encode(query, normalize_embeddings=True)
topic_embeddings = embed_model.encode(
list(restricted_topics), normalize_embeddings=True
)
# Check similarities between the restricted topics and the query
similarities = np.dot(topic_embeddings, query_embedding)
if np.max(similarities) > 0.85:
return False, "This query requests sensitive or confidential information."
return True, None
# Boosts the scores for texts containing financial terms
# This is useful during re-ranking
def boost_score(text, base_score, boost_factor=1.2):
"""Boost scores if the text contains financial terms"""
if any(term in text.lower() for term in FINANCIAL_TERMS):
return base_score * boost_factor
return base_score
# FAISS embeddings are used to retrieve semantically similar chunks
# BM25 is used to retrieve relevant chunks based on the keywords (TF-IDF)
# FAISS and BM25 complement each other- similar matches and important exact matches
# The retrieved chunks are merged and sorted based on a lambda FAISS value
# if lambda FAISS is 0.6, weightage for retrieved FAISS chunks are 0.6 and 0.4 for BM25 chunks
# Cross encoder model ms-marco-MiniLM-L6-v2 is used for scoring and re-ranking the chunks
def hybrid_retrieve(query, chunk_texts, index, bm25, top_k=5, lambda_faiss=0.7):
"""Hybrid Retrieval with FAISS, BM25, Cross-Encoder & Financial Term Boosting"""
# FAISS Retrieval
query_embedding = np.array(
[embed_model.encode(query, normalize_embeddings=True)], dtype="float32"
)
_, faiss_indices = index.search(query_embedding, top_k)
faiss_results = [chunk_texts[idx] for idx in faiss_indices[0]]
# BM25 Retrieval
tokenized_query = query.lower().split()
bm25_scores = bm25.get_scores(tokenized_query)
bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
bm25_results = [chunk_texts[idx] for idx in bm25_top_indices]
# Merge FAISS & BM25 Scores
results = {}
for entry in faiss_results:
results[entry] = boost_score(entry, lambda_faiss)
for entry in bm25_results:
results[entry] = results.get(entry, 0) + boost_score(entry, (1 - lambda_faiss))
# Rank initial results
retrieved_docs = sorted(results.items(), key=lambda x: x[1], reverse=True)
retrieved_texts = [r[0] for r in retrieved_docs]
# Cross-Encoder Re-Ranking
query_text_pairs = [[query, text] for text in retrieved_texts]
scores = cross_encoder.predict(query_text_pairs)
ranked_indices = np.argsort(scores)[::-1]
# Return top-ranked results
final_results = [retrieved_texts[i] for i in ranked_indices[:top_k]]
return final_results
def compute_entropy(logits):
"""Compute entropy from logits."""
probs = torch.softmax(logits, dim=-1)
log_probs = torch.log(probs + 1e-9)
entropy = -(probs * log_probs).sum(dim=-1)
return entropy.mean().item()
def contains_future_year(query, retrieved_chunks):
"""Detects if the query asks for future data beyond available reports"""
latest_year = get_latest_available_year(retrieved_chunks)
# Extract years from query
future_years = set(map(int, re.findall(r"\b(20\d{2})\b", query)))
return any(year > latest_year for year in future_years)
def is_explanatory_query(query):
"""Checks if the query requires an explanation rather than factual data"""
query_lower = query.lower()
return any(re.search(pattern, query_lower) for pattern in EXPLANATORY_PATTERNS)
# A confidence score is computed using FAISS and BM25 ranking
# FAISS: The similarity score between the response and the retrieved chunks are normalized
# BM25: The BM25 scores for the query and response combined tokens is normalized
# The mean of top token probability mean and 1-entropy score is the model_conf_signal
# FAISS, BM25 and the model_conf_signal are combined using a weighted sum
def compute_response_confidence(
query,
response,
retrieved_chunks,
bm25,
model_conf_signal,
lambda_faiss=0.6,
lambda_conf=0.3,
lambda_bm25=1.0,
future_penalty=-0.3,
explanation_penalty=-0.2,
):
"""Calculates a confidence score for the model response"""
if not retrieved_chunks:
return 0.0
# Compute FAISS similarity
retrieved_embedding = embed_model.encode(
" ".join(retrieved_chunks), normalize_embeddings=True
)
response_embedding = embed_model.encode(response, normalize_embeddings=True)
faiss_score = np.dot(retrieved_embedding, response_embedding)
# Normalize the FAISS score
normalized_faiss = (faiss_score + 1) / 2
# Compute BM25 for combined query + response
tokenized_combined = (query + " " + response).lower().split()
bm25_scores = bm25.get_scores(tokenized_combined)
# Normalize the BM25 score
if bm25_scores.size > 0:
bm25_score = np.mean(bm25_scores)
min_bm25, max_bm25 = np.min(bm25_scores), np.max(bm25_scores)
normalized_bm25 = (
(bm25_score - min_bm25) / (max_bm25 - min_bm25 + 1e-6)
if min_bm25 != max_bm25
else 0
)
normalized_bm25 = max(0, min(1, normalized_bm25))
else:
normalized_bm25 = 0.0
# Penalize if query contains future years
future_penalty = -0.3 if contains_future_year(query, retrieved_chunks) else 0.0
# Penalize if query is reasoning based
explanation_penalty_value = (
explanation_penalty if is_explanatory_query(query) else 0.0
)
logger.info(
f"Faiss score: {normalized_faiss}, BM25: {normalized_bm25}\n"
f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}\n"
f"Future penalty: {future_penalty}, Reasoning penalty: {explanation_penalty_value}"
)
# Weighted sum of all the normalized scores
confidence_score = (
lambda_faiss * normalized_faiss
+ model_conf_signal * lambda_conf
+ lambda_bm25 * normalized_bm25
+ future_penalty
+ explanation_penalty_value
)
return round(min(100, max(0, confidence_score.item() * 100)), 2)
# UI handle for query model button
# Loads the saved FAISS embeddings and tokenized chunks for BM25
# Check the query for any policy violation
# Retrieve similar texts using the RAG implementation
# Prompt the loaded SLM along with the retrieved texts and compute confidence score
def query_model(
query,
top_k=10,
lambda_faiss=0.5,
repetition_penalty=1.5,
max_new_tokens=100,
use_extraction=False,
):
"""Query function"""
start_time = time.perf_counter()
# Check if FAISS and BM25 indexes exist
if not os.path.exists("data/faiss_index.bin") or not os.path.exists(
"data/bm25_data.pkl"
):
logger.error("No index found! Prompting user to upload PDFs.")
return (
"Index files not found! Please upload PDFs first to generate embeddings.",
"Error",
)
allowed, reason = is_query_allowed(query)
if not allowed:
logger.error(f"Query Rejected: {reason}")
return f"Query Rejected: {reason}", "Warning"
logger.info(
f"Received query: {query} | Top-K: {top_k}, "
f"Lambda: {lambda_faiss}, Tokens: {max_new_tokens}"
)
# Load FAISS & BM25 Indexes
index = faiss.read_index("data/faiss_index.bin")
with open("data/bm25_data.pkl", "rb") as f:
bm25_data = pickle.load(f)
# Restore tokenized chunks and metadata
tokenized_chunks = bm25_data["tokenized_chunks"]
chunk_texts = bm25_data["chunk_texts"]
bm25 = BM25Okapi(tokenized_chunks)
retrieved_chunks = hybrid_retrieve(
query, chunk_texts, index, bm25, top_k=top_k, lambda_faiss=lambda_faiss
)
logger.info("Retrieved chunks")
context = ""
token_count = 0
# context = "\n".join(retrieved_chunks)
for chunk in retrieved_chunks:
chunk_tokens = tokenizer(chunk, return_tensors="pt")["input_ids"].shape[1]
if token_count + chunk_tokens < MAX_CONTEXT_TOKENS:
context += chunk + "\n"
token_count += chunk_tokens
else:
break
prompt = (
"You are a financial analyst. Answer financial queries concisely using only the numerical data "
"explicitly present in the provided financial context:\n\n"
f"{context}\n\n"
"Use only the given financial data—do not assume, infer, or generate missing values."
" Retain the original format of financial figures without conversion."
" If the requested information is unavailable, respond with 'No relevant financial data available.'"
" Provide a single-sentence answer without explanations, additional text, or multiple responses."
f"\nQuery: {query}"
)
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
inputs.pop("token_type_ids", None)
logger.info("Generating output")
input_len = inputs["input_ids"].shape[-1]
logger.info(f"Input len: {input_len}")
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
repetition_penalty=repetition_penalty,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
)
sequences = output["sequences"][0][input_len:]
execution_time = time.perf_counter() - start_time
logger.info(f"Query processed in {execution_time:.2f} seconds.")
# Get the logits per generated token
log_probs = output["scores"]
token_probs = [torch.softmax(lp, dim=-1) for lp in log_probs]
# Extract top token probabilities for each step
token_confidences = [tp.max().item() for tp in token_probs]
# Compute final confidence score
top_token_conf = sum(token_confidences) / len(token_confidences)
print(f"Token Token Probability Mean: {top_token_conf:.4f}")
entropy_score = sum(compute_entropy(lp) for lp in log_probs) / len(log_probs)
entropy_conf = 1 - (entropy_score / torch.log(torch.tensor(tokenizer.vocab_size)))
print(f"Entropy-based Confidence: {entropy_conf:.4f}")
model_conf_signal = (top_token_conf + (1 - entropy_conf)) / 2
response = tokenizer.decode(sequences, skip_special_tokens=True)
confidence_score = compute_response_confidence(
query, response, retrieved_chunks, bm25, model_conf_signal
)
logger.info(f"Confidence: {confidence_score}%")
if confidence_score <= 0.3:
logger.error(f"The system is unsure about this response.")
response += "\nThe system is unsure about this response."
final_out = ""
if not use_extraction:
final_out += f"Context: {context}\nQuery: {query}\n"
final_out += f"Response: {response}"
return (
final_out,
f"Confidence: {confidence_score}%\nTime taken: {execution_time:.2f} seconds",
)
# Gradio UI
with gr.Blocks(title="Financial Statement RAG with LLM") as ui:
gr.Markdown("## Financial Statement RAG with LLM")
# File upload section
with gr.Group():
gr.Markdown("### Upload & Process Annual Reports")
file_input = gr.File(
file_count="multiple",
file_types=[".pdf", ".csv"],
type="filepath",
label="Upload Annual Reports (PDFs/CSVs)",
)
process_button = gr.Button("Process Files")
process_output = gr.Textbox(label="Processing Status", interactive=False)
# Query model section
with gr.Group():
gr.Markdown("### Ask a Financial Query")
query_input = gr.Textbox(label="Enter Query")
with gr.Row():
top_k_input = gr.Number(value=15, label="Top K (Default: 15)")
lambda_faiss_input = gr.Slider(0, 1, value=0.5, label="Lambda FAISS (0-1)")
repetition_penalty = gr.Slider(
1, 2, value=1.2, label="Repetition Penality (1-2)"
)
max_tokens_input = gr.Number(value=100, label="Max New Tokens")
use_extraction = gr.Checkbox(label="Retrieve only the answer", value=False)
query_button = gr.Button("Submit Query")
query_output = gr.Textbox(label="Query Response", interactive=False)
time_output = gr.Textbox(label="Time Taken", interactive=False)
# Button Actions
process_button.click(process_files, inputs=[file_input], outputs=process_output)
query_button.click(
query_model,
inputs=[
query_input,
top_k_input,
lambda_faiss_input,
repetition_penalty,
max_tokens_input,
use_extraction,
],
outputs=[query_output, time_output],
)
# Application entry point
if __name__ == "__main__":
logger.info("Starting Gradio server...")
ui.launch(server_name="0.0.0.0", server_port=7860, pwa=True)