From 5a4d80c17c1d93ddf9a3142e0a68171893c4759b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 4 Mar 2024 15:06:40 +0800 Subject: [PATCH] address comments Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- bundle/python_bundle_workflow/scripts/inference.py | 2 +- bundle/python_bundle_workflow/scripts/train.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bundle/python_bundle_workflow/scripts/inference.py b/bundle/python_bundle_workflow/scripts/inference.py index 9d05c1218b..a98b47bb64 100644 --- a/bundle/python_bundle_workflow/scripts/inference.py +++ b/bundle/python_bundle_workflow/scripts/inference.py @@ -108,7 +108,7 @@ def get_device(self): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_dataset_dir(self): - return "./infer" + return self.dataset_dir def get_network_def(self): return UNet( diff --git a/bundle/python_bundle_workflow/scripts/train.py b/bundle/python_bundle_workflow/scripts/train.py index d1282daf6e..29a53b46a1 100644 --- a/bundle/python_bundle_workflow/scripts/train.py +++ b/bundle/python_bundle_workflow/scripts/train.py @@ -79,6 +79,7 @@ def __init__(self, dataset_dir: str = "./train"): # define buckets to store the generated properties and set properties self._props = {} self._set_props = {} + self.dataset_dir = dataset_dir # besides the predefined properties, this bundle workflow can also provide `network`, `loss`, `optimizer` self.add_property(name="network", required=True, desc="network for the training.") @@ -133,7 +134,7 @@ def get_device(self): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_dataset_dir(self): - return "./train" + return self.dataset_dir def get_network(self): return UNet(