Skip to content

Commit

Permalink
Support Half/BFloat16 in op_allclose
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#7766

We incorrectly required these types to be bitwise-identical rather than close.

(I had to develop this internally because the op_allclose_test doesn't run in OSS.)

Differential Revision: [D68366831](https://our.internmc.facebook.com/intern/diff/D68366831/)
ghstack-source-id: 262600586

Co-authored-by: Scott Wolchok <[email protected]>
  • Loading branch information
2 people authored and Zonglin Peng committed Jan 30, 2025
1 parent 589aacc commit fcfa618
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 220 deletions.
14 changes: 14 additions & 0 deletions kernels/portable/cpu/op_allclose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ bool tensors_are_close(
a.numel(),
rtol,
atol);
} else if (a.scalar_type() == ScalarType::Half) {
return data_is_close<Half>(
a.const_data_ptr<Half>(),
b.const_data_ptr<Half>(),
a.numel(),
rtol,
atol);
} else if (a.scalar_type() == ScalarType::BFloat16) {
return data_is_close<BFloat16>(
a.const_data_ptr<BFloat16>(),
b.const_data_ptr<BFloat16>(),
a.numel(),
rtol,
atol);
} else {
// Non-floating-point types can be compared bitwise.
return memcmp(a.mutable_data_ptr(), b.mutable_data_ptr(), a.nbytes()) == 0;
Expand Down
Loading

0 comments on commit fcfa618

Please sign in to comment.