|
import streamlit as st |
|
import os |
|
import numpy as np |
|
import re |
|
import tempfile |
|
from datetime import datetime |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from langchain_community.document_loaders import PDFPlumberLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.llms import Ollama |
|
from langchain.retrievers import BM25Retriever, EnsembleRetriever |
|
from sentence_transformers import CrossEncoder |
|
from transformers import pipeline |
|
from langchain_core.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline |
|
from huggingface_hub import login |
|
|
|
|
|
|
|
model_name= "Qwen/Qwen2.5-0.5B-Instruct" |
|
|
|
|
|
classifier = pipeline("zero-shot-classification", |
|
model="typeform/distilbert-base-uncased-mnli") |
|
|
|
|
|
st.set_page_config(page_title="Multi-File Financial Analyzer", layout="wide") |
|
st.title("π Financial Analysis System") |
|
|
|
|
|
with st.sidebar: |
|
st.header("Configuration Panel") |
|
model_choice = st.selectbox("LLM Model", |
|
[model_name], |
|
help="Choose the core analysis engine") |
|
chunk_size = st.slider("Document Chunk Size", 500, 2000, 1000) |
|
rerank_threshold = st.slider("Re-ranking Threshold", 0.0, 1.0, 0.1) |
|
|
|
|
|
uploaded_files = st.file_uploader("Upload Financial PDFs", |
|
type="pdf", |
|
accept_multiple_files=True) |
|
|
|
if uploaded_files: |
|
all_docs = [] |
|
with st.spinner("Processing Multiple Financial Documents..."): |
|
for uploaded_file in uploaded_files: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: |
|
tmp.write(uploaded_file.getvalue()) |
|
tmp_path = tmp.name |
|
|
|
|
|
loader = PDFPlumberLoader(tmp_path) |
|
docs = loader.load() |
|
all_docs.extend(docs) |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, |
|
chunk_overlap=200, |
|
separators=["\n\n", "\n", "\. ", "! ", "? ", " ", ""] |
|
) |
|
documents = text_splitter.split_documents(all_docs) |
|
|
|
|
|
embedder = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") |
|
vector_store = FAISS.from_documents(documents, embedder) |
|
bm25_retriever = BM25Retriever.from_documents(documents) |
|
bm25_retriever.k = 5 |
|
faiss_retriever = vector_store.as_retriever(search_kwargs={"k": 5}) |
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[bm25_retriever, faiss_retriever], |
|
weights=[0.4, 0.6] |
|
) |
|
|
|
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
) |
|
|
|
|
|
|
|
PROMPT_TEMPLATE = """ |
|
<|User|> |
|
You are a senior financial analyst. Analyze these financial reports: |
|
1. Compare key metrics between documents |
|
2. Identify trends across reporting periods |
|
3. Highlight differences/similarities |
|
4. Provide risk assessment |
|
5. Offer recommendations |
|
|
|
Format response with clear sections and bullet points. Keep under 300 words. |
|
|
|
Context: {context} |
|
Question: {question} |
|
<|assistant|> |
|
""" |
|
|
|
qa_prompt = PromptTemplate( |
|
template=PROMPT_TEMPLATE, |
|
input_variables=["context", "question"] |
|
) |
|
|
|
|
|
st.header("π Cross-Document Financial Inquiry") |
|
|
|
|
|
comparative_questions = [ |
|
"Analyze changes in debt structure across both reports", |
|
"Show expense ratio differences between the two years", |
|
"What are the main liquidity changes across both periods?", |
|
] |
|
user_query = st.selectbox("Sample Financial Questions", |
|
[""] + comparative_questions) |
|
user_input = st.text_input("Or enter custom financial query:", |
|
value=user_query) |
|
|
|
if user_input: |
|
|
|
classification = classifier(user_input, |
|
["financial", "other"], |
|
multi_label=False) |
|
print(f"-- Guard rail check is completed for query with prob:{classification['scores'][0]}") |
|
|
|
if classification['scores'][0] < 0.7: |
|
st.error("Query not related to financial. Ask about financial related queries") |
|
st.stop() |
|
|
|
with st.spinner("Performing Cross-Document Analysis..."): |
|
|
|
initial_docs = ensemble_retriever.get_relevant_documents(user_input) |
|
|
|
|
|
doc_pairs = [(user_input, doc.page_content) for doc in initial_docs] |
|
rerank_scores = cross_encoder.predict(doc_pairs) |
|
sorted_indices = np.argsort(rerank_scores)[::-1] |
|
ranked_docs = [initial_docs[i] for i in sorted_indices] |
|
filtered_docs = [d for d, s in zip(ranked_docs, rerank_scores) |
|
if s > rerank_threshold][:7] |
|
print(f"-- Retrieved chunks:{filtered_docs}") |
|
|
|
|
|
confidence_score = np.mean(rerank_scores[sorted_indices][:3]) * 100 |
|
confidence_score = min(100, max(0, round(confidence_score, 1))) |
|
|
|
|
|
context = "\n".join([doc.page_content for doc in filtered_docs]) |
|
print(f"-- Retrieved context:{context}") |
|
|
|
|
|
prompt = qa_prompt.format(context=context, question=user_input) |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": "You are Financial assistant."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
print(f"-- Model Invoking") |
|
generated_ids = model.generate( |
|
**model_inputs, |
|
max_new_tokens=512 |
|
) |
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
|
|
analysis = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
print(f"Analysis result:{analysis}") |
|
|
|
|
|
clean_analysis = re.sub(r"<think>|</think>|\n{3,}", "", analysis) |
|
clean_analysis = re.sub(r'(\d)([A-Za-z])', r'\1 \2', clean_analysis) |
|
clean_analysis = re.sub(r'(\d{1,3})(\d{3})', r'\1,\2', clean_analysis) |
|
|
|
|
|
st.subheader("User Query+Context to the LLM") |
|
st.markdown(f"```\n{qa_prompt.format(context=context, question=user_input)}\n```") |
|
|
|
|
|
st.subheader("Integrated Financial Analysis") |
|
st.markdown(f"```\n{clean_analysis}\n```") |
|
st.progress(int(confidence_score)/100) |
|
st.caption(f"Analysis Confidence: {confidence_score}%") |
|
|
|
|
|
if st.button("Generate Financial Analysis Report"): |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
export_content = f"COMPARATIVE QUERY: {user_input}\n\nANALYSIS:\n{clean_analysis}" |
|
st.download_button("Download Full Report", export_content, |
|
file_name=f"Comparative_Analysis_{timestamp}.txt", |
|
mime="text/plain") |
|
|
|
else: |
|
st.info("Please upload PDF financial reports to begin financial analysis") |