Skip to content

Commit

Permalink
Some python binding improvements (#775)
Browse files Browse the repository at this point in the history
* Enable user-defined rtol for result checking

* View bfloat16 data as np.int16, as ml_dtypes.bfloat16 isn't supported by xrt
  • Loading branch information
erwei-xilinx authored Nov 15, 2024
1 parent c60d7bd commit f01244c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
5 changes: 5 additions & 0 deletions python/air/backend/xrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import pyxrt as xrt
import os

from ml_dtypes import bfloat16


class XRTCompileArtifact:
"""A class encompassing information on the artifacts produced by compilation for the NPU/XRT"""
Expand Down Expand Up @@ -190,6 +192,9 @@ def invoker(*args):

self.bo_instr.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
for i, a in enumerate(args):
if a.dtype == bfloat16:
# store bfloat16 in binary as int16
a = a.view(np.int16)
bos[i].write(a, 0)
bos[i].sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)

Expand Down
18 changes: 13 additions & 5 deletions python/air/backend/xrt_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def run_test(
mlir_module: np.ndarray,
inputs: List[np.ndarray],
expected_outputs: List[np.ndarray],
rtol: float = 1e-3,
):
if self.verbose:
print("Running module: ")
Expand All @@ -93,7 +94,7 @@ def run_test(
# Remove input slots from the received outputs
actual_outputs = actual_outputs[len(inputs) :]

if self._check_outputs(actual_outputs, expected_outputs):
if self._check_outputs(actual_outputs, expected_outputs, rtol=rtol):
print("PASS!")
return_code = 0
else:
Expand All @@ -102,7 +103,10 @@ def run_test(
return return_code

def _check_outputs(
self, actual_outputs: List[np.ndarray], expected_outputs: List[np.ndarray]
self,
actual_outputs: List[np.ndarray],
expected_outputs: List[np.ndarray],
rtol: float = 1e-3,
):
assert len(actual_outputs) == len(
expected_outputs
Expand All @@ -125,16 +129,20 @@ def _check_outputs(
print(actual)

if expected.dtype in [np.float16, np.float32, np.float64, bfloat16]:
if not np.allclose(actual, expected, rtol=1e-3):
if not np.allclose(actual, expected, rtol=rtol):
print(f"ERROR: Output {i} does not meet expected output.")
print(actual)
print("Expected: ")
print(expected)
print("Actual: ")
print(actual)
return False
else:
if not np.array_equal(actual, expected):
print(f"ERROR: Output {i} does not meet expected output.")
print(actual)
print("Expected: ")
print(expected)
print("Actual: ")
print(actual)
return False

return True

0 comments on commit f01244c

Please sign in to comment.