Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question About register_cross_attention_hook and replace_call_method_for_sd3 in Attention Map Visualization #14

Open
Passenger12138 opened this issue Dec 27, 2024 · 2 comments

Comments

@Passenger12138
Copy link

Hello,

First of all, thank you for your excellent work on visualizing attention maps for DiT (Diffusion Transformer). I am currently extending your approach to visualize attention maps for video-based DiT models.

While going through the source code, I encountered the following snippet:

pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn')  
pipeline.transformer = replace_call_method_for_sd3(pipeline.transformer)  

image

I understand that the register_cross_attention_hook function is used to define a hook to capture the attention map during the forward pass. However, I am confused about the necessity of the second line, replace_call_method_for_sd3.

From my understanding, the second line replaces the forward method for SD3Transformer2DModel and its submodules. However, I noticed that the code does not seem to define a custom forward process for SD3Transformer2DModel, and it appears that the original attention computation is already sufficient.

Could you please explain:

  1. Why is replace_call_method_for_sd3 necessary in this context?
  2. If the forward process is not altered, what specific purpose does this replacement serve?

Any clarification or suggestions on this would be greatly appreciated. Thank you again for your work and support!

Best regards,

@wooyeolBaek
Copy link
Owner

@Passenger12138
It’s interesting that you’re applying this to video models! I’d appreciate it if you could share it with me once it’s completed.
Th reason why I use replace_call_method_for_sd3 is to modify the attention operation and to resize the attention map with its proper size. In a single module, q, k, and v are input to produce the attention results directly. Therefore, to obtain the attention map, the module needs to be redefined to return the attention map as well. Additionally, to restore the flattened attention map to its original form before flattening, height information is required. To pass this information to where the attention operation takes place, I replaced the call method. If there’s a way to simplify this structure, I’d appreciate it if you could suggest it.

@Passenger12138
Copy link
Author

I have implemented a basic functionality to visualize the attention maps of VDM (such as CogVideo) models. My first version of the code is available at [my repository](https://github.com/Passenger12138/attention-map-diffusers-vdm.git).

Currently, the implementation uses hook functions to record intermediate attention maps from each block. However, due to the 3D attention mechanism in CogVideo, the attention maps for each block are extremely large, leading to memory OOM issues.

To address this, my current approach saves the intermediate results to disk and loads them during visualization. While functional, this method is too time-consuming for practical use.

In the next version, I plan to optimize this process by recording only the average attention maps instead of saving every intermediate map, allowing for efficient processing in the final step. I am currently developing this version but may require additional time due to my daily work commitments.

Once I complete the new version, I would love to contribute by merging it into your repository. I will keep you updated on my progress.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants