|
import torch |
|
from transformers import ( |
|
T5ForConditionalGeneration, |
|
T5Tokenizer, |
|
TrainingArguments, |
|
Trainer, |
|
DataCollatorForSeq2Seq |
|
) |
|
from datasets import load_dataset |
|
import os |
|
import json |
|
from typing import Dict, List |
|
|
|
class ItineraryDataset(torch.utils.data.Dataset): |
|
def __init__(self, data_path: str, tokenizer, max_length: int = 512): |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
self.examples = self._load_data(data_path) |
|
|
|
def _load_data(self, data_path: str) -> List[Dict]: |
|
with open(data_path, 'r') as f: |
|
return json.load(f) |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def __getitem__(self, idx): |
|
example = self.examples[idx] |
|
prompt = f"""Generate a detailed travel itinerary for {example['destination']} for {example['duration']} days. |
|
Preferences: {example['preferences']} |
|
Budget: {example['budget']}""" |
|
|
|
target = example['itinerary'] |
|
|
|
|
|
combined = f"{prompt}\n{target}</s>" |
|
|
|
|
|
encodings = self.tokenizer( |
|
combined, |
|
truncation=True, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
return_tensors="pt" |
|
) |
|
|
|
return { |
|
"input_ids": encodings["input_ids"][0], |
|
"attention_mask": encodings["attention_mask"][0], |
|
"labels": encodings["input_ids"][0].clone() |
|
} |
|
|
|
def train_itinerary_model( |
|
model_name: str = "google/flan-t5-base", |
|
data_path: str = "data/itineraries.json", |
|
output_dir: str = "output", |
|
num_epochs: int = 3, |
|
batch_size: int = 4, |
|
learning_rate: float = 2e-5, |
|
): |
|
|
|
tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
model = T5ForConditionalGeneration.from_pretrained( |
|
model_name, |
|
device_map="auto" |
|
) |
|
|
|
|
|
dataset = ItineraryDataset(data_path, tokenizer) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=output_dir, |
|
num_train_epochs=num_epochs, |
|
per_device_train_batch_size=batch_size, |
|
gradient_accumulation_steps=4, |
|
learning_rate=learning_rate, |
|
warmup_steps=100, |
|
logging_steps=10, |
|
save_steps=100, |
|
fp16=True, |
|
report_to="tensorboard" |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset, |
|
data_collator=DataCollatorForSeq2Seq(tokenizer) |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
trainer.save_model() |
|
tokenizer.save_pretrained(output_dir) |
|
|
|
if __name__ == "__main__": |
|
train_itinerary_model() |
|
|