Travereel-Model-V1 / src /generate.py
Rahman Azhar
Switch to FLAN-T5 model for better accessibility
70ee247
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
import argparse
import json
class ItineraryGenerator:
def __init__(self, model_path: str):
self.tokenizer = T5Tokenizer.from_pretrained(model_path)
self.model = T5ForConditionalGeneration.from_pretrained(
model_path,
device_map="auto"
)
self.model.eval()
def generate_itinerary(
self,
destination: str,
duration: int,
preferences: str,
budget: str,
max_length: int = 1024,
temperature: float = 0.7,
top_p: float = 0.9,
) -> str:
prompt = f"""Generate a detailed travel itinerary for {destination} for {duration} days.
Preferences: {preferences}
Budget: {budget}
Detailed Itinerary:"""
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True).to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id
)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the generated itinerary part
itinerary = generated_text[len(prompt):]
return itinerary.strip()
def main():
parser = argparse.ArgumentParser(description="Generate travel itineraries using fine-tuned LLaMA model")
parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model")
parser.add_argument("--destination", type=str, required=True, help="Travel destination")
parser.add_argument("--duration", type=int, required=True, help="Number of days")
parser.add_argument("--preferences", type=str, required=True, help="Travel preferences")
parser.add_argument("--budget", type=str, required=True, help="Travel budget")
parser.add_argument("--output", type=str, help="Output file path (optional)")
args = parser.parse_args()
generator = ItineraryGenerator(args.model_path)
itinerary = generator.generate_itinerary(
destination=args.destination,
duration=args.duration,
preferences=args.preferences,
budget=args.budget
)
output = {
"destination": args.destination,
"duration": args.duration,
"preferences": args.preferences,
"budget": args.budget,
"generated_itinerary": itinerary
}
if args.output:
with open(args.output, 'w') as f:
json.dump(output, f, indent=2)
print(f"Itinerary saved to {args.output}")
else:
print("\nGenerated Itinerary:")
print("=" * 50)
print(itinerary)
if __name__ == "__main__":
main()