Spaces:
Sleeping
Sleeping
File size: 8,740 Bytes
b72d420 617edd8 b72d420 617edd8 094178c b72d420 617edd8 b72d420 617edd8 b72d420 617edd8 b72d420 617edd8 b72d420 617edd8 b72d420 094178c b72d420 094178c b72d420 094178c b72d420 617edd8 b72d420 617edd8 b72d420 470c16a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import os
import re
import io
import numpy as np
import networkx as nx
from sympy import symbols
from galgebra.ga import Ga
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Input
import tensorflow as tf
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline
)
import torch
from PyPDF2 import PdfReader
from concurrent.futures import ThreadPoolExecutor, as_completed
import streamlit as st
# Optionally, set environment variables to optimize CPU parallelism.
os.environ["OMP_NUM_THREADS"] = "4" # Adjust to your available cores.
os.environ["MKL_NUM_THREADS"] = "4"
# Setup IBM Granite model without 8-bit quantization (for CPU).
model_name = "ibm-granite/granite-3.1-2b-instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="balanced", # Using balanced CPU mapping.
torch_dtype=torch.float16 # Use float16 if supported.
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
DIM = 5000
# We use a lower max token count for faster generation.
DEFAULT_MAX_TOKENS = 1000
coords = symbols('e1 e2 e3')
ga = Ga('e1 e2 e3', g=[1, 1, 1])
# Cache the knowledge graph.
KNOWLEDGE_GRAPH = nx.Graph()
KNOWLEDGE_GRAPH.add_edges_from([
("Ambiguous Terms", "Risk of Dispute"),
("Lack of Termination Clause", "Prolonged Obligations"),
("Non-compliance", "Legal Penalties"),
("Confidentiality Breaches", "Reputational Damage"),
("Inadequate Indemnification", "High Liability"),
("Unclear Jurisdiction", "Compliance Issues"),
("Force Majeure", "Risk Mitigation"),
("Data Privacy", "Regulatory Compliance"),
("Penalty Clauses", "Financial Risk"),
("Intellectual Property", "Contract Disputes")
])
# Caches for file content and summaries.
FILE_CACHE = {}
SUMMARY_CACHE = {}
# Initialize a summarization pipeline on CPU (using a lightweight model).
summarizer = pipeline("summarization", model="t5-small", tokenizer="t5-small", device=-1)
def read_file(file_obj):
"""
Reads content from a file. Supports both file paths (str) and Streamlit uploaded files.
"""
if isinstance(file_obj, str):
if file_obj in FILE_CACHE:
return FILE_CACHE[file_obj]
if not os.path.exists(file_obj):
st.error(f"File not found: {file_obj}")
return ""
content = ""
try:
if file_obj.lower().endswith(".pdf"):
reader = PdfReader(file_obj)
for page in reader.pages:
content += page.extract_text() + "\n"
else:
with open(file_obj, "r", encoding="utf-8") as f:
content = f.read() + "\n"
FILE_CACHE[file_obj] = content
except Exception as e:
st.error(f"Error reading {file_obj}: {e}")
content = ""
return content
else:
# Assume it's an uploaded file (BytesIO).
file_name = file_obj.name
if file_name in FILE_CACHE:
return FILE_CACHE[file_name]
try:
if file_name.lower().endswith(".pdf"):
reader = PdfReader(io.BytesIO(file_obj.read()))
content = ""
for page in reader.pages:
content += page.extract_text() + "\n"
else:
content = file_obj.getvalue().decode("utf-8")
FILE_CACHE[file_name] = content
return content
except Exception as e:
st.error(f"Error reading uploaded file {file_name}: {e}")
return ""
def summarize_text(text, chunk_size=2000):
"""
Summarize text if it is longer than chunk_size.
Uses parallel processing for multiple chunks.
(Reducing chunk_size may speed up summarization on CPU.)
"""
if len(text) <= chunk_size:
return text
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
summaries = []
with ThreadPoolExecutor() as executor:
futures = {executor.submit(summarizer, chunk, max_length=100, min_length=30, do_sample=False): chunk for chunk in chunks}
for future in as_completed(futures):
summary = future.result()[0]["summary_text"]
summaries.append(summary)
return " ".join(summaries)
def read_files(file_objs, max_length=3000):
"""
Read and, if necessary, summarize file content from a list of file objects or file paths.
"""
full_text = ""
for file_obj in file_objs:
text = read_file(file_obj)
full_text += text + "\n"
cache_key = hash(full_text)
if cache_key in SUMMARY_CACHE:
return SUMMARY_CACHE[cache_key]
if len(full_text) > max_length:
summarized = summarize_text(full_text, chunk_size=max_length)
else:
summarized = full_text
SUMMARY_CACHE[cache_key] = summarized
return summarized
def build_prompt(system_msg, document_content, user_prompt):
"""
Build a unified prompt that explicitly delineates the system instructions,
document content, and user prompt.
"""
prompt_parts = []
prompt_parts.append("SYSTEM PROMPT:\n" + system_msg.strip())
if document_content:
prompt_parts.append("\nDOCUMENT CONTENT:\n" + document_content.strip())
prompt_parts.append("\nUSER PROMPT:\n" + user_prompt.strip())
return "\n\n".join(prompt_parts)
def speculative_decode(input_text, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**model_inputs,
max_new_tokens=max_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(output[0], skip_special_tokens=True)
def post_process(text):
lines = text.splitlines()
unique_lines = []
for line in lines:
clean_line = line.strip()
if clean_line and clean_line not in unique_lines:
unique_lines.append(clean_line)
return "\n".join(unique_lines)
def granite_analysis(user_prompt, file_objs=None, max_tokens=DEFAULT_MAX_TOKENS, top_p=0.9, temperature=0.7):
# Read and summarize document content.
document_content = read_files(file_objs) if file_objs else ""
# Define a clear system prompt.
system_prompt = (
"You are IBM Granite, an enterprise legal and technical analysis assistant. "
"Your task is to critically analyze the contract document provided below. "
"Pay special attention to identifying dangerous provisions, legal pitfalls, and potential liabilities. "
"Make sure to address both the overall contract structure and specific clauses where applicable."
)
# Build a unified prompt with explicit sections.
unified_prompt = build_prompt(system_prompt, document_content, user_prompt)
# Generate the analysis.
response = speculative_decode(unified_prompt, max_tokens=max_tokens, top_p=top_p, temperature=temperature)
final_response = post_process(response)
return final_response
# --------- Streamlit App Interface ---------
st.title("IBM Granite - Contract Analysis Assistant")
st.markdown("Upload a contract document (PDF or text) and adjust the analysis prompt and generation parameters.")
# File uploader (allows drag & drop)
uploaded_files = st.file_uploader("Upload contract file(s)", type=["pdf", "txt"], accept_multiple_files=True)
# Editable prompt text area
default_prompt = (
"Please analyze the attached contract document and highlight any clauses "
"that represent potential dangers, liabilities, or legal pitfalls that may lead to future disputes or significant financial exposure."
)
user_prompt = st.text_area("Analysis Prompt", default_prompt, height=150)
# Sliders for generation parameters.
max_tokens_slider = st.slider("Maximum Tokens", min_value=100, max_value=2000, value=DEFAULT_MAX_TOKENS, step=100)
temperature_slider = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.7, step=0.1)
top_p_slider = st.slider("Top-p", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
if st.button("Analyze Contract"):
with st.spinner("Analyzing contract document..."):
result = granite_analysis(user_prompt, uploaded_files, max_tokens=max_tokens_slider, top_p=top_p_slider, temperature=temperature_slider)
st.success("Analysis complete!")
st.markdown("### Analysis Output")
keyword = "ASSISTANT PROMPT:"
text_after_keyword = result.rsplit(keyword, 1)[-1].strip()
st.text_area("Output", text_after_keyword, height=400)
|