diff --git a/bundle/python_bundle_workflow/scripts/inference.py b/bundle/python_bundle_workflow/scripts/inference.py index 9d05c1218b..a98b47bb64 100644 --- a/bundle/python_bundle_workflow/scripts/inference.py +++ b/bundle/python_bundle_workflow/scripts/inference.py @@ -108,7 +108,7 @@ def get_device(self): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_dataset_dir(self): - return "./infer" + return self.dataset_dir def get_network_def(self): return UNet( diff --git a/bundle/python_bundle_workflow/scripts/train.py b/bundle/python_bundle_workflow/scripts/train.py index d1282daf6e..29a53b46a1 100644 --- a/bundle/python_bundle_workflow/scripts/train.py +++ b/bundle/python_bundle_workflow/scripts/train.py @@ -79,6 +79,7 @@ def __init__(self, dataset_dir: str = "./train"): # define buckets to store the generated properties and set properties self._props = {} self._set_props = {} + self.dataset_dir = dataset_dir # besides the predefined properties, this bundle workflow can also provide `network`, `loss`, `optimizer` self.add_property(name="network", required=True, desc="network for the training.") @@ -133,7 +134,7 @@ def get_device(self): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_dataset_dir(self): - return "./train" + return self.dataset_dir def get_network(self): return UNet(