-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Kyle Sayers <[email protected]>
- Loading branch information
Showing
2 changed files
with
145 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import contextlib | ||
from functools import wraps | ||
from typing import Any, Callable, ClassVar, List | ||
|
||
import torch | ||
from pydantic import BaseModel | ||
from torch.utils.hooks import RemovableHandle | ||
|
||
__all__ = ["HooksMixin"] | ||
|
||
|
||
class HooksMixin(BaseModel): | ||
""" | ||
Mixin to manage hook registration, disabling, and removal. | ||
Modifiers should use `self.register_hook(module, hook, hook_type)` | ||
for hook registration and `self.remove_hooks()` for removal. | ||
Modifiers which implement hooks should register them using | ||
`self.register_..._hook(module, hook)` rather than the usual | ||
`module.register_..._hook(hook)`. Modifiers should remove hooks with | ||
`self.remove_hooks()` | ||
Lifecycle: | ||
- modifier.register_forward_hook(module, hook) | ||
- with HooksMixin.disable_hooks(): model.forward() | ||
- modifier.remove_hooks() | ||
""" | ||
|
||
_HOOKS_DISABLED: ClassVar[bool] = False # attached to global HooksMixin | ||
_hooks: List[RemovableHandle] = [] # attached to local subclasses | ||
|
||
@classmethod | ||
@contextlib.contextmanager | ||
def disable_hooks(cls): | ||
"""Disable all hooks across all modifiers""" | ||
try: | ||
cls._HOOKS_DISABLED = True | ||
yield | ||
finally: | ||
cls._HOOKS_DISABLED = False | ||
|
||
def register_hook( | ||
self, | ||
module: torch.nn.Module, | ||
func: Callable[[Any], Any], | ||
hook_type: str, | ||
**kwargs, | ||
): | ||
@wraps(func) | ||
def wrapped_hook(*args, **kwargs): | ||
if HooksMixin._HOOKS_DISABLED: | ||
return | ||
|
||
return func(*args, **kwargs) | ||
|
||
handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) | ||
self._hooks.append(handle) | ||
|
||
def remove_hooks(self): | ||
""" | ||
Remove all hooks belonging to a modifier | ||
""" | ||
for hook in self._hooks: | ||
hook.remove() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import torch | ||
|
||
from llmcompressor.modifiers.utils.hooks import HooksMixin | ||
|
||
|
||
class DummyModel(torch.nn.Module): | ||
def __init__(self): | ||
super(DummyModel, self).__init__() | ||
|
||
self.linear1 = torch.nn.Linear(1, 2) | ||
self.linear2 = torch.nn.Linear(2, 3) | ||
self.linear3 = torch.nn.Linear(3, 1) | ||
self.dummy_inputs = torch.tensor([0.0]) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
x = self.linear3(x) | ||
|
||
return x | ||
|
||
|
||
class DummyMod(HooksMixin): | ||
hook_called: bool = False | ||
|
||
def hook(self, *args, **kwargs): | ||
self.hook_called = True | ||
|
||
|
||
class ModA(DummyMod): | ||
pass | ||
|
||
|
||
class ModB(DummyMod): | ||
pass | ||
|
||
|
||
def test_register_hook(): | ||
model = DummyModel() | ||
|
||
mod_a = ModA() | ||
mod_a.register_hook(model.linear1, mod_a.hook, "forward") | ||
|
||
mod_b = ModB() | ||
mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") | ||
|
||
model(model.dummy_inputs) | ||
assert mod_a.hook_called and mod_b.hook_called | ||
|
||
|
||
def test_remove_hooks(): | ||
model = DummyModel() | ||
|
||
mod_a = ModA() | ||
mod_a.register_hook(model.linear1, mod_a.hook, "forward") | ||
|
||
mod_b = ModB() | ||
mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") | ||
mod_b.remove_hooks() | ||
|
||
model(model.dummy_inputs) | ||
assert mod_a.hook_called and not mod_b.hook_called | ||
|
||
|
||
def test_disable_hooks(): | ||
model = DummyModel() | ||
|
||
mod_a = ModA() | ||
mod_a.register_hook(model.linear1, mod_a.hook, "forward") | ||
|
||
mod_b = ModB() | ||
mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") | ||
|
||
with HooksMixin.disable_hooks(): | ||
model(model.dummy_inputs) | ||
assert not mod_a.hook_called and not mod_b.hook_called | ||
|
||
mod_a.hook_called = False | ||
mod_b.hook_called = False | ||
model(model.dummy_inputs) | ||
assert mod_a.hook_called and mod_b.hook_called |