diff --git a/generate.py b/generate.py index 8446d11..ffc3954 100644 --- a/generate.py +++ b/generate.py @@ -169,9 +169,8 @@ def generate( draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty + seq = torch.empty(T_new, dtype=dtype, device=device) + seq[:T] = prompt input_pos = torch.arange(0, T, device=device) next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)