Skip to content

Commit

Permalink
Iterative generation using Input embeds and past_key_values (#35890)
Browse files Browse the repository at this point in the history
* Iterative generation using input embeds

* ruff fix

* Added Testcase

* Updated comment

* ♻️ Refactored testcase

* Skip test for these models

* Continue generation using input embeds and cache

* Skip generate_continue_from_embeds test

* Refactor `prepare_input_for_generation` func

* Continue generation using input embeds and cache

* Modular changes fix

* Overwrite 'prepare_inputs_for_generation' function
  • Loading branch information
yaswanth19 authored and MekkCyber committed Feb 7, 2025
1 parent b44d36e commit d68779b
Show file tree
Hide file tree
Showing 18 changed files with 276 additions and 34 deletions.
10 changes: 7 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,13 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -393,9 +397,9 @@ def prepare_inputs_for_generation(

# 3. Prepare base model inputs
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt.
if not self.config.is_encoder_decoder:
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs[input_ids_key] = None
model_inputs["inputs_embeds"] = inputs_embeds
else:
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,12 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -905,7 +909,7 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, cache_position]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,8 +1654,12 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -1668,10 +1672,13 @@ def prepare_inputs_for_generation(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if inputs_embeds is not None and input_ids.shape[1] == 0:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
else:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,10 +1674,13 @@ def prepare_inputs_for_generation(
else:
model_inputs["pixel_values"] = pixel_values

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# If we have cache: let's slice `input_ids` or `input embeds` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if inputs_embeds is not None:
input_ids = input_ids[:, -cache_position.shape[0] :]
if input_ids.shape[1] == 0:
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
else:
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if image_attention_mask is not None:
Expand All @@ -1687,14 +1690,19 @@ def prepare_inputs_for_generation(
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)

# If past_key_values are present then slice the postion ids for only only the unprocessed tokens.
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if inputs_embeds is not None and input_ids.shape[1] == 0:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
else:
position_ids = position_ids[:, -input_ids.shape[1] :]

# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None})
else:
# The clone here is for the same reason as for `position_ids`.
Expand Down
66 changes: 52 additions & 14 deletions src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1901,8 +1901,7 @@ def forward(


@add_start_docstrings(
"The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, "
"for speech-to-speech.",
"The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, for speech-to-speech.",
MOSHI_START_DOCSTRING,
)
class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
Expand Down Expand Up @@ -2458,18 +2457,57 @@ def prepare_inputs_for_generation(
blank_user_audio_codes: Optional[torch.FloatTensor] = None,
**kwargs,
):
# Overwritten -- Moshi has custom post-processing
# 1. Do usual operations done on LLMs like Gemma - because we pre-processed inputs, the first pass always has inputs_embeds
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
# Overwritten -- Moshi has custom post-processing on the prepared inputs.

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)

if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device

attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
config=self.config,
past_key_values=past_key_values,
)

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cache_position": cache_position,
}
)

# 2. Now that everything is prepared, generate audio_codes using the depth decoder
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ def _update_causal_mask(
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and attention_mask.device.type in ["cuda", "xpu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
Expand Down Expand Up @@ -1872,8 +1872,12 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -1886,7 +1890,7 @@ def prepare_inputs_for_generation(
pixel_values_videos = None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,8 +770,12 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -784,7 +788,7 @@ def prepare_inputs_for_generation(
pixel_values_videos = None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,8 +1735,12 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -1749,7 +1753,7 @@ def prepare_inputs_for_generation(
pixel_values_videos = None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and attention_mask.device.type in ["cuda", "xpu"]
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
Expand Down
Loading

0 comments on commit d68779b

Please sign in to comment.