From 53d63ebe0b4b1c9f6fac2f07c023c56644a78a6e Mon Sep 17 00:00:00 2001 From: Antonios Sarikas Date: Fri, 10 Jan 2025 13:24:26 +0200 Subject: [PATCH] refactor(prepare_data): use `None` as default for `split_ratio` Changed the default value of `split_ratio` parameter from `(0.8, 0.1, 0.1)` to `None`. This change improves the readability of the function signature in both code and documentation. Fixes #35 --- src/aidsorb/data.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/aidsorb/data.py b/src/aidsorb/data.py index b800f40..d9a6fae 100644 --- a/src/aidsorb/data.py +++ b/src/aidsorb/data.py @@ -31,7 +31,7 @@ from . transforms import upsample_pcd -def prepare_data(source: str, split_ratio: Sequence=(0.8, 0.1, 0.1), seed: int = SEED): +def prepare_data(source: str, split_ratio: Sequence = None, seed: int = SEED): r""" Split point clouds into train, validation and test sets. @@ -50,9 +50,9 @@ def prepare_data(source: str, split_ratio: Sequence=(0.8, 0.1, 0.1), seed: int = ---------- source : str Absolute or relative path to the directory holding the point clouds. - split_ratio : sequence, default=(0.8, 0.1, 0.1) + split_ratio : sequence, default=None Absolute sizes or fractions of splits of the form ``(train, val, - test)``. + test)``. If ``None``, it is set to ``(0.8, 0.1, 0.1)``. seed : int, default=1 Controls randomness of the ``rng`` used for splitting. @@ -84,6 +84,10 @@ def prepare_data(source: str, split_ratio: Sequence=(0.8, 0.1, 0.1), seed: int = path = Path(source).parent pcd_names = [name.removesuffix('.npy') for name in os.listdir(source)] + # Set default split ratio. + if split_ratio is None: + split_ratio = (0.8, 0.1, 0.1) + # Split the names of the point clouds. train, val, test = random_split(pcd_names, split_ratio, generator=rng)