Skip to content

Commit

Permalink
feature: add argument checks
Browse files Browse the repository at this point in the history
  • Loading branch information
sokovninn committed Dec 17, 2023
1 parent dac9fb7 commit 6d2b340
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions src/datadreamer/pipelines/generate_dataset_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

Expand Down Expand Up @@ -139,6 +140,58 @@ def parse_args():
return parser.parse_args()


def check_args(args):
# Check save_dir
if not os.path.exists(args.save_dir):
try:
os.makedirs(args.save_dir)
except OSError as e:
raise ValueError(f"Cannot create directory {args.save_dir}: {e}") from e

# Check class_names
if not args.class_names or any(
not isinstance(name, str) for name in args.class_names
):
raise ValueError("--class_names must be a non-empty list of strings")

# Check prompts_number
if args.prompts_number <= 0:
raise ValueError("--prompts_number must be a positive integer")

# Check num_objects_range
if (
len(args.num_objects_range) != 2
or not all(isinstance(n, int) for n in args.num_objects_range)
or args.num_objects_range[0] > args.num_objects_range[1]
):
raise ValueError(
"--num_objects_range must be two integers where the first is less than or equal to the second"
)

# Check num_objects_range[1]
if args.num_objects_range[1] > len(args.class_names):
raise ValueError(
"--num_objects_range[1] must be less than or equal to the number of class names"
)

# Check conf_threshold
if not 0 <= args.conf_threshold <= 1:
raise ValueError("--conf_threshold must be between 0 and 1")

# Check image_tester_patience
if args.image_tester_patience < 0:
raise ValueError("--image_tester_patience must be a non-negative integer")

# Check device availability (for 'cuda')
if args.device == "cuda":
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Please use --device cpu")

# Check seed
if args.seed < 0:
raise ValueError("--seed must be a non-negative integer")


def save_det_annotations_to_json(
image_paths,
boxes_list,
Expand Down Expand Up @@ -179,6 +232,7 @@ def save_clf_annotations_to_json(

def main():
args = parse_args()
check_args(args)

save_dir = args.save_dir

Expand Down

0 comments on commit 6d2b340

Please sign in to comment.