diff --git a/python_coreml_stable_diffusion/torch2coreml.py b/python_coreml_stable_diffusion/torch2coreml.py index fc4e633e..3bbf7ea5 100644 --- a/python_coreml_stable_diffusion/torch2coreml.py +++ b/python_coreml_stable_diffusion/torch2coreml.py @@ -322,6 +322,8 @@ def bundle_resources_for_swift_cli(args): from transformers.models.clip import modeling_clip # Copied from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/clip/modeling_clip.py#L677C1-L692C1 +# Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip. +# For backward compatibility with versions < 4.35.0, both functions are patched here. def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_length: int = 0): """ Patch to replace torch.finfo(dtype).min with -1e4 """ @@ -334,8 +336,9 @@ def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_len if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - -modeling_clip._make_causal_mask = patched_make_causal_mask + +modeling_clip._make_causal_mask = patched_make_causal_mask # For transformers >= 4.30.0 and transformers < 4.35.0 +modeling_clip._create_4d_causal_attention_mask = patched_make_causal_mask # For transformers >= 4.35.0 def convert_text_encoder(text_encoder, tokenizer, submodule_name, args): """ Converts the text encoder component of Stable Diffusion