Skip to content

Commit

Permalink
Fix clipping bound scaling by using the reciprocal of the number of d…
Browse files Browse the repository at this point in the history
…evices explicitly in PercoreClippedDpSgdGradient.

PiperOrigin-RevId: 588481110
  • Loading branch information
The paxml Authors committed Dec 6, 2023
1 parent cfc5dfb commit fbf3705
Showing 1 changed file with 41 additions and 8 deletions.
49 changes: 41 additions & 8 deletions paxml/sgf.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ class PercoreClippedDpSgdGradient(BaseStochasticGradient):
l2_norm_clip: The L2 clipping bound used to clip per-core gradients.
noise_multiplier: The noise multiplier used to decide the noise scale. See
Section 5.3.2 of https://arxiv.org/pdf/2303.00654.pdf for more details.
clipping_bound_scaling: The scaling of the clipping bound to adjust to the
gradient scaling. When running under pmap, losses and gradients are scaled
using aux.loss_weight on each TPU core so clipping bound should be scaled
accordingly. Legal values include 1) `None` (default): Set
clipping_bound_scaling to 1. / jax.device_count() under pmap, or throws an
error otherwise. 2) 'use_aux_loss_weight': Set clipping_bound_scaling to
aux.loss_weight. Note for models like CtcModel, aux.loss_weight is
different across TPU cores, which can cause weaker privacy guarantee than
expected in differential privacy. We strongly recommend to only use this
option for empirical privacy. If used for differential privacy, ad hoc
care has to be taken when accounting privacy budget. 3) `float`: Manually
set the value for clipping_bound_scaling.
normalize_gradients: Whether to apply Gradient Normalization as implemented
in eqn 3 of https://arxiv.org/abs/2204.13650 to reduce the dependence
between clipping value and learning rate. Note that normalization is only
Expand All @@ -144,6 +156,7 @@ class PercoreClippedDpSgdGradient(BaseStochasticGradient):

l2_norm_clip: float = 0.0
noise_multiplier: float = 0.0
clipping_bound_scaling: float | str | None = None
normalize_gradients: bool = False
adaptive_clipping_method: str | None = None

Expand All @@ -152,7 +165,10 @@ def _clip_gradients(
) -> tuple[NestedMap, jax.Array, Any]:
assert (
self.adaptive_clipping_method is not None or self.l2_norm_clip > 0.0
), f'Clipping bound must be positive. {l2_norm_clip} is provided.'
), (
f'Clipping bound must be either adaptive or positive. {l2_norm_clip} is'
' provided.'
)

# Clip the per-core mean gradient.
grads_flat, grads_treedef = jax.tree_util.tree_flatten(grads)
Expand All @@ -168,7 +184,7 @@ def _add_noise( # pytype: disable=annotation-type-mismatch # jax-ndarray
self,
grads: NestedMap,
noise_stddev: float,
loss_weight: float,
clipping_bound_scaling: float,
prng_key: PRNGKey = None,
) -> NestedMap:
prng_keys = jax.random.split(
Expand All @@ -179,14 +195,14 @@ def _add_noise( # pytype: disable=annotation-type-mismatch # jax-ndarray
)

if base_layer.is_running_under_pmap():
# Note: when running under pmap, loss_weight is set to 1/num_devices.
# In this case, the *global* batch size is batch_size / loss_weight.
# Note: Because `l2_norm_clip` is scaled with `clipping_bound_scaling`, we
# need to scale the noise_std accordingly to compensate for the change.
# Moreover, each device adds independent Gaussian noises, and then the
# noisy gradients are added with `psum``. Because the sum of num_devices
# copies of independent Gaussian noises is equivalent to a single Gaussian
# with std scaled by `sqrt(num_devices)``, we need to further scale the
# noise_std on each device to correct this.
noise_stddev *= loss_weight * jnp.sqrt(loss_weight)
noise_stddev *= clipping_bound_scaling * jnp.sqrt(clipping_bound_scaling)

def _add_noise_to_array(x, prng):
return x + noise_stddev * jax.random.normal(prng, shape=x.shape)
Expand Down Expand Up @@ -214,11 +230,26 @@ def grad_fn(
)
aux = self.process_aux_info(aux)

if self.clipping_bound_scaling is None:
if base_layer.is_running_under_pmap():
self.clipping_bound_scaling = 1.0 / jax.device_count()
else:
raise ValueError(
'clipping_bound_scaling must be set explicitly when'
'not running under pmap.'
)
elif self.clipping_bound_scaling == 'use_aux_loss_weight':
self.clipping_bound_scaling = aux.loss_weight
elif not isinstance(self.clipping_bound_scaling, float):
raise ValueError(
f'Unsupported clipping_bound_scaling: {self.clipping_bound_scaling}'
)

if self.adaptive_clipping_method == 'min':
grads_norm = optax.global_norm(grads)
self.l2_norm_clip = (
jax.lax.pmin(grads_norm, axis_name=PMAP_PARALLEL_AXIS_NAME)
/ aux.loss_weight
/ self.clipping_bound_scaling
)
elif self.adaptive_clipping_method is not None:
raise ValueError(
Expand All @@ -227,7 +258,7 @@ def grad_fn(
)

grads, num_clipped, grad_norm = self._clip_gradients(
grads, aux.loss_weight * self.l2_norm_clip
grads, self.clipping_bound_scaling * self.l2_norm_clip
)

if self.normalize_gradients:
Expand All @@ -238,7 +269,9 @@ def grad_fn(
noise_stddev = self.noise_multiplier * (
1.0 if self.normalize_gradients else self.l2_norm_clip
)
grads = self._add_noise(grads, noise_stddev, aux.loss_weight, prng_key)
grads = self._add_noise(
grads, noise_stddev, self.clipping_bound_scaling, prng_key
)

return (
values,
Expand Down

0 comments on commit fbf3705

Please sign in to comment.