Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <[email protected]>
  • Loading branch information
KumoLiu committed Mar 4, 2024
1 parent e94cf15 commit 5a4d80c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion bundle/python_bundle_workflow/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_device(self):
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_dataset_dir(self):
return "./infer"
return self.dataset_dir

def get_network_def(self):
return UNet(
Expand Down
3 changes: 2 additions & 1 deletion bundle/python_bundle_workflow/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, dataset_dir: str = "./train"):
# define buckets to store the generated properties and set properties
self._props = {}
self._set_props = {}
self.dataset_dir = dataset_dir

# besides the predefined properties, this bundle workflow can also provide `network`, `loss`, `optimizer`
self.add_property(name="network", required=True, desc="network for the training.")
Expand Down Expand Up @@ -133,7 +134,7 @@ def get_device(self):
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_dataset_dir(self):
return "./train"
return self.dataset_dir

def get_network(self):
return UNet(
Expand Down

0 comments on commit 5a4d80c

Please sign in to comment.