Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterative generation using Input embeds and past_key_values #35890

Merged
10 changes: 7 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,13 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -393,9 +397,9 @@ def prepare_inputs_for_generation(

# 3. Prepare base model inputs
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt.
if not self.config.is_encoder_decoder:
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs[input_ids_key] = None
model_inputs["inputs_embeds"] = inputs_embeds
else:
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,12 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -905,7 +909,7 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, cache_position]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,8 +1654,12 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -1668,10 +1672,13 @@ def prepare_inputs_for_generation(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if inputs_embeds is not None and input_ids.shape[1] == 0:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
else:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,10 +1674,13 @@ def prepare_inputs_for_generation(
else:
model_inputs["pixel_values"] = pixel_values

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# If we have cache: let's slice `input_ids` or `input embeds` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if inputs_embeds is not None:
input_ids = input_ids[:, -cache_position.shape[0] :]
if input_ids.shape[1] == 0:
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
else:
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if image_attention_mask is not None:
Expand All @@ -1687,14 +1690,19 @@ def prepare_inputs_for_generation(
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)

# If past_key_values are present then slice the postion ids for only only the unprocessed tokens.
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if inputs_embeds is not None and input_ids.shape[1] == 0:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
else:
position_ids = position_ids[:, -input_ids.shape[1] :]

# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None})
else:
# The clone here is for the same reason as for `position_ids`.
Expand Down
66 changes: 52 additions & 14 deletions src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1901,8 +1901,7 @@ def forward(


@add_start_docstrings(
"The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, "
"for speech-to-speech.",
"The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, for speech-to-speech.",
MOSHI_START_DOCSTRING,
)
class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
Expand Down Expand Up @@ -2458,18 +2457,57 @@ def prepare_inputs_for_generation(
blank_user_audio_codes: Optional[torch.FloatTensor] = None,
**kwargs,
):
# Overwritten -- Moshi has custom post-processing
# 1. Do usual operations done on LLMs like Gemma - because we pre-processed inputs, the first pass always has inputs_embeds
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
# Overwritten -- Moshi has custom post-processing on the prepared inputs.

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)

if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device

attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
config=self.config,
past_key_values=past_key_values,
)

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cache_position": cache_position,
}
)

# 2. Now that everything is prepared, generate audio_codes using the depth decoder
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ def _update_causal_mask(
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and attention_mask.device.type in ["cuda", "xpu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
Expand Down Expand Up @@ -1872,8 +1872,12 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -1886,7 +1890,7 @@ def prepare_inputs_for_generation(
pixel_values_videos = None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,8 +770,12 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -784,7 +788,7 @@ def prepare_inputs_for_generation(
pixel_values_videos = None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,8 +1735,12 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -1749,7 +1753,7 @@ def prepare_inputs_for_generation(
pixel_values_videos = None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and attention_mask.device.type in ["cuda", "xpu"]
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
Expand Down
Loading