Skip to content

Commit

Permalink
test: add some test kernels for tensor_scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
govereau committed Feb 13, 2025
1 parent 3e17315 commit c8dbe11
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 7 deletions.
14 changes: 9 additions & 5 deletions KLR/Trace/Tensor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def Expr.inferTensor : Expr -> Err (TensorName × List Index)
match <- inferTensor t with
| (t, [.ellipsis]) => return (t, ix)
| _ => throw "unsupported tensor expression"
| _ => throw "expecting tensor"
| _ => throw "expecting tensor expression"

def Term.inferTensor : Term -> Err (TensorName × List Index)
| .expr e (.tensor _ _) => Expr.inferTensor e
| .expr e _ => Expr.inferTensor e
| _ => throw "expecting tensor"

-- This only handles the simple cases
Expand Down Expand Up @@ -84,10 +84,14 @@ def store_expr (tag : String)
(dtype : Dtype) (memory : Memory) (src : Term)
: TraceM Term := do
match src with
| .expr e (.tensor _ shape) => do
| .expr e (.tensor dtype shape) => do
let dst <- declare tag dtype shape memory
return .store dst [.ellipsis] e
| _ => throw "expecting tensor"
| .expr e _ => do
let shape <- Expr.inferShape e
let dst <- declare tag dtype shape memory
return .store dst [.ellipsis] e
| _ => throw "expecting tensor in store"

-- APIs

Expand Down Expand Up @@ -188,7 +192,7 @@ def store : GlobalFn :=
fun
| [.expr dst (.tensor _ s₁), .expr src (.tensor _ s₂)] => do
if s₁ != s₂ then
throw "incompatible shapes {s₁} {s₂}"
throw s!"incompatible shapes {s₁} {s₂}"
let (t₁, i₁) <- Expr.inferTensor dst
let (t₂, i₂) <- Expr.inferTensor src
let src := Expr.access (.tensor t₂) i₂
Expand Down
4 changes: 2 additions & 2 deletions KLR/Trace/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def Term.format : Term -> Lean.Format
| .list l => .text s!"list<{l.length}>"
| .ellipsis => .text "ellipsis"
| .slice a b c => .text s!"slice({a},{b},{c})"
| .store t ix e => Lean.format (Stmt.store t ix e)
| .expr e ty => Lean.format e ++ ":" ++ repr ty
| .store t ix e => repr (Stmt.store t ix e)
| .expr e ty => repr e ++ ":" ++ repr ty

instance : Repr Term where reprPrec b _ := b.format

Expand Down
60 changes: 60 additions & 0 deletions interop/test/test_nki_isa_tensor_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
Copyright (C) 2025, Amazon.com. All Rights Reserved
"""
import unittest

import nki.isa as nisa
import nki.language as nl
import numpy as np

"""
Unit tests for tensor_scalar.
If using these tests for development, you can generate the NKI.json from
the top-level like so:
PYTHONPATH=interop:interop/test ./bin/gather test_nki_isa_tensor_scalar.kernel1 > kernel1.json
and then, e.g.
lake exe klr trace kernel1.json
lake exe klr compile kernel1.json
"""

# utility function - allocate memory in DRAM
def alloc_like(t):
return nl.ndarray(t.shape, dtype=t.dtype, buffer=nl.shared_hbm)

# utility function - allocate memory in DRAM and copy SBUF tile to it
def dram_tile(a):
b = alloc_like(a)
nl.store(b, a)
return b

# test kernel 1 : t - 1.0 with no access pattern
def kernel1(a):
a_tile = nl.load(a)
b_tile = nisa.tensor_scalar(a_tile, np.subtract, 1.0)
return dram_tile(b_tile)

# test kernel 2 : t - 1.0 with ellipsis access pattern
def kernel2(a):
a_tile = nl.load(a[...])
b_tile = nisa.tensor_scalar(a_tile, np.subtract, 1.0)
return dram_tile(b_tile)

# test kernel 2 : t - 1.0 with simple tile access pattern
def kernel3(a):
a_tile = nl.load(a[0:128,0:512])
b_tile = nisa.tensor_scalar(a_tile, np.subtract, 1.0)
return dram_tile(b_tile)

# The above example will fail tracing with:
# nl.store(b, b_tile)
# ^-- incompatible shapes [10, 10] [128, 512]
# This is because inferArguments is very dumb.
# You can use the kernel below for testing to get proper arguments.
def kernel3b():
a = nl.ndarray((128,512), dtype="float32", buffer=nl.shared_hbm)
return kernel3(a)

0 comments on commit c8dbe11

Please sign in to comment.