PrecollatorForGeneAndCellClassification has no attribute save_pretrained

#528
by zzbb2266 - opened

Hi there, thanks for the great pipeline and maintenance! Currently I am trying to replicate the cell_classification.ipynb from the example folder, but got an error from the PrecollatorForGeneAndCellClassification module just after finishing the first epoch:

Epoch	Training Loss	Validation Loss	Accuracy	Macro F1
0	0.133900	0.411489	0.883736	0.684715
  0%|          | 0/1 [28:44<?, ?it/s]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], line 6
      1 train_valid_id_split_dict = {"attr_key": "individual",
      2                             "train": train_ids,
      3                             "eval": eval_ids}
      5 # Example 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors
----> 6 all_metrics = cc.validate(model_directory="/home/data1/geneformer/Geneformer/gf-6L-30M-i2048/",
      7                           prepared_input_data_file=f"{output_dir}/{output_prefix}_labeled_train.dataset",
      8                           id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
      9                           output_directory=output_dir,
     10                           output_prefix=output_prefix,
     11                           split_id_dict=train_valid_id_split_dict)
     12                           # to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/geneformer/classifier.py:791, in Classifier.validate(self, model_directory, prepared_input_data_file, id_class_dict_file, output_directory, output_prefix, split_id_dict, attr_to_split, attr_to_balance, gene_balance, max_trials, pval_threshold, save_eval_output, predict_eval, predict_trainer, n_hyperopt_trials, save_gene_split_datasets, debug_gene_split_datasets)
    789     train_data = data.select(train_indices)
    790 if n_hyperopt_trials == 0:
--> 791     trainer = self.train_classifier(
    792         model_directory,
    793         num_classes,
    794         train_data,
    795         eval_data,
    796         ksplit_output_dir,
    797         predict_trainer,
    798     )
    799 else:
    800     trainer = self.hyperopt_classifier(
    801         model_directory,
    802         num_classes,
   (...)    806         n_trials=n_hyperopt_trials,
    807     )

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/geneformer/classifier.py:1269, in Classifier.train_classifier(self, model_directory, num_classes, train_data, eval_data, output_directory, predict)
   1259 trainer = Trainer(
   1260     model=model,
   1261     args=training_args_init,
   (...)   1265     compute_metrics=cu.compute_metrics,
   1266 )
   1268 # train the classifier
-> 1269 trainer.train()
   1270 trainer.save_model(output_directory)
   1271 if predict is True:
   1272     # make eval predictions and save predictions and metrics

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:2245, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2243         hf_hub_utils.enable_progress_bars()
   2244 else:
-> 2245     return inner_training_loop(
   2246         args=args,
   2247         resume_from_checkpoint=resume_from_checkpoint,
   2248         trial=trial,
   2249         ignore_keys_for_eval=ignore_keys_for_eval,
   2250     )

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:2661, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2658     self.control.should_training_stop = True
   2660 self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 2661 self._maybe_log_save_evaluate(
   2662     tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate
   2663 )
   2665 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
   2666     if is_torch_xla_available():
   2667         # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:3103, in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate)
   3100         self.control.should_save = is_new_best_metric
   3102 if self.control.should_save:
-> 3103     self._save_checkpoint(model, trial)
   3104     self.control = self.callback_handler.on_save(self.args, self.state, self.control)

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:3200, in Trainer._save_checkpoint(self, model, trial)
   3198 run_dir = self._get_output_dir(trial=trial)
   3199 output_dir = os.path.join(run_dir, checkpoint_folder)
-> 3200 self.save_model(output_dir, _internal_call=True)
   3202 if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
   3203     best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:3902, in Trainer.save_model(self, output_dir, _internal_call)
   3899         self.model_wrapped.save_checkpoint(output_dir)
   3901 elif self.args.should_save:
-> 3902     self._save(output_dir)
   3904 # Push to the Hub when `save_model` is called by the user.
   3905 if self.args.push_to_hub and not _internal_call:

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:4018, in Trainer._save(self, output_dir, state_dict)
   4012 elif (
   4013     self.data_collator is not None
   4014     and hasattr(self.data_collator, "tokenizer")
   4015     and self.data_collator.tokenizer is not None
   4016 ):
   4017     logger.info("Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`")
-> 4018     self.data_collator.tokenizer.save_pretrained(output_dir)
   4020 # Good practice: save your training arguments together with the trained model
   4021 torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/tokenization_utils_base.py:1108, in SpecialTokensMixin.__getattr__(self, key)
   1105         return self.convert_tokens_to_ids(attr_as_tokens) if attr_as_tokens is not None else None
   1107 if key not in self.__dict__:
-> 1108     raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
   1109 else:
   1110     return super().__getattr__(key)

AttributeError: PrecollatorForGeneAndCellClassification has no attribute save_pretrained

Should I add the parameter save_pretrained manually to trainer.py or classifier.py ? Seems that the issue may due to a different version of package transformers to me. my env is:

