Skip to content

Commit

Permalink
Add per-sample gradient norm computation as a functionality (#724)
Browse files Browse the repository at this point in the history
Summary:

Per-sample gradient norm is computed for Ghost Clipping, but it can be useful generally. Exposed it as a functionality.


```
...

loss.backward()
per_sample_norms  = model.per_sample_gradient_norms

```

Reviewed By: iden-kalemaj

Differential Revision: D68634969
  • Loading branch information
EnayatUllah authored and facebook-github-bot committed Feb 6, 2025
1 parent e4eb3fb commit 8359416
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
self.max_grad_norm = max_grad_norm
self.use_ghost_clipping = use_ghost_clipping
self._per_sample_gradient_norms = None

def get_clipping_coef(self) -> torch.Tensor:
"""Get per-example gradient scaling factor for clipping."""
Expand All @@ -131,6 +132,7 @@ def get_norm_sample(self) -> torch.Tensor:
norm_sample = torch.stack(
[param._norm_sample for param in self.trainable_parameters], dim=0
).norm(2, dim=0)
self.per_sample_gradient_norms = norm_sample
return norm_sample

def capture_activations_hook(
Expand Down Expand Up @@ -231,3 +233,17 @@ def capture_backprops_hook(
if len(module.activations) == 0:
if hasattr(module, "max_batch_len"):
del module.max_batch_len

@property
def per_sample_gradient_norms(self) -> torch.Tensor:
"""Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings settings"""
if self._per_sample_gradient_norms is not None:
return self._per_sample_gradient_norms
else:
raise AttributeError(
"per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property."
)

@per_sample_gradient_norms.setter
def per_sample_gradient_norms(self, value):
self._per_sample_gradient_norms = value

0 comments on commit 8359416

Please sign in to comment.