diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 725529cae..1336f3c80 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -1318,7 +1318,7 @@ def distributed_shampoo( ### inverse_failure_threshold=0.1, moving_average_for_momentum=True, - skip_preconditioning_dim_size_gt=0, + skip_preconditioning_dim_size_gt=4096, clip_by_scaled_gradient_norm=None, precision=lax.Precision.HIGHEST, tensordot_precision=None,