Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Nov 15, 2024
1 parent e1c589e commit ec59d6c
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/llmcompressor/modifiers/utils/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,33 @@ def disable_hooks(cls):
def register_hook(
self,
module: torch.nn.Module,
func: Callable[[Any], Any],
hook: Callable[[Any], Any],
hook_type: str,
**kwargs,
):
@wraps(func)
"""
Registers a hook on a specified module with the option to disable it with
HooksMixin.disable_hooks
:param module: the module on which the hook should be registered
:param hook: the hook to register
:param hook_type: the type of hook to register corresponding to the
`register_{hook_type}_hook` attribute on torch.nn.Module.
Ex. "forward", "forward_pre", "full_backward", "state_dict_post"
:param kwargs: keyword arguments to pass to register hook method
"""

@wraps(hook)
def wrapped_hook(*args, **kwargs):
if HooksMixin._HOOKS_DISABLED:
return

return func(*args, **kwargs)
return hook(*args, **kwargs)

handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs)
self._hooks.append(handle)

def remove_hooks(self):
"""
Remove all hooks belonging to a modifier
"""
"""Remove all hooks belonging to a modifier"""
for hook in self._hooks:
hook.remove()

0 comments on commit ec59d6c

Please sign in to comment.