Skip to content

Commit

Permalink
yapf with updated deps
Browse files Browse the repository at this point in the history
  • Loading branch information
tvmarino committed Dec 21, 2024
1 parent 1018316 commit 293970c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,10 @@ def __init__(
max_horizon_to_explore=np.inf,
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
bool]]] = None,
obs_action_specs: Optional[Tuple[time_step.TimeStep,
tensor_spec.BoundedTensorSpec,]] = None,
obs_action_specs: Optional[Tuple[
time_step.TimeStep,
tensor_spec.BoundedTensorSpec,
]] = None,
reward_key: str = '',
keep_temps: Optional[str] = None,
**kwargs,
Expand Down Expand Up @@ -751,8 +753,10 @@ def __init__(
exploration_policy_paths: Optional[str] = None,
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
bool]]] = None,
obs_action_specs: Optional[Tuple[time_step.TimeStep,
tensor_spec.BoundedTensorSpec,]] = None,
obs_action_specs: Optional[Tuple[
time_step.TimeStep,
tensor_spec.BoundedTensorSpec,
]] = None,
base_path: Optional[str] = None,
partitions: List[float] = [
0.,
Expand Down
4 changes: 2 additions & 2 deletions compiler_opt/rl/policy_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def __init__(self, policy_dict: Dict[str, tf_policy.TFPolicy]):
self._policy_saver_dict: Dict[str, Tuple[
policy_saver.PolicySaver, tf_policy.TFPolicy]] = {
policy_name: (policy_saver.PolicySaver(
policy, batch_size=1, use_nest_path_signatures=False), policy)
for policy_name, policy in policy_dict.items()
policy, batch_size=1, use_nest_path_signatures=False), policy
) for policy_name, policy in policy_dict.items()
}

def _write_output_signature(
Expand Down
4 changes: 2 additions & 2 deletions compiler_opt/rl/train_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def sequence_example_iterator_fn(seq_ex: List[str]):

# Repeat for num_policy_iterations iterations.
t1 = time.time()
while (llvm_trainer.global_step_numpy() <
num_policy_iterations * num_iterations):
while (llvm_trainer.global_step_numpy()
< num_policy_iterations * num_iterations):
t2 = time.time()
logging.info('Last iteration took: %f', t2 - t1)
t1 = t2
Expand Down

0 comments on commit 293970c

Please sign in to comment.