Phi-2-fine-tuned-with-GRPO / compare_models.py
Ubuntu
inference script added
a4bab9c
raw
history blame contribute delete
6.5 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import argparse
import time
import sys
def stream_response(model, tokenizer, prompt, max_new_tokens=256):
"""Generate a streaming response from the model for the given prompt."""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Store the input length to identify the response part
input_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
# Initialize generation
generated_ids = inputs.input_ids.clone()
past_key_values = None
response_text = ""
# Add stop sequences to detect natural endings
stop_sequences = ["\n\n", "\nExercise:", "\nQuestion:"]
was_truncated = False # Track if response was truncated
# Generate tokens one by one
for _ in range(max_new_tokens):
with torch.no_grad():
# Forward pass
outputs = model(
input_ids=generated_ids[:, -1:] if past_key_values is not None else generated_ids,
past_key_values=past_key_values,
use_cache=True
)
# Get logits and past key values
logits = outputs.logits
past_key_values = outputs.past_key_values
# Sample next token
next_token_logits = logits[:, -1, :]
next_token_logits = next_token_logits / 0.7 # Apply temperature
# Apply top-p sampling
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > 0.9
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = -float('Inf')
# Sample from the filtered distribution
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to generated ids
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
# Decode the current token
current_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Extract only the new part (response)
if len(current_text) > input_length:
new_text = current_text[input_length:]
# Print only the new characters
new_chars = new_text[len(response_text):]
sys.stdout.write(new_chars)
sys.stdout.flush()
response_text = new_text
# Add a small delay to simulate typing
time.sleep(0.01)
# Stop if we generate an EOS token
if next_token[0, 0].item() == tokenizer.eos_token_id:
break
# Check for natural stopping points
if any(stop_seq in response_text for stop_seq in stop_sequences):
# If we find a stop sequence, only keep text up to that point
for stop_seq in stop_sequences:
if stop_seq in response_text:
stop_idx = response_text.find(stop_seq)
if stop_idx > 0: # Only trim if we have some content
was_truncated = True
response_text = response_text[:stop_idx]
sys.stdout.write("\n") # Add a newline for cleaner output
sys.stdout.flush()
return response_text, was_truncated
# Return the full response
return response_text, was_truncated
def main():
parser = argparse.ArgumentParser(description="Compare base and fine-tuned Phi-2 models")
parser.add_argument("--base-only", action="store_true", help="Use only the base model")
parser.add_argument("--finetuned-only", action="store_true", help="Use only the fine-tuned model")
parser.add_argument("--adapter-path", type=str, default="./phi2-grpo-qlora-final",
help="Path to the fine-tuned adapter")
args = parser.parse_args()
# Load the base model and tokenizer
base_model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Configure tokenizer
tokenizer.pad_token = tokenizer.eos_token
# Load models based on arguments
models = {}
if not args.finetuned_only:
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
models["Base Phi-2"] = base_model
if not args.base_only:
print(f"Loading fine-tuned model from {args.adapter_path}...")
# Load the base model first (with same quantization as during training)
base_model_for_ft = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Load the adapter on top of the base model
finetuned_model = PeftModel.from_pretrained(base_model_for_ft, args.adapter_path)
models["Fine-tuned Phi-2"] = finetuned_model
# Interactive prompt loop
print("\n" + "="*50)
print("Interactive Phi-2 Model Comparison (Streaming Mode)")
print("Type 'exit' to quit")
print("="*50 + "\n")
while True:
# Get user input
user_prompt = input("\nEnter your prompt: ")
if user_prompt.lower() == 'exit':
break
print("\n" + "-"*50)
# Generate responses from each model
for model_name, model in models.items():
print(f"\n{model_name} response:")
response, was_truncated = stream_response(model, tokenizer, user_prompt)
if was_truncated:
print("\n[Note: Response was truncated at a natural stopping point]")
print("\n" + "-"*30)
print("-"*50)
if __name__ == "__main__":
main()