tablecell-htr / train_trocr.py
MikkoLipsanen's picture
Update train_trocr.py
dc41f2f verified
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
import argparse
from evaluate import load
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
import torchvision.transforms as transforms
from augments import RandAug, RandRotate
parser = argparse.ArgumentParser('arguments for the code')
parser.add_argument('--root_path', type=str, default="",
help='Root path to data files.')
parser.add_argument('--tr_data_path', type=str, default="/path/to/train_data.csv",
help='Path to .csv file containing the training data.')
parser.add_argument('--val_data_path', type=str, default="/path/to/val_data.csv",
help='Path to .csv file containing the validation data.')
parser.add_argument('--output_path', type=str, default="./output/path/",
help='Path for saving training results.')
parser.add_argument('--model_path', type=str, default="/model/path/",
help='Path to trocr model')
parser.add_argument('--processor_path', type=str, default="/processor/path/",
help='Path to trocr processor')
parser.add_argument('--epochs', type=int, default=15,
help='Training epochs.')
parser.add_argument('--batch_size', type=int, default=16,
help='Training epochs.')
parser.add_argument('--device', type=str, default="cuda:0",
help='Device used for training.')
parser.add_argument('--augment', type=int, default=0,
help='Defines if image augmentations are used during training.')
args = parser.parse_args()
# Initialize processor and model
processor = TrOCRProcessor.from_pretrained(args.processor_path)
model = VisionEncoderDecoderModel.from_pretrained(args.model_path)
model.to(args.device)
# Initialize metrics
cer_metric = load("cer")
wer_metric = load("wer")
# Load train and validation data to dataframes
train_df = pd.read_csv(args.tr_data_path)
val_df = pd.read_csv(args.val_data_path)
# Reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
val_df.reset_index(drop=True, inplace=True)
# Torch dataset
class TextlineDataset(Dataset):
def __init__(self, root_dir, df, processor, max_target_length=128, augment=False):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.max_target_length = max_target_length
self.augment = augment
self.augmentator = RandAug()
self.rotator = RandRotate()
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# get file name + text
file_name = self.df['file_name'][idx]
text = self.df['text'][idx]
# prepare image (i.e. resize + normalize)
image = Image.open(self.root_dir + file_name).convert("RGB")
if self.augment:
image = self.augmentator(image)
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = self.processor.tokenizer(text,
padding="max_length", truncation=True,
max_length=self.max_target_length).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
#encoding = {"pixel_values": pixel_values.squeeze(0),"labels":labels}
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
return encoding
# Create train and validation datasets
train_dataset = TextlineDataset(root_dir=args.root_path,
df=train_df,
processor=processor,
augment=args.augment)
eval_dataset = TextlineDataset(root_dir=args.root_path,
df=val_df,
processor=processor,
augment=False)
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))
# Define model configuration
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 1
# Set arguments for model training
# For all argumenst see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
eval_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=50,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
load_best_model_at_end=True,
metric_for_best_model='cer',
greater_is_better=False,
fp16=True,
num_train_epochs=args.epochs,
save_total_limit=1,
output_dir=args.output_path,
optim='adamw_torch'
)
# Function for computing CER and WER metrics for the prediction results
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer, "wer": wer}
# Instantiate trainer
# For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.image_processor,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
)
# Train the model
trainer.train()
#trainer.train(resume_from_checkpoint = True)
model.save_pretrained(args.output_path)
processor.save_pretrained(args.output_path + "/processor")