Skip to content

Commit

Permalink
Add more tests for OptaxOptimizer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623282400
  • Loading branch information
laurentes authored and pax authors committed Apr 9, 2024
1 parent 0cc8f10 commit 9e29436
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions paxml/learners_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,24 @@ def _get_raw_grad_transformation(self, lr):
self.assertEqual(pspec_3[opt_idx].c.repeat_prefix, [2, 2])
self.assertEqual(pspec_3[opt_idx].c.repeat_prefix_split_dims_mapping,
[('data', 'mdl'), None])

for idx in [0, 1, 3]:
pspec_1_count = typing.cast(NestedMap, pspec_1[idx]).count
pspec_2_count = typing.cast(NestedMap, pspec_2[idx]).count
pspec_3_count = typing.cast(NestedMap, pspec_3[idx]).count
self.assertEqual(pspec_1_count.shape, [])
self.assertEqual(pspec_1_count.repeat_prefix, [2])
self.assertEqual(pspec_1_count.repeat_prefix_split_dims_mapping, [-1])
self.assertEqual(pspec_2_count.shape, [])
self.assertEqual(pspec_2_count.repeat_prefix, [])
self.assertEqual(pspec_2_count.repeat_prefix_split_dims_mapping, [])
self.assertEqual(pspec_3_count.shape, [])
self.assertEqual(pspec_3_count.repeat_prefix, [2, 2])
self.assertEqual(
pspec_3_count.repeat_prefix_split_dims_mapping,
[('data', 'mdl'), None],
)

logging.info(f'Prefix vectorization partition spec .. {partition_spec} ')
state = grad_tx.init(variables)
logging.info('Prefix vectorization state after init .. ')
Expand Down Expand Up @@ -946,6 +964,14 @@ def _opt_update(updates, state, params):
[('data', 'mdl'), None],
)

for idx in [0, 1, 3]:
pspec_1_count = typing.cast(NestedMap, pspec_1[idx]).count
pspec_2_count = typing.cast(NestedMap, pspec_2[idx]).count
pspec_3_count = typing.cast(NestedMap, pspec_3[idx]).count
self.assertEqual(pspec_1_count.shape, (2,))
self.assertEqual(pspec_2_count.shape, ())
self.assertEqual(pspec_3_count.shape, (2, 2))

logging.info('Prefix vectorization state after init .. ')
# Computed update is 0 + state, and state is sum of each variable.
update, state = grad_tx.update(
Expand Down

0 comments on commit 9e29436

Please sign in to comment.