Package                   Version
------------------------- -----------
absl-py                   2.2.2
accelerate                1.6.0
accumulation_tree         0.6.4
aiohappyeyeballs          2.6.1
aiohttp                   3.11.18
aiosignal                 1.3.2
alembic                   1.15.2
anndata                   0.11.4
annotated-types           0.7.0
array_api_compat          1.11.2
asttokens                 3.0.0
attrs                     25.3.0
certifi                   2025.4.26
charset-normalizer        3.4.1
click                     8.1.8
colorlog                  6.9.0
comm                      0.2.2
contourpy                 1.3.2
cycler                    0.12.1
datasets                  3.5.1
debugpy                   1.8.14
decorator                 5.2.1
dill                      0.3.8
docker-pycreds            0.4.0
exceptiongroup            1.2.2
executing                 2.2.0
filelock                  3.18.0
fonttools                 4.57.0
frozenlist                1.6.0
fsspec                    2025.3.0
geneformer                0.1.0
gitdb                     4.0.12
GitPython                 3.1.44
greenlet                  3.2.1
grpcio                    1.71.0
h5py                      3.13.0
huggingface-hub           0.30.2
idna                      3.10
importlib_metadata        8.6.1
ipykernel                 6.29.5
ipython                   9.2.0
ipython_pygments_lexers   1.1.1
jedi                      0.19.2
Jinja2                    3.1.6
joblib                    1.4.2
jsonschema                4.23.0
jsonschema-specifications 2025.4.1
jupyter_client            8.6.3
jupyter_core              5.7.2
kiwisolver                1.4.8
legacy-api-wrap           1.4.1
llvmlite                  0.44.0
loompy                    3.0.8
Mako                      1.3.10
Markdown                  3.8
MarkupSafe                3.0.2
matplotlib                3.10.1
matplotlib-inline         0.1.7
mpmath                    1.3.0
msgpack                   1.1.0
multidict                 6.4.3
multiprocess              0.70.16
natsort                   8.4.0
nest_asyncio              1.6.0
networkx                  3.4.2
numba                     0.61.2
numpy                     2.2.5
numpy-groupies            0.11.2
nvidia-cublas-cu12        12.6.4.1
nvidia-cuda-cupti-cu12    12.6.80
nvidia-cuda-nvrtc-cu12    12.6.77
nvidia-cuda-runtime-cu12  12.6.77
nvidia-cudnn-cu12         9.5.1.17
nvidia-cufft-cu12         11.3.0.4
nvidia-cufile-cu12        1.11.1.6
nvidia-curand-cu12        10.3.7.77
nvidia-cusolver-cu12      11.7.1.2
nvidia-cusparse-cu12      12.5.4.2
nvidia-cusparselt-cu12    0.6.3
nvidia-nccl-cu12          2.26.2
nvidia-nvjitlink-cu12     12.6.85
nvidia-nvtx-cu12          12.6.77
optuna                    4.3.0
optuna-integration        4.3.0
packaging                 25.0
pandas                    2.2.3
parso                     0.8.4
patsy                     1.0.1
peft                      0.15.2
pexpect                   4.9.0
pickleshare               0.7.5
pillow                    11.2.1
pip                       25.1
platformdirs              4.3.7
prompt_toolkit            3.0.51
propcache                 0.3.1
protobuf                  6.30.2
psutil                    7.0.0
ptyprocess                0.7.0
pure_eval                 0.2.3
pyarrow                   20.0.0
pydantic                  2.11.4
pydantic_core             2.33.2
Pygments                  2.19.1
pynndescent               0.5.13
pyparsing                 3.2.3
python-dateutil           2.9.0.post0
pytz                      2025.2
pyudorandom               1.0.0
PyYAML                    6.0.2
pyzmq                     26.4.0
ray                       2.45.0
referencing               0.36.2
regex                     2024.11.6
requests                  2.32.3
rpds-py                   0.24.0
safetensors               0.5.3
scanpy                    1.11.1
scikit-learn              1.5.2
scikit-misc               0.5.1
scipy                     1.15.2
seaborn                   0.13.2
sentry-sdk                2.27.0
session-info2             0.1.2
setproctitle              1.3.6
setuptools                80.0.1
six                       1.17.0
smmap                     5.0.2
SQLAlchemy                2.0.40
stack_data                0.6.3
statsmodels               0.14.4
sympy                     1.14.0
tdigest                   0.5.2.2
tensorboard               2.19.0
tensorboard-data-server   0.7.2
threadpoolctl             3.6.0
tokenizers                0.21.1
torch                     2.7.0
tornado                   6.4.2
tqdm                      4.67.1
traitlets                 5.14.3
transformers              4.51.3
triton                    3.3.0
typing_extensions         4.13.2
typing-inspection         0.4.0
tzdata                    2025.2
umap-learn                0.5.7
urllib3                   2.4.0
wandb                     0.19.10
wcwidth                   0.2.13
Werkzeug                  3.1.3
xxhash                    3.5.0
yarl                      1.20.0
zipp                      3.21.0

Colud I know how could I solve this issue, or jsut downgrade the package transformers? Thank you!

Update: This issue has already been solved by downgrading the version of package transformers along with another issue evaluation_strategy. FYI

Sign up or log in to comment