diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 44d58ebb45..e5946d31ab 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -22,6 +22,7 @@ import inspect import json import math +import multiprocessing import os import random import shutil @@ -87,6 +88,7 @@ find_executable_batch_size, get_last_checkpoint, has_length, + seed_worker, ) from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments from transformers.utils import ( @@ -333,6 +335,157 @@ def _move_model_to_device(self, model, device): if self.args.use_habana and hasattr(model, "tie_weights"): model.tie_weights() + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + if self.args.dataloader_num_workers > 0: + multiprocessing_context = multiprocessing.get_context("forkserver") + else: + multiprocessing_context = multiprocessing.get_context("fork") + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + "multiprocessing_context": multiprocessing_context, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + dataloader_params["multiprocessing_context"] = multiprocessing_context + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*): + If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" + if ( + hasattr(self, "_eval_dataloaders") + and dataloader_key in self._eval_dataloaders + and self.args.dataloader_persistent_workers + ): + return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) + + eval_dataset = ( + self.eval_dataset[eval_dataset] + if isinstance(eval_dataset, str) + else eval_dataset + if eval_dataset is not None + else self.eval_dataset + ) + data_collator = self.data_collator + + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") + + if self.args.dataloader_num_workers > 0: + multiprocessing_context = multiprocessing.get_context("forkserver") + else: + multiprocessing_context = multiprocessing.get_context("fork") + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + "multiprocessing_context": multiprocessing_context, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + dataloader_params["multiprocessing_context"] = multiprocessing_context + + # accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version + eval_dataloader = DataLoader(eval_dataset, **dataloader_params) + if self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = eval_dataloader + else: + self._eval_dataloaders = {dataloader_key: eval_dataloader} + + return self.accelerator.prepare(eval_dataloader) + + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: + """ + Returns the test [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + test_dataset (`torch.utils.data.Dataset`, *optional*): + The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. It must implement `__len__`. + """ + data_collator = self.data_collator + + if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): + test_dataset = self._remove_unused_columns(test_dataset, description="test") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="test") + + if self.args.dataloader_num_workers > 0: + multiprocessing_context = multiprocessing.get_context("forkserver") + else: + multiprocessing_context = multiprocessing.get_context("fork") + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + "multiprocessing_context": multiprocessing_context, + } + + if not isinstance(test_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(test_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + dataloader_params["multiprocessing_context"] = multiprocessing_context + + # We use the same batch_size as for eval. + return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): return None