Skip to content

Commit

Permalink
Merge pull request #694 from RasmusOrsoe/parquet_dataset_hangup
Browse files Browse the repository at this point in the history
Update DataModule
  • Loading branch information
RasmusOrsoe authored Apr 15, 2024
2 parents 7b83990 + fb80af8 commit 9fc3400
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/graphnet/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
dataset: Dataset,
batch_size: int = 1,
shuffle: bool = False,
num_workers: int = 10,
num_workers: int = 1,
persistent_workers: bool = True,
collate_fn: Callable = collate_fn,
prefetch_factor: int = 2,
Expand Down
104 changes: 94 additions & 10 deletions src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ def __init__(
dataset_args: Dict[str, Any],
selection: Optional[Union[List[int], List[List[int]]]] = None,
test_selection: Optional[Union[List[int], List[List[int]]]] = None,
train_dataloader_kwargs: Optional[Dict[str, Any]] = None,
validation_dataloader_kwargs: Optional[Dict[str, Any]] = None,
test_dataloader_kwargs: Optional[Dict[str, Any]] = None,
train_dataloader_kwargs: Dict[str, Any] = {
"batch_size": 2,
"num_workers": 1,
},
validation_dataloader_kwargs: Dict[str, Any] = None,
test_dataloader_kwargs: Dict[str, Any] = None,
train_val_split: Optional[List[float]] = [0.9, 0.10],
split_seed: int = 42,
) -> None:
Expand All @@ -42,13 +45,14 @@ def __init__(
selection: (Optional) a list of event id's used for training
and validation, Default None.
test_selection: (Optional) a list of event id's used for testing,
Default None.
Defaults to None.
train_dataloader_kwargs: Arguments for the training DataLoader,
Default None.
Defaults{"batch_size": 2, "num_workers": 1}.
validation_dataloader_kwargs: Arguments for the validation
DataLoader, Default None.
DataLoader. Defaults to
`train_dataloader_kwargs`.
test_dataloader_kwargs: Arguments for the test DataLoader,
Default None.
Defaults to `train_dataloader_kwargs`.
train_val_split (Optional): Split ratio for training and
validation sets. Default is [0.9, 0.10].
split_seed: seed used for shuffling and splitting selections into
Expand All @@ -63,9 +67,11 @@ def __init__(
self._train_val_split = train_val_split or [0.0]
self._rng = split_seed

self._train_dataloader_kwargs = train_dataloader_kwargs or {}
self._validation_dataloader_kwargs = validation_dataloader_kwargs or {}
self._test_dataloader_kwargs = test_dataloader_kwargs or {}
self._set_dataloader_kwargs(
train_dataloader_kwargs,
validation_dataloader_kwargs,
test_dataloader_kwargs,
)

# If multiple dataset paths are given, we should use EnsembleDataset
self._use_ensemble_dataset = isinstance(
Expand All @@ -75,6 +81,84 @@ def __init__(
# Create Dataloaders
self.setup("fit")

def _set_dataloader_kwargs(
self,
train_dl_args: Dict[str, Any],
val_dl_args: Union[Dict[str, Any], None],
test_dl_args: Union[Dict[str, Any], None],
) -> None:
"""Copy train dataloader args to other dataloaders if not given.
Also checks that ParquetDataset dataloaders have multiprocessing
context set to "spawn" as this is strictly required.
See: https://docs.pola.rs/user-guide/misc/multiprocessing/
"""
if val_dl_args is None:
self.info(
"No `val_dataloader_kwargs` given. This arg has "
"been set to `train_dataloader_kwargs` with `shuffle` = False."
)
val_dl_args = deepcopy(train_dl_args)
val_dl_args["shuffle"] = False # Important for inference
if (test_dl_args is None) & (self._test_selection is not None):
test_dl_args = deepcopy(train_dl_args)
test_dl_args["shuffle"] = False # Important for inference
self.info(
"No `test_dataloader_kwargs` given. This arg has "
"been set to `train_dataloader_kwargs` with `shuffle` = False."
)

if self._dataset == ParquetDataset:
train_dl_args = self._add_context(train_dl_args, "training")
val_dl_args = self._add_context(val_dl_args, "validation")
if self._test_selection is not None:
assert test_dl_args is not None
test_dl_args = self._add_context(test_dl_args, "test")

self._train_dataloader_kwargs = train_dl_args
self._validation_dataloader_kwargs = val_dl_args
self._test_dataloader_kwargs = test_dl_args or {}

def _add_context(
self, dataloader_args: Dict[str, Any], dataloader_type: str
) -> Dict[str, Any]:
"""Handle assignment of `multiprocessing_context` arg to loaders.
Datasets relying on threaded libraries often require the
multiprocessing context to be set to "spawn" if "num_workers" > 0. This
method will check the arguments for this entry and throw an error if
the field is already assigned to a wrong value. If the value is not
specified, it is added automatically with a log entry.
"""
arg = "multiprocessing_context"
if dataloader_args["num_workers"] != 0:
# If using multiprocessing
if arg in dataloader_args:
if dataloader_args[arg] != "spawn":
# Wrongly assigned by user
self.error(
"DataLoaders using `ParquetDataset` must have "
"multiprocessing_context = 'spawn'. "
f" Found '{dataloader_args[arg]}' in ",
f"{dataloader_type} dataloader.",
)
raise ValueError("multiprocessing_context must be 'spawn'")
else:
# Correctly assigned by user
return dataloader_args
else:
# Forgotten assignment by user
dataloader_args[arg] = "spawn"
self.warning_once(
f"{self.__class__.__name__} has automatically "
"set multiprocessing_context = 'spawn' in "
f"{dataloader_type} dataloader."
)
return dataloader_args
else:
return dataloader_args

def prepare_data(self) -> None:
"""Prepare the dataset for training."""
# Download method for curated datasets. Method for download is
Expand Down
3 changes: 3 additions & 0 deletions src/graphnet/data/dataset/parquet/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __init__(
):
"""Construct Dataset.
NOTE: DataLoaders using this Dataset should have
"multiprocessing_context = 'spawn'" set to avoid thread locking.
Args:
path: Path to the file(s) from which this `Dataset` should read.
pulsemaps: Name(s) of the pulse map series that should be used to
Expand Down
9 changes: 7 additions & 2 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def dataset_setup(dataset_ref: pytest.FixtureRequest) -> tuple:
"graph_definition": graph_definition,
}

dataloader_kwargs = {"batch_size": 2, "num_workers": 1}
dataloader_kwargs = {"batch_size": 2, "num_workers": 1, "shuffle": True}

return dataset_ref, dataset_kwargs, dataloader_kwargs

Expand Down Expand Up @@ -151,7 +151,7 @@ def test_single_dataset_without_selections(
# validation loader should have shuffle = False by default
assert isinstance(val_dataloader.sampler, SequentialSampler)
# Should have identical batch_size
assert val_dataloader.batch_size != train_dataloader.batch_size
assert val_dataloader.batch_size == train_dataloader.batch_size
# Training dataloader should contain more batches
assert len(train_dataloader) > len(val_dataloader)

Expand Down Expand Up @@ -210,6 +210,11 @@ def test_single_dataset_with_selections(
# Training dataloader should have more batches
assert len(train_dataloader) > len(val_dataloader)

# validation loader should have shuffle = False by default
assert isinstance(val_dataloader.sampler, SequentialSampler)
# test loader should have shuffle = False by default
assert isinstance(test_dataloader.sampler, SequentialSampler)


@pytest.mark.parametrize(
"dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True
Expand Down

0 comments on commit 9fc3400

Please sign in to comment.