SFT-Phi2-QLoRA / app.py
Shilpaj's picture
Feat: Project files
0238edd verified
#!/usr/bin/env python3
"""
Gradio application for inference with Phi-2 model using LoRA/QLoRA adapters.
Pre-loads the model and provides a simple chat interface.
"""
import os
import time
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
# Define constants
DEFAULT_MODEL_PATH = "./adapters" # Path to the trained adapters
DEFAULT_BASE_MODEL = "microsoft/phi-2" # Base model name
DEFAULT_MAX_NEW_TOKENS = 512
DEFAULT_TEMPERATURE = 0.7
DEFAULT_TOP_P = 0.9
DEFAULT_TOP_K = 50
# Global variables to store the model and tokenizer
model = None
tokenizer = None
def load_model(
model_path=DEFAULT_MODEL_PATH,
base_model=DEFAULT_BASE_MODEL,
use_qlora=True,
device="cuda"
):
"""
Load the base model and adapter weights.
"""
global model, tokenizer
print(f"Loading tokenizer from {base_model}...")
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Configure model loading parameters
model_kwargs = {"trust_remote_code": True}
# Set up quantization for QLoRA if enabled
if use_qlora:
print("Using 4-bit quantization (QLoRA)")
compute_dtype = torch.float16
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True
)
model_kwargs["quantization_config"] = quantization_config
else:
model_kwargs["torch_dtype"] = torch.float16 if torch.cuda.is_available() else torch.float32
# Check if adapter path exists
if not os.path.exists(model_path):
print(f"Warning: Model path '{model_path}' does not exist. Using base model only.")
# Load base model
print(f"Loading base model {base_model}...")
base_model = AutoModelForCausalLM.from_pretrained(
base_model,
**model_kwargs
)
# Load adapter weights if available
if os.path.exists(model_path) and os.path.exists(os.path.join(model_path, "adapter_config.json")):
print(f"Loading {'QLoRA' if use_qlora else 'LoRA'} adapters from {model_path}...")
model = PeftModel.from_pretrained(base_model, model_path)
# Special handling for QLoRA - move norm layers to float32 for stability
# and ensure model and adapter layers have consistent dtypes
if use_qlora:
print("Harmonizing model layer dtypes...")
working_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# First make sure important parts are in float16/32
for name, module in model.named_modules():
if any(x in name for x in ["lm_head", "embed_tokens"]):
module.to(working_dtype)
elif "norm" in name:
module.to(torch.float32) # Norms should be in fp32 for stability
else:
model = base_model
print("Using base model without adapters")
# Move model to device
device = torch.device(device if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
print(f"Model loaded successfully and moved to {device}!")
return model, tokenizer
def generate_response(prompt, chat_history):
"""
Generate text response from the model.
"""
global model, tokenizer
if model is None or tokenizer is None:
return chat_history + [(prompt, "Model not loaded yet. Please wait a moment.")]
# Format prompt for Phi-2
formatted_prompt = f"Instruct: {prompt}\nOutput:"
# Tokenize input prompt
device = next(model.parameters()).device
input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt").to(device)
attention_mask = torch.ones_like(input_ids).to(device)
# Generate text with robust error handling
try:
with torch.no_grad():
# Explicit type casting
input_ids = input_ids.to(torch.long) # IDs should always be long
attention_mask = attention_mask.to(torch.float16 if torch.cuda.is_available() else torch.float32)
# First attempt with simple parameters
generated_ids = model.generate(
input_ids=input_ids,
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
do_sample=True,
temperature=DEFAULT_TEMPERATURE,
top_p=DEFAULT_TOP_P,
top_k=DEFAULT_TOP_K,
)
except Exception as e:
print(f"Generation error: {str(e)}")
try:
# Fallback: Try with model in eval with forced dtype
print("Attempting fallback generation...")
with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu',
dtype=torch.float16 if torch.cuda.is_available() else torch.float32):
generated_ids = model.generate(
input_ids=input_ids,
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
do_sample=False, # Use greedy decoding for more stability
)
except Exception as e2:
return chat_history + [(prompt, f"Error generating response: {str(e2)}")]
# Decode the generated text
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Extract just the output part
output = generated_text.split("Output:")[1].strip() if "Output:" in generated_text else generated_text
# Update chat history
return chat_history + [(prompt, output)]
# Example prompts to demonstrate model capabilities
examples = [
["Explain the concept of quantum computing in simple terms."],
["Write a short story about a robot that learns to paint."],
["What are some ethical considerations when developing AI systems?"],
["How can I improve my productivity while working from home?"],
["Create a meal plan for a vegetarian diet that provides sufficient protein."]
]
# Initialize the model at startup
print("Pre-loading the model...")
try:
model, tokenizer = load_model()
except Exception as e:
print(f"Error loading model: {str(e)}")
print("The app will still start, but you may need to check your model path.")
# Create the Gradio interface
with gr.Blocks(title="Supervised Fine Tuned (SFT) Phi-2 with QLoRA Adapters") as demo:
gr.Markdown("# Supervised Fine Tuned (SFT) Phi-2 with QLoRA Adapters")
gr.Markdown("- Base model (foundation model) Phi-2\n"
"- Supervised Fine Tuned (SFT) method is used to fine-tune the model on [OpenAssistant dataset](https://huggingface.co/datasets/OpenAssistant/oasst1?row=0)\n"
"- QLoRA Adapters are used to reduce the number of parameters in the model\n"
"- This gives the model an ability to answer questions rather than just generating text\n"
"- Chat with SFT Phi-2 model with QLoRA Adapters")
chatbot = gr.Chatbot(height=500)
with gr.Row():
msg = gr.Textbox(
label="Type your message here",
placeholder="Ask me anything...",
show_label=False,
scale=9
)
send_btn = gr.Button("Send", scale=1)
clear = gr.Button("Clear Chat")
# Add examples section
gr.Markdown("### Example Capabilities")
gr.Examples(
examples=examples,
inputs=msg,
outputs=chatbot,
fn=generate_response,
cache_examples=False,
examples_per_page=5
)
# Set up event handlers
send_btn.click(generate_response, [msg, chatbot], [chatbot]).then(
lambda: "", None, msg # Clear the input box after sending
)
msg.submit(generate_response, [msg, chatbot], [chatbot]).then(
lambda: "", None, msg # Clear the input box after sending
)
clear.click(lambda: [], None, chatbot)
# Launch the app
if __name__ == "__main__":
# Check GPU status
if torch.cuda.is_available():
print(f"CUDA available: {torch.cuda.get_device_name(0)}")
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
else:
print("CUDA not available, using CPU. This will be very slow for inference.")
# Launch the Gradio app
demo.launch(share=True)