From c0bb19182b0d42367ad6e21a922118a2845cc251 Mon Sep 17 00:00:00 2001 From: yaswanth Date: Sat, 25 Jan 2025 21:01:32 +0530 Subject: [PATCH 01/12] Iterative generation using input embeds --- src/transformers/generation/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 45558bd22a4e..1fd9da870217 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -383,7 +383,9 @@ def prepare_inputs_for_generation( # (we can't check exception 3 while compiling) if past_key_values is not None: model_inputs["past_key_values"] = past_key_values - if ( + if inputs_embeds is not None and input_ids.shape[1]==0: + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -393,9 +395,9 @@ def prepare_inputs_for_generation( # 3. Prepare base model inputs input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt. if not self.config.is_encoder_decoder: - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs[input_ids_key] = None model_inputs["inputs_embeds"] = inputs_embeds else: From 837f6b49a5db112369863f2ee37e7f141959c7c5 Mon Sep 17 00:00:00 2001 From: yaswanth Date: Sat, 25 Jan 2025 21:32:44 +0530 Subject: [PATCH 02/12] ruff fix --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1fd9da870217..067db65a616d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -383,7 +383,7 @@ def prepare_inputs_for_generation( # (we can't check exception 3 while compiling) if past_key_values is not None: model_inputs["past_key_values"] = past_key_values - if inputs_embeds is not None and input_ids.shape[1]==0: + if inputs_embeds is not None and input_ids.shape[1] == 0: inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] elif ( inputs_embeds is not None # Exception 1 From cc93cf7e1d2104b6a36839d6152e03245775e949 Mon Sep 17 00:00:00 2001 From: yaswanth Date: Thu, 30 Jan 2025 22:21:02 +0530 Subject: [PATCH 03/12] Added Testcase --- tests/generation/test_utils.py | 66 ++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 321803a2179b..08b7bac7c6ca 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1857,6 +1857,72 @@ def test_generate_continue_from_past_key_values(self): ) ) + @pytest.mark.generate + def test_continue_generate_from_inputs_embeds(self): + """Tests that we can continue generation from `inputs_embeds` and past key values returned from a previous `generate` call.""" + for model_class in self.all_generative_model_classes: + if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): + self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") + if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): + self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") + + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + + if config.is_encoder_decoder: + self.skipTest(reason="This model is encoder-decoder") + if not hasattr(config, "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + + model = model_class(config).to(torch_device).eval() + + if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): + self.skipTest(reason="This model does not support `inputs_embeds` in generation") + + pixel_values_is_mutually_exclusive = any( + model_name in model_class.__name__.lower() + for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"] + ) + if pixel_values_is_mutually_exclusive: + inputs_dict.pop("pixel_values", None) + inputs_dict.pop("pixel_values_videos", None) + inputs_dict.pop("pixel_values_images", None) + + input_ids = inputs_dict.pop("input_ids") + + model.config.use_cache = True + model.config.is_decoder = True + max_cache_len = 10 + + inputs_embeds = model.get_input_embeddings()(input_ids[0].unsqueeze(0)) + generation_kwargs = { + "max_length": max_cache_len, + "return_dict_in_generate": True, + "do_sample": False, + } + + # Generate the first batch of tokens and capture the `past_key_values` in cache + with torch.no_grad(): + prompt_cache = model(inputs_embeds=inputs_embeds, **generation_kwargs).past_key_values + + # Concatenate the new input embeddings for continuation. + new_inputs_embeds = torch.cat( + [inputs_embeds, model.get_input_embeddings()(input_ids[1].unsqueeze(0))], dim=1 + ) + + # Continue generation using the concatenated `inputs_embeds` and the original `past_key_values` + outputs_continued = model.generate( + inputs_embeds=new_inputs_embeds, past_key_values=prompt_cache, **generation_kwargs + ) + # Generate the sequence by combining the original two input_ids and generating the entire sequence in one go. + combined_inputs = input_ids.flatten().unsqueeze(0) + combined_embeds = model.get_input_embeddings()(combined_inputs) + + # Generate using the combined input embeddings (no cache passed) + outputs_combined = model.generate(inputs_embeds=combined_embeds, **generation_kwargs) + + # Verify that the generated sequences are identical + self.assertListEqual(outputs_continued.sequences.tolist(), outputs_combined.sequences.tolist()) + @parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5) @require_torch_gpu @pytest.mark.generate From 1b9d24731be84e214b4263b19d33abe431e31fb5 Mon Sep 17 00:00:00 2001 From: yaswanth Date: Thu, 30 Jan 2025 22:36:28 +0530 Subject: [PATCH 04/12] Updated comment --- src/transformers/generation/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 067db65a616d..f0f7f2b0b6b5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -381,9 +381,11 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: model_inputs["past_key_values"] = past_key_values - if inputs_embeds is not None and input_ids.shape[1] == 0: + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] elif ( inputs_embeds is not None # Exception 1 From 380a694ced3c863b6a30336e62e748c5d25b9125 Mon Sep 17 00:00:00 2001 From: yaswanth Date: Sat, 1 Feb 2025 10:46:01 +0530 Subject: [PATCH 05/12] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactored=20testcas?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/generation/test_utils.py | 61 ++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 08b7bac7c6ca..c7c8c7f8c108 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1858,7 +1858,7 @@ def test_generate_continue_from_past_key_values(self): ) @pytest.mark.generate - def test_continue_generate_from_inputs_embeds(self): + def test_generate_continue_from_inputs_embeds(self): """Tests that we can continue generation from `inputs_embeds` and past key values returned from a previous `generate` call.""" for model_class in self.all_generative_model_classes: if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): @@ -1868,6 +1868,9 @@ def test_continue_generate_from_inputs_embeds(self): config, inputs_dict = self.prepare_config_and_inputs_for_generate() + if "token_type_ids" in inputs_dict: + del inputs_dict["token_type_ids"] + if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder") if not hasattr(config, "use_cache"): @@ -1878,6 +1881,11 @@ def test_continue_generate_from_inputs_embeds(self): if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): self.skipTest(reason="This model does not support `inputs_embeds` in generation") + # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) + outputs = model(**inputs_dict) + if "past_key_values" not in outputs: + self.skipTest(reason="This model doesn't return `past_key_values`") + pixel_values_is_mutually_exclusive = any( model_name in model_class.__name__.lower() for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"] @@ -1889,39 +1897,42 @@ def test_continue_generate_from_inputs_embeds(self): input_ids = inputs_dict.pop("input_ids") - model.config.use_cache = True + model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 + model.generation_config.forced_eos_token_id = None model.config.is_decoder = True - max_cache_len = 10 + model.generation_config.use_cache = True - inputs_embeds = model.get_input_embeddings()(input_ids[0].unsqueeze(0)) generation_kwargs = { - "max_length": max_cache_len, "return_dict_in_generate": True, "do_sample": False, } - # Generate the first batch of tokens and capture the `past_key_values` in cache - with torch.no_grad(): - prompt_cache = model(inputs_embeds=inputs_embeds, **generation_kwargs).past_key_values - - # Concatenate the new input embeddings for continuation. - new_inputs_embeds = torch.cat( - [inputs_embeds, model.get_input_embeddings()(input_ids[1].unsqueeze(0))], dim=1 - ) - - # Continue generation using the concatenated `inputs_embeds` and the original `past_key_values` - outputs_continued = model.generate( - inputs_embeds=new_inputs_embeds, past_key_values=prompt_cache, **generation_kwargs + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values. + input_embeds = model.get_input_embeddings()(input_ids) + outputs = model.generate(inputs_embeds=input_embeds, max_new_tokens=4, **generation_kwargs) + + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens) + initial_output = model.generate(inputs_embeds=input_embeds, max_new_tokens=3, **generation_kwargs) + continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1) + cached_output = model.generate( + inputs_embeds=continued_embeds, + max_new_tokens=1, + past_key_values=initial_output.past_key_values, + **generation_kwargs, ) - # Generate the sequence by combining the original two input_ids and generating the entire sequence in one go. - combined_inputs = input_ids.flatten().unsqueeze(0) - combined_embeds = model.get_input_embeddings()(combined_inputs) - # Generate using the combined input embeddings (no cache passed) - outputs_combined = model.generate(inputs_embeds=combined_embeds, **generation_kwargs) - - # Verify that the generated sequences are identical - self.assertListEqual(outputs_continued.sequences.tolist(), outputs_combined.sequences.tolist()) + # Combine the (3 + 1) generated tokens and verify it matches with full generation. + combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1) + self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist()) + # The two sets of past kv should be equal to each other + for layer_idx in range(len(cached_output.past_key_values)): + for kv_idx in range(len(cached_output.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + cached_output.past_key_values[layer_idx][kv_idx], + ) + ) @parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5) @require_torch_gpu From 03a106dabd8d92000916320f2e30b2116a21b304 Mon Sep 17 00:00:00 2001 From: yaswanth Date: Sat, 1 Feb 2025 16:23:11 +0530 Subject: [PATCH 06/12] Skip test for these models --- tests/models/gemma2/test_modeling_gemma2.py | 4 ++++ tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py | 4 ++++ tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py | 6 ++++++ 3 files changed, 14 insertions(+) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 1fb7bdfa8994..a0563aed90cb 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -146,6 +146,10 @@ def test_generate_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self): pass + @unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_continue_from_inputs_embeds(self): + pass + # overwrite because HybridCache has fixed length for key/values def _check_attentions_for_generate( self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 1ac2db408123..c854a7e71167 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -450,6 +450,10 @@ def test_disk_offload(self): def test_past_key_values_format(self): pass + @unittest.skip(reason="BigCodeGPT has a non-standard KV cache format and breaks this test.") + def test_generate_continue_from_inputs_embeds(self): + pass + def test_gpt_bigcode_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs) diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index d76a7ba1e20c..de02783d93eb 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -342,6 +342,12 @@ def test_beam_search_low_memory(self): def test_generate_from_inputs_embeds_with_static_cache(self): pass + @unittest.skip( + reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs" + ) + def test_generate_continue_from_inputs_embeds(self): + pass + @unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`") def test_generate_compile_fullgraph(self): pass From e5af893356c2d9730f7474e43fe513d382ccf1a9 Mon Sep 17 00:00:00 2001 From: yaswant19 Date: Wed, 5 Feb 2025 11:28:28 +0530 Subject: [PATCH 07/12] Continue generation using input embeds and cache --- src/transformers/models/bloom/modeling_bloom.py | 8 ++++++-- src/transformers/models/chameleon/modeling_chameleon.py | 8 ++++++-- src/transformers/models/idefics/modeling_idefics.py | 9 ++++++--- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 8 ++++++-- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 8 ++++++-- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 8 ++++++-- 6 files changed, 36 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 19ca679ad0df..9a03a793a613 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -895,8 +895,12 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -905,7 +909,7 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, cache_position] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 1e088fcaba00..8aa937413431 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1654,8 +1654,12 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -1671,7 +1675,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 4dbe4ad4c7f9..48cf500b7454 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1674,10 +1674,13 @@ def prepare_inputs_for_generation( else: model_inputs["pixel_values"] = pixel_values - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # If we have cache: let's slice `input_ids` or `input embeds` through `cache_position`, to keep only the unprocessed tokens if past_key_values is not None: if inputs_embeds is not None: - input_ids = input_ids[:, -cache_position.shape[0] :] + if input_ids.shape[1] == 0: + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + else: + input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: input_ids = input_ids[:, cache_position] if image_attention_mask is not None: @@ -1694,7 +1697,7 @@ def prepare_inputs_for_generation( position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None}) else: # The clone here is for the same reason as for `position_ids`. diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 242622d293a2..0b7348671148 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1872,8 +1872,12 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -1886,7 +1890,7 @@ def prepare_inputs_for_generation( pixel_values_videos = None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: model_inputs = {"input_ids": input_ids, "inputs_embeds": None} diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 87216988b717..601ad373771c 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -770,8 +770,12 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -784,7 +788,7 @@ def prepare_inputs_for_generation( pixel_values_videos = None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: model_inputs = {"input_ids": input_ids, "inputs_embeds": None} diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index d94daa39a729..51d8fe9430b5 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1735,8 +1735,12 @@ def prepare_inputs_for_generation( # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: - if ( + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): @@ -1749,7 +1753,7 @@ def prepare_inputs_for_generation( pixel_values_videos = None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: model_inputs = {"input_ids": input_ids, "inputs_embeds": None} From 21a3341a2ee9785ee92f67631da6e7351f62b8be Mon Sep 17 00:00:00 2001 From: yaswant19 Date: Wed, 5 Feb 2025 11:29:03 +0530 Subject: [PATCH 08/12] Skip generate_continue_from_embeds test --- tests/models/clvp/test_modeling_clvp.py | 4 ++++ tests/models/cohere2/test_modeling_cohere2.py | 4 ++++ tests/models/fuyu/test_modeling_fuyu.py | 4 ++++ tests/models/moshi/test_modeling_moshi.py | 8 ++++++++ tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py | 6 ------ tests/models/zamba2/test_modeling_zamba2.py | 4 ++++ 6 files changed, 24 insertions(+), 6 deletions(-) diff --git a/tests/models/clvp/test_modeling_clvp.py b/tests/models/clvp/test_modeling_clvp.py index 84a0101f6f28..334f01004936 100644 --- a/tests/models/clvp/test_modeling_clvp.py +++ b/tests/models/clvp/test_modeling_clvp.py @@ -334,6 +334,10 @@ def test_training_gradient_checkpointing(self): loss = model(**inputs).loss loss.backward() + @unittest.skip(reason="Clvp `prepare_inputs_for_generation` function doesn't have cache position.") + def test_generate_continue_from_inputs_embeds(self): + pass + class ClvpModelForConditionalGenerationTester: def __init__(self, parent, is_training=False): diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 436f1f965e90..81ea53b49f88 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -131,6 +131,10 @@ def test_generate_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self): pass + @unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.") + def test_generate_continue_from_inputs_embeds(self): + pass + # overwrite because HybridCache has fixed length for key/values def _check_attentions_for_generate( self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py index 0444ad14f269..634dfcf61565 100644 --- a/tests/models/fuyu/test_modeling_fuyu.py +++ b/tests/models/fuyu/test_modeling_fuyu.py @@ -325,6 +325,10 @@ def test_disk_offload_safetensors(self): def test_model_parallelism(self): super().test_model_parallelism() + @unittest.skip(reason="Fuyu `prepare_inputs_for_generation` function doesn't have cache position.") + def test_generate_continue_from_inputs_embeds(): + pass + @slow @require_torch_accelerator diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index adaf0fcc34ac..f51d0199156c 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -358,6 +358,10 @@ def test_disk_offload_bin(self): def test_disk_offload_safetensors(self): pass + @unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities input.") + def test_generate_continue_from_inputs_embeds(self): + pass + @is_flaky(max_attempts=5, description="flaky on some models.") def test_save_load(self): super().test_save_load() @@ -919,6 +923,10 @@ def test_disk_offload_bin(self): def test_disk_offload_safetensors(self): pass + @unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities") + def test_generate_continue_from_inputs_embeds(self): + pass + @is_flaky(max_attempts=5, description="flaky on some models.") def test_save_load(self): super().test_save_load() diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index de02783d93eb..d76a7ba1e20c 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -342,12 +342,6 @@ def test_beam_search_low_memory(self): def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip( - reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs" - ) - def test_generate_continue_from_inputs_embeds(self): - pass - @unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`") def test_generate_compile_fullgraph(self): pass diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 2bd6732514c6..c876e598e867 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -333,6 +333,10 @@ def test_past_key_values_format(self): """ pass + @unittest.skip(reason="Zamba2 has hybrid cache.") + def test_generate_continue_from_inputs_embeds(self): + pass + @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") def test_multi_gpu_data_parallel_forward(self): pass From 52f9394ede926b5acda09c99a3755e44f96839c8 Mon Sep 17 00:00:00 2001 From: yaswant19 Date: Wed, 5 Feb 2025 13:22:59 +0530 Subject: [PATCH 09/12] Refactor `prepare_input_for_generation` func --- src/transformers/models/chameleon/modeling_chameleon.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 8aa937413431..3bc6a43d6f56 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1672,7 +1672,10 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if inputs_embeds is not None and input_ids.shape[1] == 0: + position_ids = position_ids[:, -inputs_embeds.shape[1] :] + else: + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: From 8b718dc502f9fa28e90133da808bb0ca319682e4 Mon Sep 17 00:00:00 2001 From: yaswant19 Date: Wed, 5 Feb 2025 18:10:44 +0530 Subject: [PATCH 10/12] Continue generation using input embeds and cache --- .../models/idefics/modeling_idefics.py | 7 ++- tests/models/idefics/test_modeling_idefics.py | 59 +++++++++++++++++++ tests/models/moshi/test_modeling_moshi.py | 3 +- 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 48cf500b7454..6857fb624c0f 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1690,8 +1690,13 @@ def prepare_inputs_for_generation( # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) + + # If past_key_values are present then slice the postion ids for only only the unprocessed tokens. if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if inputs_embeds is not None and input_ids.shape[1] == 0: + position_ids = position_ids[:, -inputs_embeds.shape[1] :] + else: + position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 01871e81b30e..cc9efc967db2 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -755,6 +755,65 @@ def test_generate_without_input_ids(self): ) self.assertIsNotNone(output_ids_generate) + @pytest.mark.generate + def test_generate_continue_from_inputs_embeds(self): + """Overwrite for IDEFICS: Ensure image attention mask is processed while continuing from `inputs_embeds`.""" + + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + print(inputs) + + model = model_class(config).to(torch_device).eval() + + model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 + model.generation_config.forced_eos_token_id = None + model.generation_config.use_cache = True + + input_ids = inputs.pop("input_ids") + input_embeds = model.get_input_embeddings()(input_ids) + + generation_kwargs = { + "return_dict_in_generate": True, + "do_sample": False, + } + + inputs["inputs_embeds"] = input_embeds + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + outputs = model.generate(**inputs, max_new_tokens=4, **generation_kwargs) + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + initial_output = model.generate(**inputs, max_new_tokens=3, **generation_kwargs) + inputs["past_key_values"] = initial_output.past_key_values + + new_attention_len = input_ids.shape[1] + initial_output.sequences.shape[-1] + continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1) + inputs["inputs_embeds"] = continued_embeds + + if "attention_mask" in inputs: + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], + (0, new_attention_len - inputs["attention_mask"].shape[1]), + mode="constant", + value=1, + ) + if "image_attention_mask" in inputs: + inputs["image_attention_mask"] = inputs["image_attention_mask"][..., -1:, :] + + cached_output = model.generate(**inputs, max_new_tokens=1, **generation_kwargs) + + # Verify that the combined outputs match the full generation. + combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1) + self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist()) + for layer_idx in range(len(cached_output.past_key_values)): + for kv_idx in range(len(cached_output.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + cached_output.past_key_values[layer_idx][kv_idx], + ) + ) + def _check_attentions_for_generate( self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 ): diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index f51d0199156c..09278f0d24c4 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -358,7 +358,7 @@ def test_disk_offload_bin(self): def test_disk_offload_safetensors(self): pass - @unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities input.") + @unittest.skip(reason="Test becomes too complex with Moshi requiring multiple input modalities.") def test_generate_continue_from_inputs_embeds(self): pass @@ -828,6 +828,7 @@ def test_generate_without_input_ids(self): output_ids_generate = model.generate( do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True ) + print(output_ids_generate) self.assertIsNotNone(output_ids_generate) @unittest.skip(reason="The audio encoder has no gradients.") From 5eb97d455435f82db8d638b687cf3ec69f160e30 Mon Sep 17 00:00:00 2001 From: yaswant19 Date: Wed, 5 Feb 2025 20:00:33 +0530 Subject: [PATCH 11/12] Modular changes fix --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- src/transformers/models/zamba2/modeling_zamba2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 0b7348671148..82b112ad3665 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1261,7 +1261,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index f2d7d21a743e..8f00780f341d 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1557,7 +1557,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. From 73b0ed8f244ff94156550d019177843ebfcebe38 Mon Sep 17 00:00:00 2001 From: yaswant19 Date: Wed, 5 Feb 2025 21:14:44 +0530 Subject: [PATCH 12/12] Overwrite 'prepare_inputs_for_generation' function --- .../models/moshi/modeling_moshi.py | 66 +++++++++++++++---- 1 file changed, 52 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 01d2ff1940fc..6bde89f9aab5 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1901,8 +1901,7 @@ def forward( @add_start_docstrings( - "The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, " - "for speech-to-speech.", + "The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, for speech-to-speech.", MOSHI_START_DOCSTRING, ) class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): @@ -2458,18 +2457,57 @@ def prepare_inputs_for_generation( blank_user_audio_codes: Optional[torch.FloatTensor] = None, **kwargs, ): - # Overwritten -- Moshi has custom post-processing - # 1. Do usual operations done on LLMs like Gemma - because we pre-processed inputs, the first pass always has inputs_embeds - model_inputs = super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - position_ids=position_ids, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - **kwargs, + # Overwritten -- Moshi has custom post-processing on the prepared inputs. + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + + if past_key_values is not None: + if ( + inputs_embeds is not None # Exception 1 + or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "cache_position": cache_position, + } ) # 2. Now that everything is prepared, generate audio_codes using the depth decoder