Skip to content

Commit

Permalink
feat: track tensor allocations during tracing
Browse files Browse the repository at this point in the history
This patch adds the ability to track needed tensor allocations during the
tracing process. To do this, a new `store` statement and term expression are
added.  A more complete implementation of the tensor_scalar operator is
included as a test case for the new mechanism.
  • Loading branch information
govereau committed Feb 12, 2025
1 parent f9b53cd commit 40ec18b
Show file tree
Hide file tree
Showing 15 changed files with 665 additions and 92 deletions.
1 change: 1 addition & 0 deletions KLR.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Authors: Paul Govereau, Sean McLaughlin
-/
import KLR.Core
import KLR.Python
import KLR.Trace
5 changes: 4 additions & 1 deletion KLR/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau, Sean McLaughlin
-/
import KLR.Core.Basic
import KLR.Core.Encode
-- TODO: fix encoder
--import KLR.Core.Encode
import KLR.Core.FromToJson
import KLR.Core.Pretty
130 changes: 108 additions & 22 deletions KLR/Core/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,88 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau, Sean McLaughlin
-/
import Lean

/-!
# Abstract syntax of Core NKL language
This language is the result of "tracing", and is used as the
portable format, a.k.a. Kernel Language Representation (KLR).
-/

namespace KLR.Core

-- TODO switch to TensorLib's version of these types
--export TensorLib (Tensor Dtype Shape)
-- Compute Engines

abbrev Dtype := String
abbrev Shape := List Int
inductive Engine where
| unassigned
| pool
| act
| pe
| dma
| dve
| sp
deriving BEq, Repr

-- ALU operations
-- TODO organize these into groups
inductive AluOp where
| abs
| add
| arith_shift_left
| arith_shift_right
| average
| bitwise_and
| bitwise_not
| bitwise_or
| bitwise_xor
| bypass
| divide
| elemwise_mul
| is_equal
| is_ge
| is_gt
| is_le
| is_lt
| logical_and
| logical_or
| logical_shift_left
| logical_shift_right
| logical_xor
| max
| min
| mod
| mult
| not_equal
| pow
| rsqrt
| subtract
deriving BEq, Repr

/-
A TensorName is essentially a typed variable, where the type
must be a tensor type. When we flush out Typ below we may replace
this with `Expr.var name (Typ.tensor dtype shape)`. For now, this
only refers to dynamic tensors, or compile-time tensors, not
trace-time tensors.
A TensorName is essentially a typed variable, where the type must be a
tensor type. This only refers to dynamic tensors, or compile-time
tensors, not trace-time tensors.
-/
abbrev Dtype := String
abbrev Shape := List Nat

inductive Memory where
| dram | sbuf | pmem
deriving Repr, BEq

structure TensorName where
name : String
dtype : Dtype
shape : Shape
deriving Repr, BEq, Lean.ToJson

-- TODO
inductive Typ where
memory: Memory
deriving Repr, BEq

inductive Const where
| none
| bool (value : Bool)
| int (value : Int)
| float (value : Float)
| string (value : String)
deriving Repr, BEq, Lean.ToJson
deriving Repr, BEq

-- This corresponds to the "Quasi-Affine Expressions" in Neuron.
-- Note, `floor` is the usual integer division.
Expand All @@ -56,27 +97,72 @@ inductive IndexExpr where
| floor (expr : IndexExpr) (scalar : Int)
| ceil (expr : IndexExpr) (scalar : Int)
| mod (expr : IndexExpr) (scalar : Int)
deriving Repr, BEq, Lean.ToJson
deriving Repr, BEq

-- Note: `np.newindex` is represented as `(.coord none)`
inductive Index where
| ellipsis
| coord (e : Option IndexExpr)
| slice (l u step : Option IndexExpr)
deriving Repr, BEq, Lean.ToJson
deriving Repr, BEq

structure TensorScalar where
op0 : AluOp
const0 : Float
reverse0 : Bool
op1 : AluOp
const1 : Float
reverse1 : Bool
deriving Repr, BEq

inductive Operator where
| tensorScalar : TensorScalar -> Operator
deriving Repr, BEq

