-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix training crash issue on multi-nodes when dataloader_num_workers>0 #1721
base: main
Are you sure you want to change the base?
Conversation
…der_num_workers>0
@Wei-Lin-Intel , Also do a "make style" and check for any errors if you have not done so already. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See possible code change and adapt + test.
@@ -333,6 +335,157 @@ def _move_model_to_device(self, model, device): | |||
if self.args.use_habana and hasattr(model, "tie_weights"): | |||
model.tie_weights() | |||
|
|||
def get_train_dataloader(self) -> DataLoader: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest refactoring this code to avoid duplicate code, etc. Here is a possible refactoring (not tested).
def get_multiprocessing_context(self) -> 'multiprocessing.context.BaseContext':
context_type = "forkserver" if self.args.dataloader_num_workers > 0 else "fork"
return multiprocessing.get_context(context_type)
def preprocess_dataset(self, dataset, description: str, data_collator):
if is_datasets_available() and isinstance(dataset, datasets.Dataset):
return self._remove_unused_columns(dataset, description=description), data_collator
else:
return dataset, self._get_collator_with_removed_columns(data_collator, description=description)
def get_dataloader_params(self, batch_size, dataset, data_collator) -> dict:
context = self.get_multiprocessing_context()
dataloader_params = {
"batch_size": batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"multiprocessing_context": context,
}
if not isinstance(dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler() if batch_size == self._train_batch_size else self._get_eval_sampler(dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return dataloader_params
def get_dataloader(self, dataset, batch_size, description: str) -> DataLoader:
if dataset is None:
raise ValueError(f"Trainer: {description} requires a dataset.")
dataset, data_collator = self.preprocess_dataset(dataset, description=description, data_collator=self.data_collator)
dataloader_params = self.get_dataloader_params(batch_size, dataset, data_collator)
return DataLoader(dataset, **dataloader_params)
def get_train_dataloader(self) -> DataLoader:
return self.accelerator.prepare(self.get_dataloader(self.train_dataset, self._train_batch_size, "train_dataset"))
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
if isinstance(eval_dataset, str):
eval_dataset = self.eval_dataset[eval_dataset]
elif eval_dataset is None:
eval_dataset = self.eval_dataset
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
if dataloader_key in self._eval_dataloaders and self.args.dataloader_persistent_workers:
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
eval_dataloader = self.get_dataloader(eval_dataset, self.args.eval_batch_size, "evaluation")
if self.args.dataloader_persistent_workers:
self._eval_dataloaders[dataloader_key] = eval_dataloader
return self.accelerator.prepare(eval_dataloader)
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
return self.accelerator.prepare(self.get_dataloader(test_dataset, self.args.eval_batch_size, "test"))
This feature requires 2 nodes with NIC connection to test... I am not sure if our pytest can handle such case? |
|
Sure, I will provide a test case for 1 node.
|
@Wei-Lin-Intel , Please provide tests and results for 1node/8hpu and also I suggest running at least some of the slow tests where this option is used. You see the tests/baselines directory to see where this option is used e.g., .../optimum-habana/tests/test_examples.py |
What does this PR do?
Fixes # (issue)
Relevant Issue: SW-207456
Background: In Gaudi2 Host NIC environment, it is found that the multi-nodes training would be stuck in "pt_data_worker" stage (Synapse 1.19), or throwing errors like
RuntimeError: DataLoader worker (pid(s) 12844) exited unexpectedly
(Synapse 1.17) whendataloader_num_workers
is set to larger than 0.According to the habana document torch-multiprocessing-for-dataloaders, the default start method of dataloader is
fork
which may result in undefined behavior. Thus it is better to setmultiprocessing_context
asforkserver
orspawn
in the initialization stage of Gaudi Trainer whendataloader_num_workers
> 0.In Unix system,
forkserver
would be faster thanspawn
to start a new process, and only necessary resources would be inherited. Thusforkserver
is preferred. In this PR, such change has been applied toget_train_dataloader
,get_eval_dataloader
, andget_test_dataloader
, respectively.Before submitting