Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 21, 2024
1 parent 26f74a5 commit f057def
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,18 @@ def abstract(

# Infer output shape
if batched_output:
assert lhs_aval.ndim > 2 and rhs_aval.ndim == 2, (
"Batched output requires batched LHS and non-batched RHS operands."
)
assert (
lhs_aval.ndim > 2 and rhs_aval.ndim == 2
), "Batched output requires batched LHS and non-batched RHS operands."
out_shape = (
*lhs_batch_shape,
lhs_aval.shape[lhs_outer_dim],
rhs_aval.shape[rhs_outer_dim]
rhs_aval.shape[rhs_outer_dim],
)
else:
assert lhs_aval.ndim == rhs_aval.ndim, (
"Non-batched output requires LHS and RHS operands with same number of dimensions."
)
assert (
lhs_aval.ndim == rhs_aval.ndim
), "Non-batched output requires LHS and RHS operands with same number of dimensions."
if lhs_aval.ndim > 2:
rhs_bdims = [
dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim]
Expand Down

0 comments on commit f057def

Please sign in to comment.