inductive Expr where
| var (x : String)
| const (c : Const)
| tensor (t : TensorName)
| access (t : Expr) (ix : List Index)
| operator (op : Operator)
| call (f : Expr) (args : List Expr) (kwargs : List (String × Expr))
deriving Repr, BEq, Lean.ToJson
deriving Repr, BEq

inductive Stmt where
| pass
| expr (v : Expr)
| ret (v : Expr)
| store (t : TensorName) (ix : List Index) (e : Expr)
| assign (x : String) (e : Expr)
| loop (x : String) (l u step : IndexExpr) (body : List Stmt)
deriving Repr, BEq, Lean.ToJson
deriving Repr, BEq

structure Kernel where
name : String
inputs : List TensorName
outputs : List TensorName
body : List Stmt
deriving Repr, BEq

-- TODO: not efficient
partial def Expr.tensors : Expr -> List TensorName :=
tensors []
where
tensors (l : List TensorName) : Expr -> List TensorName
| .var _ => l
| .const _ => l
| .tensor t => l.insert t
| .access t _ => tensors l t
| .operator _ => l
| .call f args kwargs =>
let l := tensors l f
let l := args.foldl tensors l
kwargs.foldl (fun l kw => tensors l kw.snd) l

partial def Stmt.tensors : Stmt -> List TensorName
| .ret e => e.tensors
| .store t _ e => e.tensors.insert t
| .assign _ e => e.tensors
| .loop _ _ _ _ body => (body.map tensors).flatten.eraseDups

def Kernel.internal (k : Kernel) : List TensorName :=
let ts := (k.body.map Stmt.tensors).flatten.eraseDups
(ts.removeAll k.inputs).removeAll k.outputs
88 changes: 88 additions & 0 deletions KLR/Core/FromToJson.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau, Sean McLaughlin
-/
import Lean
import KLR.Core.Basic

/-
# Instances of ToJson for KLR
The instances are placed here, in a separate module, because we need to
manually write/replace some of the instances, and this may interfere with other
code.
-/

namespace KLR.Core
open Lean (FromJson ToJson)

/-
The tools we are interacting use a different encoding of infinity and NaN from
the default instance in Lean.
-/

instance : ToJson Float where
toJson f :=
match Lean.JsonNumber.fromFloat? f with
| .inr n => .num n
| .inl "NaN" => .str "nan"
| .inl "Infinity" => .str "inf"
| .inl "-Infinity" => .str "-inf"
| _ => panic "internal error"

instance : FromJson Float where
fromJson?
| .str "inf" => return (1.0 / 0.0)
| .str "-inf" => return (-1.0 / 0.0)
| .str "nan" => return (0.0 / 0.0)
| .num jn => return jn.toFloat
| _ => throw "Expected a number or 'inf, '-inf, 'nan."

instance : ToJson Engine where
toJson
| .unassigned => .str "Unassigned"
| .pool => .str "Pool"
| .act => .str "Activation"
| .pe => .str "PE"
| .dma => .str "DMA"
| .dve => .str "DVE"
| .sp => .str "SP"

instance : FromJson Engine where
fromJson?
| .str "Unassigned" => return .unassigned
| .str "Pool" => return .pool
| .str "Activation" => return .act
| .str "PE" => return .pe
| .str "DMA" => return .dma
| .str "DVE" => return .dve
| .str "SP" => return .sp
| .str s => throw s!"unknown engine type {s}"
| _ => throw "expecting engine type"

deriving instance ToJson for AluOp
deriving instance ToJson for Memory
deriving instance ToJson for TensorName
deriving instance ToJson for Const
deriving instance ToJson for IndexExpr
deriving instance ToJson for Index

deriving instance FromJson for AluOp
deriving instance FromJson for Memory
deriving instance FromJson for TensorName
deriving instance FromJson for Const
deriving instance FromJson for IndexExpr
deriving instance FromJson for Index

deriving instance ToJson for TensorScalar
deriving instance ToJson for Operator
deriving instance ToJson for Expr
deriving instance ToJson for Stmt
deriving instance ToJson for Kernel

deriving instance FromJson for TensorScalar
deriving instance FromJson for Operator
deriving instance FromJson for Expr
deriving instance FromJson for Stmt
deriving instance FromJson for Kernel
42 changes: 39 additions & 3 deletions KLR/Core/Pretty.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@ private def abracket (f : Format) : Format :=
private def ppArgs [ToFormat a] (l : List a) : Format :=
Format.joinSep l ","

