Skip to content

Commit

Permalink
Implement HooksMixin
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Nov 14, 2024
1 parent 93832a6 commit e1c589e
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 0 deletions.
64 changes: 64 additions & 0 deletions src/llmcompressor/modifiers/utils/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import contextlib
from functools import wraps
from typing import Any, Callable, ClassVar, List

import torch
from pydantic import BaseModel
from torch.utils.hooks import RemovableHandle

__all__ = ["HooksMixin"]


class HooksMixin(BaseModel):
"""
Mixin to manage hook registration, disabling, and removal.
Modifiers should use `self.register_hook(module, hook, hook_type)`
for hook registration and `self.remove_hooks()` for removal.
Modifiers which implement hooks should register them using
`self.register_..._hook(module, hook)` rather than the usual
`module.register_..._hook(hook)`. Modifiers should remove hooks with
`self.remove_hooks()`
Lifecycle:
- modifier.register_forward_hook(module, hook)
- with HooksMixin.disable_hooks(): model.forward()
- modifier.remove_hooks()
"""

_HOOKS_DISABLED: ClassVar[bool] = False # attached to global HooksMixin
_hooks: List[RemovableHandle] = [] # attached to local subclasses

@classmethod
@contextlib.contextmanager
def disable_hooks(cls):
"""Disable all hooks across all modifiers"""
try:
cls._HOOKS_DISABLED = True
yield
finally:
cls._HOOKS_DISABLED = False

def register_hook(
self,
module: torch.nn.Module,
func: Callable[[Any], Any],
hook_type: str,
**kwargs,
):
@wraps(func)
def wrapped_hook(*args, **kwargs):
if HooksMixin._HOOKS_DISABLED:
return

return func(*args, **kwargs)

handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs)
self._hooks.append(handle)

def remove_hooks(self):
"""
Remove all hooks belonging to a modifier
"""
for hook in self._hooks:
hook.remove()
81 changes: 81 additions & 0 deletions tests/llmcompressor/modifiers/utils/test_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch

from llmcompressor.modifiers.utils.hooks import HooksMixin


class DummyModel(torch.nn.Module):
def __init__(self):
super(DummyModel, self).__init__()

self.linear1 = torch.nn.Linear(1, 2)
self.linear2 = torch.nn.Linear(2, 3)
self.linear3 = torch.nn.Linear(3, 1)
self.dummy_inputs = torch.tensor([0.0])

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)

return x


class DummyMod(HooksMixin):
hook_called: bool = False

def hook(self, *args, **kwargs):
self.hook_called = True


class ModA(DummyMod):
pass


class ModB(DummyMod):
pass


def test_register_hook():
model = DummyModel()

mod_a = ModA()
mod_a.register_hook(model.linear1, mod_a.hook, "forward")

mod_b = ModB()
mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre")

model(model.dummy_inputs)
assert mod_a.hook_called and mod_b.hook_called


def test_remove_hooks():
model = DummyModel()

mod_a = ModA()
mod_a.register_hook(model.linear1, mod_a.hook, "forward")

mod_b = ModB()
mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre")
mod_b.remove_hooks()

model(model.dummy_inputs)
assert mod_a.hook_called and not mod_b.hook_called


def test_disable_hooks():
model = DummyModel()

mod_a = ModA()
mod_a.register_hook(model.linear1, mod_a.hook, "forward")

mod_b = ModB()
mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre")

with HooksMixin.disable_hooks():
model(model.dummy_inputs)
assert not mod_a.hook_called and not mod_b.hook_called

mod_a.hook_called = False
mod_b.hook_called = False
model(model.dummy_inputs)
assert mod_a.hook_called and mod_b.hook_called

0 comments on commit e1c589e

Please sign in to comment.