Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Gradio application for inference with Phi-2 model using LoRA/QLoRA adapters. | |
Pre-loads the model and provides a simple chat interface. | |
""" | |
import os | |
import time | |
import torch | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from peft import PeftModel | |
# Define constants | |
DEFAULT_MODEL_PATH = "./adapters" # Path to the trained adapters | |
DEFAULT_BASE_MODEL = "microsoft/phi-2" # Base model name | |
DEFAULT_MAX_NEW_TOKENS = 512 | |
DEFAULT_TEMPERATURE = 0.7 | |
DEFAULT_TOP_P = 0.9 | |
DEFAULT_TOP_K = 50 | |
# Global variables to store the model and tokenizer | |
model = None | |
tokenizer = None | |
def load_model( | |
model_path=DEFAULT_MODEL_PATH, | |
base_model=DEFAULT_BASE_MODEL, | |
use_qlora=True, | |
device="cuda" | |
): | |
""" | |
Load the base model and adapter weights. | |
""" | |
global model, tokenizer | |
print(f"Loading tokenizer from {base_model}...") | |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Configure model loading parameters | |
model_kwargs = {"trust_remote_code": True} | |
# Set up quantization for QLoRA if enabled | |
if use_qlora: | |
print("Using 4-bit quantization (QLoRA)") | |
compute_dtype = torch.float16 | |
if torch.cuda.is_bf16_supported(): | |
compute_dtype = torch.bfloat16 | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=compute_dtype, | |
bnb_4bit_use_double_quant=True | |
) | |
model_kwargs["quantization_config"] = quantization_config | |
else: | |
model_kwargs["torch_dtype"] = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Check if adapter path exists | |
if not os.path.exists(model_path): | |
print(f"Warning: Model path '{model_path}' does not exist. Using base model only.") | |
# Load base model | |
print(f"Loading base model {base_model}...") | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model, | |
**model_kwargs | |
) | |
# Load adapter weights if available | |
if os.path.exists(model_path) and os.path.exists(os.path.join(model_path, "adapter_config.json")): | |
print(f"Loading {'QLoRA' if use_qlora else 'LoRA'} adapters from {model_path}...") | |
model = PeftModel.from_pretrained(base_model, model_path) | |
# Special handling for QLoRA - move norm layers to float32 for stability | |
# and ensure model and adapter layers have consistent dtypes | |
if use_qlora: | |
print("Harmonizing model layer dtypes...") | |
working_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# First make sure important parts are in float16/32 | |
for name, module in model.named_modules(): | |
if any(x in name for x in ["lm_head", "embed_tokens"]): | |
module.to(working_dtype) | |
elif "norm" in name: | |
module.to(torch.float32) # Norms should be in fp32 for stability | |
else: | |
model = base_model | |
print("Using base model without adapters") | |
# Move model to device | |
device = torch.device(device if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
print(f"Model loaded successfully and moved to {device}!") | |
return model, tokenizer | |
def generate_response(prompt, chat_history): | |
""" | |
Generate text response from the model. | |
""" | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
return chat_history + [(prompt, "Model not loaded yet. Please wait a moment.")] | |
# Format prompt for Phi-2 | |
formatted_prompt = f"Instruct: {prompt}\nOutput:" | |
# Tokenize input prompt | |
device = next(model.parameters()).device | |
input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt").to(device) | |
attention_mask = torch.ones_like(input_ids).to(device) | |
# Generate text with robust error handling | |
try: | |
with torch.no_grad(): | |
# Explicit type casting | |
input_ids = input_ids.to(torch.long) # IDs should always be long | |
attention_mask = attention_mask.to(torch.float16 if torch.cuda.is_available() else torch.float32) | |
# First attempt with simple parameters | |
generated_ids = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=DEFAULT_MAX_NEW_TOKENS, | |
do_sample=True, | |
temperature=DEFAULT_TEMPERATURE, | |
top_p=DEFAULT_TOP_P, | |
top_k=DEFAULT_TOP_K, | |
) | |
except Exception as e: | |
print(f"Generation error: {str(e)}") | |
try: | |
# Fallback: Try with model in eval with forced dtype | |
print("Attempting fallback generation...") | |
with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', | |
dtype=torch.float16 if torch.cuda.is_available() else torch.float32): | |
generated_ids = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=DEFAULT_MAX_NEW_TOKENS, | |
do_sample=False, # Use greedy decoding for more stability | |
) | |
except Exception as e2: | |
return chat_history + [(prompt, f"Error generating response: {str(e2)}")] | |
# Decode the generated text | |
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
# Extract just the output part | |
output = generated_text.split("Output:")[1].strip() if "Output:" in generated_text else generated_text | |
# Update chat history | |
return chat_history + [(prompt, output)] | |
# Example prompts to demonstrate model capabilities | |
examples = [ | |
["Explain the concept of quantum computing in simple terms."], | |
["Write a short story about a robot that learns to paint."], | |
["What are some ethical considerations when developing AI systems?"], | |
["How can I improve my productivity while working from home?"], | |
["Create a meal plan for a vegetarian diet that provides sufficient protein."] | |
] | |
# Initialize the model at startup | |
print("Pre-loading the model...") | |
try: | |
model, tokenizer = load_model() | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
print("The app will still start, but you may need to check your model path.") | |
# Create the Gradio interface | |
with gr.Blocks(title="Supervised Fine Tuned (SFT) Phi-2 with QLoRA Adapters") as demo: | |
gr.Markdown("# Supervised Fine Tuned (SFT) Phi-2 with QLoRA Adapters") | |
gr.Markdown("- Base model (foundation model) Phi-2\n" | |
"- Supervised Fine Tuned (SFT) method is used to fine-tune the model on [OpenAssistant dataset](https://huggingface.co/datasets/OpenAssistant/oasst1?row=0)\n" | |
"- QLoRA Adapters are used to reduce the number of parameters in the model\n" | |
"- This gives the model an ability to answer questions rather than just generating text\n" | |
"- Chat with SFT Phi-2 model with QLoRA Adapters") | |
chatbot = gr.Chatbot(height=500) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="Type your message here", | |
placeholder="Ask me anything...", | |
show_label=False, | |
scale=9 | |
) | |
send_btn = gr.Button("Send", scale=1) | |
clear = gr.Button("Clear Chat") | |
# Add examples section | |
gr.Markdown("### Example Capabilities") | |
gr.Examples( | |
examples=examples, | |
inputs=msg, | |
outputs=chatbot, | |
fn=generate_response, | |
cache_examples=False, | |
examples_per_page=5 | |
) | |
# Set up event handlers | |
send_btn.click(generate_response, [msg, chatbot], [chatbot]).then( | |
lambda: "", None, msg # Clear the input box after sending | |
) | |
msg.submit(generate_response, [msg, chatbot], [chatbot]).then( | |
lambda: "", None, msg # Clear the input box after sending | |
) | |
clear.click(lambda: [], None, chatbot) | |
# Launch the app | |
if __name__ == "__main__": | |
# Check GPU status | |
if torch.cuda.is_available(): | |
print(f"CUDA available: {torch.cuda.get_device_name(0)}") | |
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
print(f"Memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") | |
else: | |
print("CUDA not available, using CPU. This will be very slow for inference.") | |
# Launch the Gradio app | |
demo.launch(share=True) |