def ppMemory : Memory -> Format
| .dram => "dram"
| .sbuf => "sbuf"
| .pmem => "pmem"

def ppTensor (t : TensorName) : Format :=
"%" ++ t.name ++ abracket (t.dtype ++ ":" ++ ppArgs t.shape)
t.name ++ abracket (.joinSep [
format t.dtype,
.paren (.joinSep t.shape ","),
ppMemory t.memory
] ",")

def ppConst : Const -> Format
| .none => "None"
Expand Down Expand Up @@ -53,28 +62,55 @@ def ppIndexExpr? : Option IndexExpr -> Format
def ppIndex : Index -> Format
| .ellipsis => "..."
| .coord e => ppIndexExpr? e
| .slice none none none => ":"
| .slice none u none => "0:" ++ ppIndexExpr? u
| .slice s u none => .joinSep ([s,u].map ppIndexExpr?) ":"
| .slice l u s => .joinSep ([l,u,s].map ppIndexExpr?) ":"

private def ppList (f : a -> Format) : List a -> Format
| [] => .nil
| x :: xs => .append (f x) (ppList f xs)

partial def ppExpr : Expr -> Format
| .var x => x
| .const c => ppConst c
| .tensor t => ppTensor t
| .access t ix => .fill (ppExpr t ++ .sbracket (.joinSep (ix.map ppIndex) ","))
| .operator _ => "operator"
| .call f args kwargs =>
let args := args.map ppExpr
let kwargs := kwargs.map fun (x,e) => x ++ "=" ++ ppExpr e
.fill (ppExpr f ++ .paren (ppArgs (args ++ kwargs)))

def ppStmt : Stmt -> Format
| .pass => "pass"
| .expr e => ppExpr e
| .ret e => "ret" ++ ppExpr e
| .store t ix e => ppExpr (.access (.tensor t) ix) ++ " := " ++ ppExpr e
| .assign x e => x ++ " = " ++ ppExpr e
| .loop _ _ _ _ _ => "<loop>"

def ppFullTensor (t : TensorName) : Format :=
t.name ++ abracket (.joinSep [
format t.dtype,
.paren (.joinSep t.shape ","),
ppMemory t.memory
] ",")

def lines (l : List Format) := Format.joinSep l "\n"
def nest_lines (l : List Format) := Format.nest 2 (.align true ++ lines l)

def ppKernel (k : Kernel) : Format :=
lines [
Format.text k.name,
"inputs:", nest_lines (k.inputs.map ppFullTensor),
"outputs:", nest_lines (k.outputs.map ppFullTensor),
"internal:", nest_lines (k.internal.map ppFullTensor),
"body:", nest_lines (k.body.map ppStmt)
]

instance : ToFormat TensorName where format := ppTensor
instance : ToFormat Const where format := ppConst
instance : ToFormat IndexExpr where format := ppIndexExpr 0
instance : ToFormat Index where format := ppIndex
instance : ToFormat Expr where format := ppExpr
instance : ToFormat Stmt where format := ppStmt
instance : ToFormat Kernel where format := ppKernel
1 change: 1 addition & 0 deletions KLR/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ structure Kernel where
args : List Expr'
kwargs : List (String × Expr')
globals : List (String × Expr')
deriving Repr

/-
POC: try to guess suitable arguments if none suplied (see bin/gather).
Expand Down
10 changes: 5 additions & 5 deletions KLR/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import KLR.Trace.Basic
import KLR.Trace.Builtin
import KLR.Trace.Python
import KLR.Trace.NKI
import KLR.Trace.Numpy

namespace KLR.Trace

def runNKIKernel (k : KLR.Python.Kernel) : Err (List KLR.Core.Stmt) :=
tracer ⟨ .ofList NKIEnv, #[] ⟩ do
traceKernel k
let g <- get
return g.body.toList
def globalEnv := NKIEnv ++ NumpyEnv

def runNKIKernel (k : KLR.Python.Kernel) : Err KLR.Core.Kernel :=
tracer ⟨ .ofList globalEnv, #[] ⟩ (traceKernel k)
Loading

0 comments on commit 40ec18b

Please sign in to comment.