|
import torch |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import argparse |
|
import json |
|
|
|
class ItineraryGenerator: |
|
def __init__(self, model_path: str): |
|
self.tokenizer = T5Tokenizer.from_pretrained(model_path) |
|
self.model = T5ForConditionalGeneration.from_pretrained( |
|
model_path, |
|
device_map="auto" |
|
) |
|
self.model.eval() |
|
|
|
def generate_itinerary( |
|
self, |
|
destination: str, |
|
duration: int, |
|
preferences: str, |
|
budget: str, |
|
max_length: int = 1024, |
|
temperature: float = 0.7, |
|
top_p: float = 0.9, |
|
) -> str: |
|
prompt = f"""Generate a detailed travel itinerary for {destination} for {duration} days. |
|
Preferences: {preferences} |
|
Budget: {budget} |
|
|
|
Detailed Itinerary:""" |
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True).to(self.model.device) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
num_return_sequences=1, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
itinerary = generated_text[len(prompt):] |
|
return itinerary.strip() |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Generate travel itineraries using fine-tuned LLaMA model") |
|
parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model") |
|
parser.add_argument("--destination", type=str, required=True, help="Travel destination") |
|
parser.add_argument("--duration", type=int, required=True, help="Number of days") |
|
parser.add_argument("--preferences", type=str, required=True, help="Travel preferences") |
|
parser.add_argument("--budget", type=str, required=True, help="Travel budget") |
|
parser.add_argument("--output", type=str, help="Output file path (optional)") |
|
|
|
args = parser.parse_args() |
|
|
|
generator = ItineraryGenerator(args.model_path) |
|
|
|
itinerary = generator.generate_itinerary( |
|
destination=args.destination, |
|
duration=args.duration, |
|
preferences=args.preferences, |
|
budget=args.budget |
|
) |
|
|
|
output = { |
|
"destination": args.destination, |
|
"duration": args.duration, |
|
"preferences": args.preferences, |
|
"budget": args.budget, |
|
"generated_itinerary": itinerary |
|
} |
|
|
|
if args.output: |
|
with open(args.output, 'w') as f: |
|
json.dump(output, f, indent=2) |
|
print(f"Itinerary saved to {args.output}") |
|
else: |
|
print("\nGenerated Itinerary:") |
|
print("=" * 50) |
|
print(itinerary) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|