File size: 8,736 Bytes
0238edd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#!/usr/bin/env python3
"""
Gradio application for inference with Phi-2 model using LoRA/QLoRA adapters.
Pre-loads the model and provides a simple chat interface.
"""

import os
import time
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

# Define constants
DEFAULT_MODEL_PATH = "./adapters"  # Path to the trained adapters
DEFAULT_BASE_MODEL = "microsoft/phi-2"  # Base model name
DEFAULT_MAX_NEW_TOKENS = 512
DEFAULT_TEMPERATURE = 0.7
DEFAULT_TOP_P = 0.9
DEFAULT_TOP_K = 50

# Global variables to store the model and tokenizer
model = None
tokenizer = None

def load_model(
    model_path=DEFAULT_MODEL_PATH,
    base_model=DEFAULT_BASE_MODEL,
    use_qlora=True,
    device="cuda"
):
    """
    Load the base model and adapter weights.
    """
    global model, tokenizer
    
    print(f"Loading tokenizer from {base_model}...")
    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Configure model loading parameters
    model_kwargs = {"trust_remote_code": True}
    
    # Set up quantization for QLoRA if enabled
    if use_qlora:
        print("Using 4-bit quantization (QLoRA)")
        compute_dtype = torch.float16
        if torch.cuda.is_bf16_supported():
            compute_dtype = torch.bfloat16
            
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=True
        )
        model_kwargs["quantization_config"] = quantization_config
    else:
        model_kwargs["torch_dtype"] = torch.float16 if torch.cuda.is_available() else torch.float32
    
    # Check if adapter path exists
    if not os.path.exists(model_path):
        print(f"Warning: Model path '{model_path}' does not exist. Using base model only.")
    
    # Load base model
    print(f"Loading base model {base_model}...")
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model, 
        **model_kwargs
    )
    
    # Load adapter weights if available
    if os.path.exists(model_path) and os.path.exists(os.path.join(model_path, "adapter_config.json")):
        print(f"Loading {'QLoRA' if use_qlora else 'LoRA'} adapters from {model_path}...")
        model = PeftModel.from_pretrained(base_model, model_path)
        
        # Special handling for QLoRA - move norm layers to float32 for stability
        # and ensure model and adapter layers have consistent dtypes
        if use_qlora:
            print("Harmonizing model layer dtypes...")
            working_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
            
            # First make sure important parts are in float16/32
            for name, module in model.named_modules():
                if any(x in name for x in ["lm_head", "embed_tokens"]):
                    module.to(working_dtype)
                elif "norm" in name:
                    module.to(torch.float32)  # Norms should be in fp32 for stability
    else:
        model = base_model
        print("Using base model without adapters")
    
    # Move model to device
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    print(f"Model loaded successfully and moved to {device}!")
    return model, tokenizer


def generate_response(prompt, chat_history):
    """
    Generate text response from the model.
    """
    global model, tokenizer
    
    if model is None or tokenizer is None:
        return chat_history + [(prompt, "Model not loaded yet. Please wait a moment.")]
    
    # Format prompt for Phi-2
    formatted_prompt = f"Instruct: {prompt}\nOutput:"
    
    # Tokenize input prompt
    device = next(model.parameters()).device
    input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids).to(device)
    
    # Generate text with robust error handling
    try:
        with torch.no_grad():
            # Explicit type casting
            input_ids = input_ids.to(torch.long)  # IDs should always be long
            attention_mask = attention_mask.to(torch.float16 if torch.cuda.is_available() else torch.float32)
            
            # First attempt with simple parameters
            generated_ids = model.generate(
                input_ids=input_ids,
                max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
                do_sample=True,
                temperature=DEFAULT_TEMPERATURE,
                top_p=DEFAULT_TOP_P,
                top_k=DEFAULT_TOP_K,
            )
    except Exception as e:
        print(f"Generation error: {str(e)}")
        try:
            # Fallback: Try with model in eval with forced dtype
            print("Attempting fallback generation...")
            with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', 
                              dtype=torch.float16 if torch.cuda.is_available() else torch.float32):
                generated_ids = model.generate(
                    input_ids=input_ids,
                    max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
                    do_sample=False,  # Use greedy decoding for more stability
                )
        except Exception as e2:
            return chat_history + [(prompt, f"Error generating response: {str(e2)}")]
    
    # Decode the generated text
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    # Extract just the output part
    output = generated_text.split("Output:")[1].strip() if "Output:" in generated_text else generated_text
    
    # Update chat history
    return chat_history + [(prompt, output)]


# Example prompts to demonstrate model capabilities
examples = [
    ["Explain the concept of quantum computing in simple terms."],
    ["Write a short story about a robot that learns to paint."],
    ["What are some ethical considerations when developing AI systems?"],
    ["How can I improve my productivity while working from home?"],
    ["Create a meal plan for a vegetarian diet that provides sufficient protein."]
]


# Initialize the model at startup
print("Pre-loading the model...")
try:
    model, tokenizer = load_model()
except Exception as e:
    print(f"Error loading model: {str(e)}")
    print("The app will still start, but you may need to check your model path.")

# Create the Gradio interface
with gr.Blocks(title="Supervised Fine Tuned (SFT) Phi-2 with QLoRA Adapters") as demo:
    gr.Markdown("# Supervised Fine Tuned (SFT) Phi-2 with QLoRA Adapters")
    gr.Markdown("- Base model (foundation model) Phi-2\n"
                 "- Supervised Fine Tuned (SFT) method is used to fine-tune the model on [OpenAssistant dataset](https://huggingface.co/datasets/OpenAssistant/oasst1?row=0)\n"
                 "- QLoRA Adapters are used to reduce the number of parameters in the model\n"
                 "- This gives the model an ability to answer questions rather than just generating text\n"
                 "- Chat with SFT Phi-2 model with QLoRA Adapters")
    
    chatbot = gr.Chatbot(height=500)
    
    with gr.Row():
        msg = gr.Textbox(
            label="Type your message here", 
            placeholder="Ask me anything...",
            show_label=False,
            scale=9
        )
        send_btn = gr.Button("Send", scale=1)
    
    clear = gr.Button("Clear Chat")
    
    # Add examples section
    gr.Markdown("### Example Capabilities")
    gr.Examples(
        examples=examples,
        inputs=msg,
        outputs=chatbot,
        fn=generate_response,
        cache_examples=False,
        examples_per_page=5
    )
    
    # Set up event handlers
    send_btn.click(generate_response, [msg, chatbot], [chatbot]).then(
        lambda: "", None, msg  # Clear the input box after sending
    )
    
    msg.submit(generate_response, [msg, chatbot], [chatbot]).then(
        lambda: "", None, msg  # Clear the input box after sending
    )
    
    clear.click(lambda: [], None, chatbot)

# Launch the app
if __name__ == "__main__":
    # Check GPU status
    if torch.cuda.is_available():
        print(f"CUDA available: {torch.cuda.get_device_name(0)}")
        print(f"Memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
        print(f"Memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    else:
        print("CUDA not available, using CPU. This will be very slow for inference.")
    
    # Launch the Gradio app
    demo.launch(share=True)