From 4a070bc5ec2cd0a2ccbf28c915efa57d83fb54a7 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 10 Apr 2024 17:49:12 +0200 Subject: [PATCH 1/8] fix hangup --- src/graphnet/data/datamodule.py | 102 ++++++++++++++++-- .../data/dataset/parquet/parquet_dataset.py | 3 + 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index f012c4abe..afda6471c 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -24,9 +24,12 @@ def __init__( dataset_args: Dict[str, Any], selection: Optional[Union[List[int], List[List[int]]]] = None, test_selection: Optional[Union[List[int], List[List[int]]]] = None, - train_dataloader_kwargs: Optional[Dict[str, Any]] = None, - validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, - test_dataloader_kwargs: Optional[Dict[str, Any]] = None, + train_dataloader_kwargs: Dict[str, Any] = { + "batch_size": 2, + "num_workers": 1, + }, + validation_dataloader_kwargs: Dict[str, Any] = None, + test_dataloader_kwargs: Dict[str, Any] = None, train_val_split: Optional[List[float]] = [0.9, 0.10], split_seed: int = 42, ) -> None: @@ -40,13 +43,14 @@ def __init__( selection: (Optional) a list of event id's used for training and validation, Default None. test_selection: (Optional) a list of event id's used for testing, - Default None. + Defaults to None. train_dataloader_kwargs: Arguments for the training DataLoader, - Default None. + Defaults{"batch_size": 2, "num_workers": 1}. validation_dataloader_kwargs: Arguments for the validation - DataLoader, Default None. + DataLoader. Defaults to + `train_dataloader_kwargs`. test_dataloader_kwargs: Arguments for the test DataLoader, - Default None. + Defaults to `train_dataloader_kwargs`. train_val_split (Optional): Split ratio for training and validation sets. Default is [0.9, 0.10]. split_seed: seed used for shuffling and splitting selections into @@ -61,9 +65,11 @@ def __init__( self._train_val_split = train_val_split or [0.0] self._rng = split_seed - self._train_dataloader_kwargs = train_dataloader_kwargs or {} - self._validation_dataloader_kwargs = validation_dataloader_kwargs or {} - self._test_dataloader_kwargs = test_dataloader_kwargs or {} + self._set_dataloader_kwargs( + train_dataloader_kwargs, + validation_dataloader_kwargs, + test_dataloader_kwargs, + ) # If multiple dataset paths are given, we should use EnsembleDataset self._use_ensemble_dataset = isinstance( @@ -72,6 +78,82 @@ def __init__( self.setup("fit") + def _set_dataloader_kwargs( + self, + train_dl_args: Dict[str, Any], + val_dl_args: Union[Dict[str, Any], None], + test_dl_args: Union[Dict[str, Any], None], + ) -> None: + """Copy train dataloader args to other dataloaders if not given. + + Also checks that ParquetDataset dataloaders have multiprocessing + context set to "spawn" as this is strictly required. + + See: https://docs.pola.rs/user-guide/misc/multiprocessing/ + """ + if val_dl_args is None: + self.info( + "No `val_dataloader_kwargs` given. This arg has " + "been set to `train_dataloader_kwargs` automatically." + ) + val_dl_args = deepcopy(train_dl_args) + if (test_dl_args is None) & (self._test_selection is not None): + test_dl_args = deepcopy(train_dl_args) + self.info( + "No `test_dataloader_kwargs` given. This arg has " + "been set to `train_dataloader_kwargs` automatically." + ) + + if self._dataset == ParquetDataset: + self._train_dataloader_kwargs = self._add_context( + train_dl_args, "training" + ) + self._validation_dataloader_kwargs = self._add_context( + val_dl_args, "validation" + ) + if self._test_selection is not None: + assert test_dl_args is not None + self._test_dataloader_kwargs = self._add_context( + test_dl_args, "test" + ) + else: + self._test_dataloader_kwargs = {} + + def _add_context( + self, dataloader_args: Dict[str, Any], dataloader_type: str + ) -> Dict[str, Any]: + """Handle assignment of `multiprocessing_context` arg to loaders. + + Datasets relying on threaded libraries often require the + multiprocessing context to be set to "spawn". This method will check + the arguments for this entry and throw and error if the field is + already assigned to a wrong value. If the value is not specified, it is + added automatically with a log entry. + """ + arg = "multiprocessing_context" + if arg in dataloader_args: + if dataloader_args[arg] != "spawn": + # Wrongly assigned by user + self.error( + "DataLoaders using `ParquetDataset` must have " + "multiprocessing_context = 'spawn'. " + f" Found '{dataloader_args[arg]}' in ", + f"{dataloader_type} dataloader.", + ) + raise ValueError("multiprocessing_context must be 'spawn'") + else: + # Correctly assigned by user + return dataloader_args + else: + # Forgotten assignment by user + dataloader_args[arg] = "spawn" + self.warning_once( + f"{self.__class__.__name__} has automatically " + "set multiprocessing_context = 'spawn' in " + f"{dataloader_type} dataloader." + ) + return dataloader_args + def prepare_data(self) -> None: """Prepare the dataset for training.""" # Download method for curated datasets. Method for download is diff --git a/src/graphnet/data/dataset/parquet/parquet_dataset.py b/src/graphnet/data/dataset/parquet/parquet_dataset.py index 0feec8da9..0186a7a72 100644 --- a/src/graphnet/data/dataset/parquet/parquet_dataset.py +++ b/src/graphnet/data/dataset/parquet/parquet_dataset.py @@ -57,6 +57,9 @@ def __init__( ): """Construct Dataset. + NOTE: DataLoaders using this Dataset should have + "multiprocessing_context = 'spawn'" set to avoid thread locking. + Args: path: Path to the file(s) from which this `Dataset` should read. pulsemaps: Name(s) of the pulse map series that should be used to From ab2cf69cf3a113a4d272067f5c0980da0a720b51 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Thu, 11 Apr 2024 08:29:45 +0200 Subject: [PATCH 2/8] update docstring in `_add_context` --- src/graphnet/data/datamodule.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index afda6471c..8eed6fa14 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -126,9 +126,9 @@ def _add_context( Datasets relying on threaded libraries often require the multiprocessing context to be set to "spawn". This method will check - the arguments for this entry and throw and error if the field is - already assigned to a wrong value. If the value is not specified, it is - added automatically with a log entry. + the arguments for this entry and throw an error if the field is already + assigned to a wrong value. If the value is not specified, it is added + automatically with a log entry. """ arg = "multiprocessing_context" if arg in dataloader_args: From 69b5c643ce049f41fbf51403baa7f0a2ce60cb66 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Thu, 11 Apr 2024 09:10:34 +0200 Subject: [PATCH 3/8] fix arg assignments for SQLite --- src/graphnet/data/datamodule.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 8eed6fa14..323e1ce1c 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -118,6 +118,10 @@ def _set_dataloader_kwargs( ) else: self._test_dataloader_kwargs = {} + else: + self._train_dataloader_kwargs = train_dl_args + self._validation_dataloader_kwargs = val_dl_args + self._test_dataloader_kwargs = test_dl_args or {} def _add_context( self, dataloader_args: Dict[str, Any], dataloader_type: str From 7ef14979e48f6670fbec8d460a51d3db2a9ee0ce Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Thu, 11 Apr 2024 09:12:00 +0200 Subject: [PATCH 4/8] update arg assignment for dataloaders --- src/graphnet/data/datamodule.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 323e1ce1c..0c252a462 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -105,19 +105,12 @@ def _set_dataloader_kwargs( ) if self._dataset == ParquetDataset: - self._train_dataloader_kwargs = self._add_context( - train_dl_args, "training" - ) - self._validation_dataloader_kwargs = self._add_context( - val_dl_args, "validation" - ) + train_dl_args = self._add_context(train_dl_args, "training") + val_dl_args = self._add_context(val_dl_args, "validation") if self._test_selection is not None: assert test_dl_args is not None - self._test_dataloader_kwargs = self._add_context( - test_dl_args, "test" - ) - else: - self._test_dataloader_kwargs = {} + test_dl_args = self._add_context(test_dl_args, "test") + else: self._train_dataloader_kwargs = train_dl_args self._validation_dataloader_kwargs = val_dl_args From 9e0feec2348b66bab63c03d50c746a12cbdc5202 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Thu, 11 Apr 2024 09:31:02 +0200 Subject: [PATCH 5/8] update unit test --- tests/data/test_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 1257a3f75..d4f305d7a 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -151,7 +151,7 @@ def test_single_dataset_without_selections( # validation loader should have shuffle = False by default assert isinstance(val_dataloader.sampler, SequentialSampler) # Should have identical batch_size - assert val_dataloader.batch_size != train_dataloader.batch_size + assert val_dataloader.batch_size == train_dataloader.batch_size # Training dataloader should contain more batches assert len(train_dataloader) > len(val_dataloader) From ba60583893426180a52ebc7950585789c02947fd Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Thu, 11 Apr 2024 09:56:03 +0200 Subject: [PATCH 6/8] polish arg logic for dataloaders --- src/graphnet/data/datamodule.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 0c252a462..704ee28b0 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -111,10 +111,9 @@ def _set_dataloader_kwargs( assert test_dl_args is not None test_dl_args = self._add_context(test_dl_args, "test") - else: - self._train_dataloader_kwargs = train_dl_args - self._validation_dataloader_kwargs = val_dl_args - self._test_dataloader_kwargs = test_dl_args or {} + self._train_dataloader_kwargs = train_dl_args + self._validation_dataloader_kwargs = val_dl_args + self._test_dataloader_kwargs = test_dl_args or {} def _add_context( self, dataloader_args: Dict[str, Any], dataloader_type: str From db090ae60da54c9681388c15ff996037e476be0a Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 12 Apr 2024 10:01:55 +0200 Subject: [PATCH 7/8] comments --- src/graphnet/data/dataloader.py | 2 +- src/graphnet/data/datamodule.py | 54 ++++++++++++++++++--------------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/graphnet/data/dataloader.py b/src/graphnet/data/dataloader.py index 1ded6fa37..4c02c595c 100644 --- a/src/graphnet/data/dataloader.py +++ b/src/graphnet/data/dataloader.py @@ -31,7 +31,7 @@ def __init__( dataset: Dataset, batch_size: int = 1, shuffle: bool = False, - num_workers: int = 10, + num_workers: int = 1, persistent_workers: bool = True, collate_fn: Callable = collate_fn, prefetch_factor: int = 2, diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 704ee28b0..e4562b0dc 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -94,14 +94,16 @@ def _set_dataloader_kwargs( if val_dl_args is None: self.info( "No `val_dataloader_kwargs` given. This arg has " - "been set to `train_dataloader_kwargs` automatically." + "been set to `train_dataloader_kwargs` with `shuffle` = False." ) val_dl_args = deepcopy(train_dl_args) + val_dl_args["shuffle"] = False # Important for inference if (test_dl_args is None) & (self._test_selection is not None): test_dl_args = deepcopy(train_dl_args) + test_dl_args["shuffle"] = False # Important for inference self.info( "No `test_dataloader_kwargs` given. This arg has " - "been set to `train_dataloader_kwargs` automatically." + "been set to `train_dataloader_kwargs` with `shuffle` = False." ) if self._dataset == ParquetDataset: @@ -121,33 +123,37 @@ def _add_context( """Handle assignment of `multiprocessing_context` arg to loaders. Datasets relying on threaded libraries often require the - multiprocessing context to be set to "spawn". This method will check - the arguments for this entry and throw an error if the field is already - assigned to a wrong value. If the value is not specified, it is added - automatically with a log entry. + multiprocessing context to be set to "spawn" if "num_workers" > 0. This + method will check the arguments for this entry and throw an error if + the field is already assigned to a wrong value. If the value is not + specified, it is added automatically with a log entry. """ arg = "multiprocessing_context" - if arg in dataloader_args: - if dataloader_args[arg] != "spawn": - # Wrongly assigned by user - self.error( - "DataLoaders using `ParquetDataset` must have " - "multiprocessing_context = 'spawn'. " - f" Found '{dataloader_args[arg]}' in ", - f"{dataloader_type} dataloader.", - ) - raise ValueError("multiprocessing_context must be 'spawn'") + if dataloader_args["num_workers"] != 0: + # If using multiprocessing + if arg in dataloader_args: + if dataloader_args[arg] != "spawn": + # Wrongly assigned by user + self.error( + "DataLoaders using `ParquetDataset` must have " + "multiprocessing_context = 'spawn'. " + f" Found '{dataloader_args[arg]}' in ", + f"{dataloader_type} dataloader.", + ) + raise ValueError("multiprocessing_context must be 'spawn'") + else: + # Correctly assigned by user + return dataloader_args else: - # Correctly assigned by user + # Forgotten assignment by user + dataloader_args[arg] = "spawn" + self.warning_once( + f"{self.__class__.__name__} has automatically " + "set multiprocessing_context = 'spawn' in " + f"{dataloader_type} dataloader." + ) return dataloader_args else: - # Forgotten assignment by user - dataloader_args[arg] = "spawn" - self.warning_once( - f"{self.__class__.__name__} has automatically " - "set multiprocessing_context = 'spawn' in " - f"{dataloader_type} dataloader." - ) return dataloader_args def prepare_data(self) -> None: From fb80af8673a7de6ba786499321982632fbaaaa9a Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 12 Apr 2024 10:05:07 +0200 Subject: [PATCH 8/8] update unit test to check for shuffle --- tests/data/test_datamodule.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index d4f305d7a..a3e0a0921 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -90,7 +90,7 @@ def dataset_setup(dataset_ref: pytest.FixtureRequest) -> tuple: "graph_definition": graph_definition, } - dataloader_kwargs = {"batch_size": 2, "num_workers": 1} + dataloader_kwargs = {"batch_size": 2, "num_workers": 1, "shuffle": True} return dataset_ref, dataset_kwargs, dataloader_kwargs @@ -210,6 +210,11 @@ def test_single_dataset_with_selections( # Training dataloader should have more batches assert len(train_dataloader) > len(val_dataloader) + # validation loader should have shuffle = False by default + assert isinstance(val_dataloader.sampler, SequentialSampler) + # test loader should have shuffle = False by default + assert isinstance(test_dataloader.sampler, SequentialSampler) + @pytest.mark.parametrize( "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True