Skip to content

Commit

Permalink
Fixed an issue where single-value parameters could cause errors in ba…
Browse files Browse the repository at this point in the history
…yes_opt, extended tests to test for this
  • Loading branch information
fjwillemsen committed Oct 12, 2023
1 parent a4a284b commit c27081f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
12 changes: 7 additions & 5 deletions kernel_tuner/strategies/bayes_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,13 @@ def find_param_config_unvisited_index(self, param_config: tuple) -> int:
return self.unvisited_cache.index(param_config)

def normalize_param_config(self, param_config: tuple) -> tuple:
"""Normalizes a parameter configuration."""
normalized = tuple(
self.normalized_dict[self.param_names[index]][param_value] for index, param_value in enumerate(param_config)
)
return normalized
"""Normalizes a parameter configuration. Skips over pruned values."""
param_config = self.unprune_param_config(param_config)
normalized = list()
for index, param_value in enumerate(param_config):
if self.removed_tune_params[index] is None:
normalized.append(self.normalized_dict[self.param_names[index]][param_value])
return tuple(normalized)

def denormalize_param_config(self, param_config: tuple) -> tuple:
"""Denormalizes a parameter configuration."""
Expand Down
23 changes: 19 additions & 4 deletions test/strategies/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,30 @@ def test_strategies(vector_add, strategy):

assert len(results) > 0

# check if the number of valid unique configurations is less than or equal to max_fevals
if not strategy == "brute_force":
# check if the number of valid unique configurations is less then max_fevals

tune_params = vector_add[-1]
unique_results = {}

for result in results:
x_int = ",".join([str(v) for k, v in result.items() if k in tune_params])
if not isinstance(result["time"], util.InvalidConfig):
unique_results[x_int] = result["time"]

assert len(unique_results) <= filter_options["max_fevals"]

# check whether the returned dictionaries contain exactly the expected keys and the appropriate type
expected_items = {
'block_size_x': int,
'time': (float, int),
'times': list,
'compile_time': (float, int),
'verification_time': (float, int),
'benchmark_time': (float, int),
'strategy_time': (float, int),
'framework_time': (float, int),
'timestamp': str
}
for res in results:
assert len(res) == len(expected_items)
for expected_key, expected_type in expected_items.items():
assert expected_key in res
assert isinstance(res[expected_key], expected_type)

0 comments on commit c27081f

Please sign in to comment.