Skip to content

Commit

Permalink
Fix tests (#669)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssarkar2 authored Jan 30, 2024
1 parent 937a411 commit ada8b22
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,8 +1368,8 @@ def greedy_search(
hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps)
hb_profer.start()
this_peer_finished = False # used by synced_gpus only
bucket_size = model_kwargs["bucket_size"]
reduce_recompile = model_kwargs["reduce_recompile"]
bucket_size = model_kwargs.get("bucket_size", -1)
reduce_recompile = model_kwargs.get("reduce_recompile", False)

prompt_len = input_ids.shape[-1]
if bucket_size >= 0:
Expand Down Expand Up @@ -2121,8 +2121,8 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
hb_profer.start()
this_peer_finished = False # used by synced_gpus only

bucket_size = model_kwargs["bucket_size"]
reduce_recompile = model_kwargs["reduce_recompile"]
bucket_size = model_kwargs.get("bucket_size", -1)
reduce_recompile = model_kwargs.get("reduce_recompile", False)
prompt_len = input_ids.shape[-1]
if bucket_size >= 0:
inc = iter(incrementor(bucket_size, prompt_len))
Expand Down

0 comments on commit ada8b22

Please sign in to comment.