Skip to content

Commit

Permalink
handle folds greater than max fold constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
suzhoum committed Nov 12, 2023
1 parent e270d3c commit 6f6b520
Showing 1 changed file with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,12 @@ def process_benchmark_runs(module_configs: dict, amlb_benchmark_search_dirs: lis
module_configs["fold_to_run"].setdefault(benchmark, {})
for task in module_configs["amlb_task"][benchmark]:
if module_configs["fold_to_run"][benchmark].get(task):
tasks = module_configs["fold_to_run"][benchmark][task]
module_configs["fold_to_run"][benchmark][task] = [t for t in tasks if t < default_max_folds]
folds = module_configs["fold_to_run"][benchmark][task]
else:
module_configs["fold_to_run"][benchmark][task] = amlb_task_folds[benchmark][task]
folds = amlb_task_folds[benchmark][task]
module_configs["fold_to_run"][benchmark][task] = [f for f in folds if f < default_max_folds]
if not module_configs["fold_to_run"][benchmark][task]:
del module_configs["fold_to_run"][benchmark][task]


def get_cloudwatch_logs_url(region: str, job_id: str, log_group_name: str = "aws/batch/job"):
Expand All @@ -219,6 +221,9 @@ def generate_config_combinations(config, metrics_bucket, batch_job_queue, batch_
else:
raise ValueError("Invalid module. Choose either 'tabular', 'timeseries', or 'multimodal'.")

if len(job_configs) == 0:
return {parent_job_id: "No job submitted"}

benchmark_name = config["benchmark_name"]
config_s3_path = upload_config(config_list=job_configs, bucket=metrics_bucket, benchmark_name=benchmark_name)
env = [{"name": "config_file", "value": config_s3_path}]
Expand Down

0 comments on commit 6f6b520

Please sign in to comment.