Skip to content

Commit

Permalink
If sharding is not None (that's passed to convert_element_type), only…
Browse files Browse the repository at this point in the history
… compare it with operand's sharding if the sharding is concrete. Otherwise doing `getattr(operand, 'sharding')` on a `Tracer` leads to weird timeouts.

PiperOrigin-RevId: 723595960
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Feb 5, 2025
1 parent f43d2b6 commit 0fb278a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,8 @@ def _convert_element_type(
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
isinstance(operand, Array) and
not (isinstance(operand, core.Tracer) and core.is_concrete(operand)) and
(sharding is None or getattr(operand, 'sharding', None) == sharding)):
(sharding is None or
(sharding._is_concrete and getattr(operand, 'sharding', None) == sharding))):
return operand
else:
return convert_element_type_p.bind(
Expand Down

0 comments on commit 0fb278a

Please sign in to comment.