Spaces:
Sleeping
Sleeping
import os | |
import subprocess | |
import sys | |
import warnings | |
import logging | |
from typing import List, Dict, Any, Optional | |
import tempfile | |
import re | |
import time | |
import gc | |
import spaces | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler("debug.log"), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Suppress warnings | |
warnings.filterwarnings("ignore") | |
def install_package(package: str, version: Optional[str] = None) -> None: | |
"""Install a Python package if not already installed""" | |
package_spec = f"{package}=={version}" if version else package | |
try: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-cache-dir", package_spec]) | |
print(f"Successfully installed {package_spec}") | |
except subprocess.CalledProcessError as e: | |
print(f"Failed to install {package_spec}: {e}") | |
raise | |
# Required packages - install these before importing | |
required_packages = { | |
"torch": None, | |
"gradio": "3.10.1", | |
"transformers": None, | |
"peft": None, | |
"bitsandbytes": None, | |
"PyPDF2": None, | |
"python-docx": None, | |
"accelerate": None, | |
"sentencepiece": None, | |
} | |
# Install required packages BEFORE importing them | |
for package, version in required_packages.items(): | |
try: | |
__import__(package) | |
print(f"{package} is already installed.") | |
except ImportError: | |
print(f"Installing {package}...") | |
install_package(package, version) | |
# Now we can safely import all required modules | |
import torch | |
import transformers | |
import gradio as gr | |
from transformers import ( | |
AutoTokenizer, AutoModelForCausalLM, | |
TrainingArguments, Trainer, TrainerCallback, | |
BitsAndBytesConfig | |
) | |
from peft import ( | |
LoraConfig, | |
prepare_model_for_kbit_training, | |
get_peft_model | |
) | |
import PyPDF2 | |
import docx | |
import numpy as np | |
from tqdm import tqdm | |
from torch.utils.data import Dataset as TorchDataset | |
# Suppress transformers warnings | |
transformers.logging.set_verbosity_error() | |
# Check GPU availability | |
if torch.cuda.is_available(): | |
DEVICE = "cuda" | |
print(f"GPU found: {torch.cuda.get_device_name(0)}") | |
print(f"CUDA version: {torch.version.cuda}") | |
else: | |
DEVICE = "cpu" | |
print("No GPU found, using CPU. Fine-tuning will be much slower.") | |
print("For better performance, use Google Colab with GPU runtime (Runtime > Change runtime type > GPU)") | |
# Constants specific to Phi-2 | |
MODEL_KEY = "microsoft/phi-2" | |
MAX_SEQ_LEN = 512 # Reduced from 1024 for much lighter memory usage | |
# FIX: Updated target modules for Phi-2 | |
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "dense"] # Correct modules for Phi-2 | |
# Initialize model and tokenizer | |
model = None | |
tokenizer = None | |
fine_tuned_model = None | |
document_text = "" # Store document content for context | |
def load_base_model() -> str: | |
"""Load Phi-2 with 8-bit quantization instead of 4-bit for faster training""" | |
global model, tokenizer | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
try: | |
# Use 8-bit quantization (faster to train than 4-bit) | |
if DEVICE == "cuda": | |
bnb_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
llm_int8_threshold=6.0, | |
llm_int8_has_fp16_weight=False | |
) | |
else: | |
bnb_config = None | |
# Load tokenizer with Phi-2 specific settings | |
print("Loading Phi-2 tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_KEY, | |
trust_remote_code=True, | |
padding_side="right" | |
) | |
# Ensure pad token is properly set | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load model with Phi-2 specific configuration | |
print("Loading Phi-2 model... (this may take a few minutes)") | |
if DEVICE == "cuda": | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_KEY, | |
quantization_config=bnb_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
else: | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_KEY, | |
torch_dtype=torch.float32, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
).to(DEVICE) | |
print("Phi-2 (2.7B) model loaded successfully!") | |
return "Phi-2 (2.7B) model loaded successfully! Ready to process documents." | |
except Exception as e: | |
error_msg = f"Error loading model: {str(e)}" | |
print(error_msg) | |
return error_msg | |
def phi2_prompt_template(context: str, question: str) -> str: | |
""" | |
Create a prompt optimized for Phi-2 | |
Phi-2 responds well to clear instruction formatting | |
""" | |
return f"""Instruction: Answer the question accurately based on the context provided. | |
Context: {context} | |
Question: {question} | |
Answer:""" | |
def process_pdf(file_path: str) -> str: | |
"""Extract text from PDF file""" | |
text = "" | |
try: | |
with open(file_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
total_pages = len(pdf_reader.pages) | |
# Process at most 30 pages to avoid memory issues | |
pages_to_process = min(total_pages, 30) | |
for i in range(pages_to_process): | |
page = pdf_reader.pages[i] | |
page_text = page.extract_text() or "" | |
text += page_text + "\n" | |
if total_pages > pages_to_process: | |
text += f"\n[Note: Only the first {pages_to_process} pages were processed due to size limitations.]" | |
except Exception as e: | |
print(f"Error processing PDF: {str(e)}") | |
return text | |
def process_docx(file_path: str) -> str: | |
"""Extract text from DOCX file""" | |
try: | |
doc = docx.Document(file_path) | |
text = "\n".join([para.text for para in doc.paragraphs]) | |
return text | |
except Exception as e: | |
print(f"Error processing DOCX: {str(e)}") | |
return "" | |
def process_txt(file_path: str) -> str: | |
"""Extract text from TXT file""" | |
try: | |
with open(file_path, 'r', encoding='utf-8', errors='ignore') as file: | |
text = file.read() | |
return text | |
except Exception as e: | |
print(f"Error processing TXT: {str(e)}") | |
return "" | |
def preprocess_text(text: str) -> str: | |
"""Clean and preprocess text""" | |
if not text: | |
return "" | |
# Remove extra whitespace | |
text = re.sub(r'\s+', ' ', text) | |
# Remove special characters that may cause issues | |
text = re.sub(r'[^\w\s.,;:!?\'\"()-]', '', text) | |
return text.strip() | |
def get_semantic_chunks(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]: | |
"""More efficient semantic chunking""" | |
if not text: | |
return [] | |
# Simple sentence splitting for speed | |
sentences = re.split(r'(?<=[.!?])\s+', text) | |
chunks = [] | |
current_chunk = [] | |
current_length = 0 | |
for sentence in sentences: | |
words = sentence.split() | |
if current_length + len(words) <= chunk_size: | |
current_chunk.append(sentence) | |
current_length += len(words) | |
else: | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
current_chunk = [sentence] | |
current_length = len(words) | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
# Limit to just 5 chunks for much faster processing | |
if len(chunks) > 5: | |
indices = np.linspace(0, len(chunks)-1, 5, dtype=int) | |
chunks = [chunks[i] for i in indices] | |
return chunks | |
def create_qa_dataset(document_chunks: List[str]) -> List[Dict[str, str]]: | |
"""Create comprehensive QA pairs from document chunks for better fine-tuning""" | |
qa_pairs = [] | |
# Document-level questions | |
full_text = " ".join(document_chunks[:5]) # Use beginning of document for overview | |
qa_pairs.append({ | |
"question": "What is this document about?", | |
"context": full_text, | |
"answer": "Based on my analysis, this document discusses..." # Empty template for model to learn | |
}) | |
qa_pairs.append({ | |
"question": "Summarize the key points of this document.", | |
"context": full_text, | |
"answer": "The key points of this document are..." | |
}) | |
# Process each chunk for specific QA pairs | |
for i, chunk in enumerate(document_chunks): | |
if not chunk or len(chunk) < 100: # Skip very short chunks | |
continue | |
# Context-specific questions | |
chunk_index = i + 1 # 1-indexed for readability | |
# Basic factual questions about chunk content | |
qa_pairs.append({ | |
"question": f"What information is contained in section {chunk_index}?", | |
"context": chunk, | |
"answer": f"Section {chunk_index} contains information about..." | |
}) | |
# Entity-based questions - find names, organizations, technical terms | |
entities = set(re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', chunk)) | |
technical_terms = set(re.findall(r'\b[A-Za-z]+-?[A-Za-z]+\b', chunk)) | |
# Filter to meaningful entities (longer than 3 chars) | |
entities = [e for e in entities if len(e) > 3][:2] # Limit to 2 entity questions per chunk | |
for entity in entities: | |
qa_pairs.append({ | |
"question": f"What does the document say about {entity}?", | |
"context": chunk, | |
"answer": f"Regarding {entity}, the document states that..." | |
}) | |
# Specific content questions | |
sentences = re.split(r'(?<=[.!?])\s+', chunk) | |
key_sentences = [s for s in sentences if len(s.split()) > 8][:2] # Focus on substantive sentences | |
for sentence in key_sentences: | |
# Create question from sentence by identifying subject | |
subject_match = re.search(r'^(The|A|An|This|These|Those|Some|Any|Many|Few|All|Most)?\s*([A-Za-z\s]+?)\s+(is|are|was|were|has|have|had|can|could|will|would|may|might)', sentence, re.IGNORECASE) | |
if subject_match: | |
subject = subject_match.group(2).strip() | |
if len(subject) > 2: | |
qa_pairs.append({ | |
"question": f"What information is provided about {subject}?", | |
"context": chunk, | |
"answer": sentence | |
}) | |
# Add relationship questions between concepts | |
if i < len(document_chunks) - 1: | |
next_chunk = document_chunks[i+1] | |
qa_pairs.append({ | |
"question": f"How does the information in section {chunk_index} relate to section {chunk_index+1}?", | |
"context": chunk + " " + next_chunk, | |
"answer": f"Section {chunk_index} discusses... while section {chunk_index+1} covers... The relationship between them is..." | |
}) | |
# Limit to 5 examples max for lighter memory usage | |
if len(qa_pairs) > 5: | |
import random | |
random.shuffle(qa_pairs) | |
qa_pairs = qa_pairs[:5] | |
return qa_pairs | |
class QADataset(TorchDataset): | |
"""PyTorch dataset specialized for Phi-2 QA fine-tuning""" | |
def __init__(self, qa_pairs: List[Dict[str, str]], tokenizer, max_length: int = MAX_SEQ_LEN): | |
self.qa_pairs = qa_pairs | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
# Verify dataset structure | |
self.validate_dataset() | |
def validate_dataset(self): | |
"""Verify that the dataset has proper structure""" | |
if not self.qa_pairs: | |
print("Warning: Empty dataset!") | |
return | |
required_keys = ["question", "context", "answer"] | |
for i, item in enumerate(self.qa_pairs[:5]): # Check first 5 examples | |
missing = [k for k in required_keys if k not in item] | |
if missing: | |
print(f"Warning: Example {i} missing keys: {missing}") | |
# Check for empty values | |
empty = [k for k in required_keys if k in item and not item[k]] | |
if empty: | |
print(f"Warning: Example {i} has empty values for: {empty}") | |
def __len__(self): | |
return len(self.qa_pairs) | |
def __getitem__(self, idx): | |
qa_pair = self.qa_pairs[idx] | |
# Format prompt using Phi-2 template | |
context = qa_pair['context'] | |
question = qa_pair['question'] | |
answer = qa_pair['answer'] | |
# Build Phi-2 specific prompt | |
prompt = phi2_prompt_template(context, question) | |
# Concatenate prompt and answer | |
sequence = f"{prompt} {answer}" | |
try: | |
# Tokenize with proper handling | |
encoded = self.tokenizer( | |
sequence, | |
truncation=True, | |
max_length=self.max_length, | |
padding="max_length", | |
return_tensors="pt" | |
) | |
# Extract tensors | |
input_ids = encoded["input_ids"].squeeze(0) | |
attention_mask = encoded["attention_mask"].squeeze(0) | |
# Create labels | |
labels = input_ids.clone() | |
# Calculate prompt length accurately | |
prompt_encoded = self.tokenizer(prompt, add_special_tokens=False) | |
prompt_length = len(prompt_encoded["input_ids"]) | |
# Ensure prompt_length doesn't exceed labels length | |
prompt_length = min(prompt_length, len(labels)) | |
# Set labels for prompt portion to -100 (ignored in loss calculation) | |
labels[:prompt_length] = -100 | |
return { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"labels": labels | |
} | |
except Exception as e: | |
print(f"Error processing sample {idx}: {e}") | |
# Return dummy sample as fallback | |
return { | |
"input_ids": torch.zeros(self.max_length, dtype=torch.long), | |
"attention_mask": torch.zeros(self.max_length, dtype=torch.long), | |
"labels": torch.zeros(self.max_length, dtype=torch.long) | |
} | |
def clear_gpu_memory(): | |
"""Clear GPU memory to prevent OOM errors""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
class ProgressCallback(TrainerCallback): | |
def __init__(self, progress, status_box=None): | |
self.progress = progress | |
self.status_box = status_box | |
self.current_step = 0 | |
self.total_steps = 0 | |
def on_train_begin(self, args, state, control, **kwargs): | |
self.total_steps = state.max_steps | |
def on_step_end(self, args, state, control, **kwargs): | |
self.current_step = state.global_step | |
progress_percent = self.current_step / self.total_steps | |
self.progress(0.4 + (0.5 * progress_percent), | |
desc=f"Epoch {state.epoch}/{args.num_train_epochs} | Step {self.current_step}/{self.total_steps}") | |
if self.status_box: | |
self.status_box.update(f"Training in progress: Epoch {state.epoch}/{args.num_train_epochs} | Step {self.current_step}/{self.total_steps}") | |
def create_deepspeed_config(): | |
"""Create DeepSpeed config for faster training""" | |
return { | |
"fp16": { | |
"enabled": True | |
}, | |
"zero_optimization": { | |
"stage": 2, | |
"offload_optimizer": { | |
"device": "cpu", | |
"pin_memory": True | |
}, | |
"allgather_partitions": True, | |
"allgather_bucket_size": 5e8, | |
"reduce_scatter": True, | |
"reduce_bucket_size": 5e8, | |
"overlap_comm": True, | |
"contiguous_gradients": True | |
}, | |
"optimizer": { | |
"type": "AdamW", | |
"params": { | |
"lr": 2e-4, | |
"betas": [0.9, 0.999], | |
"eps": 1e-8, | |
"weight_decay": 0.01 | |
} | |
}, | |
"scheduler": { | |
"type": "WarmupLR", | |
"params": { | |
"warmup_min_lr": 0, | |
"warmup_max_lr": 2e-4, | |
"warmup_num_steps": 50 | |
} | |
}, | |
"train_batch_size": 1, | |
"train_micro_batch_size_per_gpu": 1, | |
"gradient_accumulation_steps": 1, | |
"gradient_clipping": 0.5, | |
"steps_per_print": 10 | |
} | |
def finetune_model(qa_dataset, progress=gr.Progress(), status_box=None): | |
"""Fine-tune Phi-2 using optimized LoRA parameters""" | |
global model, tokenizer, fine_tuned_model | |
if model is None: | |
return "Please load the base model first." | |
if len(qa_dataset) == 0: | |
return "No training data created. Please check your document." | |
try: | |
progress(0.1, desc="Preparing model for fine-tuning...") | |
if status_box: | |
status_box.update("Preparing model for fine-tuning...") | |
# Clear GPU memory | |
clear_gpu_memory() | |
# Prepare model for 8-bit training if using GPU | |
if DEVICE == "cuda": | |
training_model = prepare_model_for_kbit_training(model) | |
else: | |
training_model = model | |
# Add this line to fix the gradient error | |
training_model.enable_input_require_grads() | |
# Configure LoRA for Phi-2 | |
peft_config = LoraConfig( | |
r=2, # Reduced rank for lighter training | |
lora_alpha=4, # Reduced alpha | |
lora_dropout=0.05, # Added small dropout for regularization | |
bias="none", | |
task_type="CAUSAL_LM", | |
target_modules=LORA_TARGET_MODULES # Fixed Phi-2 modules | |
) | |
# Apply LoRA to model | |
lora_model = get_peft_model(training_model, peft_config) | |
# Print trainable parameters | |
trainable_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad) | |
all_params = sum(p.numel() for p in lora_model.parameters()) | |
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/all_params:.2%} of {all_params:,} total)") | |
# Enable gradient checkpointing for memory efficiency | |
if hasattr(lora_model, "gradient_checkpointing_enable"): | |
lora_model.gradient_checkpointing_enable() | |
print("Gradient checkpointing enabled") | |
# Create training arguments optimized for Phi-2 | |
training_args = TrainingArguments( | |
output_dir="./results", | |
num_train_epochs=2, # Set to 2 as requested | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=1, | |
learning_rate=1e-4, # Reduced from 2e-4 for stability | |
lr_scheduler_type="constant", # Simplified scheduler | |
warmup_ratio=0.05, # Slight increase in warmup | |
weight_decay=0.01, | |
logging_steps=1, | |
max_grad_norm=0.3, # Reduced from 0.5 for better gradient stability | |
save_strategy="no", | |
report_to="none", | |
remove_unused_columns=False, | |
fp16=(DEVICE == "cuda"), | |
no_cuda=(DEVICE == "cpu"), | |
optim="adamw_torch", # Use standard optimizer instead of fused for stability | |
gradient_checkpointing=True | |
) | |
# Add DeepSpeed if on CUDA | |
if DEVICE == "cuda": | |
training_args.deepspeed = create_deepspeed_config() | |
# Create data collator that doesn't move tensors to device yet | |
def collate_fn(features): | |
batch = {} | |
for key in features[0].keys(): | |
if key in ["input_ids", "attention_mask", "labels"]: | |
batch[key] = torch.stack([f[key] for f in features]) | |
return batch | |
progress(0.3, desc="Setting up trainer...") | |
if status_box: | |
status_box.update("Setting up trainer...") | |
# Create trainer | |
trainer = Trainer( | |
model=lora_model, | |
args=training_args, | |
train_dataset=qa_dataset, | |
data_collator=collate_fn, | |
callbacks=[ProgressCallback(progress, status_box)] # Add both callbacks | |
) | |
# Start training | |
progress(0.4, desc="Initializing training...") | |
if status_box: | |
status_box.update("Initializing training...") | |
print("Starting training...") | |
trainer.train() | |
# Set fine-tuned model | |
fine_tuned_model = lora_model | |
# Put model in evaluation mode | |
fine_tuned_model.eval() | |
# Clear memory | |
clear_gpu_memory() | |
return "Fine-tuning completed successfully! You can now ask questions about your document." | |
except Exception as e: | |
error_msg = f"Error during fine-tuning: {str(e)}" | |
print(error_msg) | |
import traceback | |
traceback.print_exc() | |
# Try to clean up memory | |
try: | |
clear_gpu_memory() | |
except: | |
pass | |
return error_msg | |
def process_document(file_obj, progress=gr.Progress(), status_box=None): | |
"""Process uploaded document and prepare dataset for fine-tuning""" | |
global model, tokenizer, document_text | |
progress(0, desc="Processing document...") | |
if status_box: | |
status_box.update("Processing document...") | |
if not file_obj: | |
return "Please upload a document first." | |
try: | |
# Create temp directory for file | |
temp_dir = tempfile.mkdtemp() | |
# Get file name | |
file_name = getattr(file_obj, 'name', 'uploaded_file') | |
if not isinstance(file_name, str): | |
file_name = "uploaded_file.txt" # Default name | |
# Ensure file has extension | |
if '.' not in file_name: | |
file_name = file_name + '.txt' | |
temp_path = os.path.join(temp_dir, file_name) | |
# Get file content | |
if hasattr(file_obj, 'read'): | |
file_content = file_obj.read() | |
else: | |
file_content = file_obj | |
with open(temp_path, 'wb') as f: | |
f.write(file_content) | |
# Extract text based on file extension | |
file_extension = os.path.splitext(file_name)[1].lower() | |
if file_extension == '.pdf': | |
text = process_pdf(temp_path) | |
elif file_extension in ['.docx', '.doc']: | |
text = process_docx(temp_path) | |
elif file_extension == '.txt' or True: # Default to txt for unknown extensions | |
text = process_txt(temp_path) | |
# Check if text was extracted | |
if not text or len(text) < 50: | |
return "Could not extract sufficient text from the document. Please check the file." | |
# Save document text for context window during inference | |
document_text = text | |
# Preprocess and chunk the document | |
progress(0.3, desc="Preprocessing document...") | |
if status_box: | |
status_box.update("Preprocessing document...") | |
text = preprocess_text(text) | |
chunks = get_semantic_chunks(text) | |
if not chunks: | |
return "Could not extract meaningful text from the document." | |
# Create enhanced QA pairs | |
progress(0.5, desc="Creating QA dataset...") | |
if status_box: | |
status_box.update("Creating QA dataset...") | |
qa_pairs = create_qa_dataset(chunks) | |
print(f"Created {len(qa_pairs)} QA pairs for training") | |
# Debug: Print a sample of QA pairs to verify format | |
if qa_pairs: | |
print("\nSample QA pair for validation:") | |
sample = qa_pairs[0] | |
print(f"Question: {sample['question']}") | |
print(f"Context length: {len(sample['context'])} chars") | |
print(f"Answer: {sample['answer'][:50]}...") | |
# Create dataset | |
qa_dataset = QADataset(qa_pairs, tokenizer, max_length=MAX_SEQ_LEN) | |
# Fine-tune model | |
progress(0.7, desc="Starting fine-tuning...") | |
if status_box: | |
status_box.update("Starting fine-tuning...") | |
result = finetune_model(qa_dataset, progress, status_box) | |
# Clean up | |
try: | |
os.remove(temp_path) | |
os.rmdir(temp_dir) | |
except: | |
pass | |
return result | |
except Exception as e: | |
error_msg = f"Error processing document: {str(e)}" | |
print(error_msg) | |
import traceback | |
traceback.print_exc() | |
return error_msg | |
def generate_answer(question, status_box=None): | |
"""Generate answer using fine-tuned Phi-2 model with improved response quality""" | |
global fine_tuned_model, tokenizer, document_text | |
if fine_tuned_model is None: | |
return "Please process a document first!" | |
if not question.strip(): | |
return "Please enter a question." | |
try: | |
# Clear memory before generation | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# For better answers, use document context to help the model | |
# Find relevant context from document (simple keyword matching for efficiency) | |
keywords = re.findall(r'\b\w{5,}\b', question.lower()) | |
context = document_text | |
# If document is very long, try to find relevant section | |
if len(document_text) > 2000 and keywords: | |
chunks = get_semantic_chunks(document_text, chunk_size=500, overlap=100) | |
relevant_chunks = [] | |
for chunk in chunks: | |
score = sum(1 for keyword in keywords if keyword.lower() in chunk.lower()) | |
if score > 0: | |
relevant_chunks.append((chunk, score)) | |
relevant_chunks.sort(key=lambda x: x[1], reverse=True) | |
if relevant_chunks: | |
# Use top 2 most relevant chunks | |
context = " ".join([chunk for chunk, _ in relevant_chunks[:2]]) | |
# Limit context length to fit in model's context window | |
context = context[:1500] # Limit to 1500 chars for prompt space | |
# Create Phi-2 optimized prompt | |
prompt = phi2_prompt_template(context, question) | |
# Ensure model is in evaluation mode | |
fine_tuned_model.eval() | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt").to(fine_tuned_model.device) | |
# Configure generation parameters optimized for Phi-2 | |
with torch.no_grad(): | |
outputs = fine_tuned_model.generate( | |
**inputs, | |
max_new_tokens=75, # Reduced from 150 | |
do_sample=True, | |
temperature=0.7, | |
top_k=40, | |
top_p=0.85, | |
repetition_penalty=1.2, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
# Decode response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the generated answer part | |
if "Answer:" in response: | |
answer = response.split("Answer:")[-1].strip() | |
else: | |
answer = response | |
# If answer is too short or generic, try again with more temperature | |
if len(answer.split()) < 10 or "I don't have enough information" in answer: | |
with torch.no_grad(): | |
outputs = fine_tuned_model.generate( | |
**inputs, | |
max_new_tokens=75, # Reduced from 150 | |
do_sample=True, | |
temperature=0.9, # Higher temperature | |
top_k=40, | |
top_p=0.92, | |
repetition_penalty=1.2, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
# Decode second attempt | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract answer | |
if "Answer:" in response: | |
answer = response.split("Answer:")[-1].strip() | |
else: | |
answer = response | |
return answer | |
except Exception as e: | |
error_msg = f"Error generating answer: {str(e)}" | |
print(error_msg) | |
return error_msg | |
# Create Gradio interface | |
with gr.Blocks(title="Phi-2 Document QA", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π Phi-2 Document Q&A System") | |
gr.Markdown("Specialized system for fine-tuning Microsoft's Phi-2 model on your documents") | |
with gr.Tab("Document Processing"): | |
file_input = gr.File( | |
label="Upload Document (PDF, DOCX, or TXT)", | |
file_types=[".pdf", ".docx", ".txt"], | |
type="binary" | |
) | |
with gr.Row(): | |
load_model_btn = gr.Button("1. Load Phi-2 Model", variant="secondary") | |
process_btn = gr.Button("2. Process & Fine-tune Document", variant="primary") | |
status = gr.Textbox( | |
label="Status", | |
placeholder="First load the model, then upload a document and click 'Process & Fine-tune'", | |
lines=3 | |
) | |
gr.Markdown(""" | |
### Tips for Best Results | |
- PDF, DOCX and TXT files are supported | |
- Keep documents under 10 pages for best results | |
- Processing time depends on document length and GPU availability | |
- For GPU usage in Colab: Runtime > Change runtime type > GPU | |
""") | |
with gr.Tab("Ask Questions"): | |
question_input = gr.Textbox( | |
label="Your Question", | |
placeholder="Ask about your document...", | |
lines=2 | |
) | |
ask_btn = gr.Button("Get Answer", variant="primary") | |
answer_output = gr.Textbox( | |
label="Phi-2's Response", | |
placeholder="The answer will appear here after you ask a question", | |
lines=8 | |
) | |
gr.Markdown(""" | |
### Example Questions | |
- "What is this document about?" | |
- "Summarize the key points in this document" | |
- "What does the document say about [specific topic]?" | |
- "Explain the relationship between [concept A] and [concept B]" | |
""") | |
# Set up events | |
load_model_btn.click( | |
fn=load_base_model, | |
outputs=[status] | |
) | |
process_btn.click( | |
fn=process_document, | |
inputs=[file_input], | |
outputs=[status] | |
) | |
ask_btn.click( | |
fn=generate_answer, | |
inputs=[question_input], | |
outputs=[answer_output] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=True) |