From b94f3578b6933e4a4b086f9dd05e304311beed1a Mon Sep 17 00:00:00 2001 From: Paul Govereau Date: Mon, 13 Jan 2025 14:18:06 -0500 Subject: [PATCH] feat: simple pretty-printer for KLR This change adds two related things: a pretty-printer for KLR terms, and tensor names. Tensor names make the pretty printing nicer, but have a second purpose. By naming all of the tensors, we can scan a KLR kernel to collect up all of the input, output, and intermediate tensors that will be needed to run the kernel. For argument tensors, the generated tensor names are changed to the argument variable names; this is just for readability. --- NKL/FFI.lean | 4 ++- NKL/KLR/Basic.lean | 22 +++++++----- NKL/KLR/Encode.lean | 6 ++-- NKL/KLR/Pretty.lean | 80 +++++++++++++++++++++++++++++++++++++++++++ NKL/Trace/Python.lean | 13 +++++-- NKL/Trace/Tensor.lean | 10 +++--- 6 files changed, 115 insertions(+), 20 deletions(-) create mode 100644 NKL/KLR/Pretty.lean diff --git a/NKL/FFI.lean b/NKL/FFI.lean index 764d939..d8e7fe1 100644 --- a/NKL/FFI.lean +++ b/NKL/FFI.lean @@ -4,10 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ import Lean +import NKL.KLR.Pretty import NKL.Python import NKL.Trace namespace NKL +open NKL.KLR local instance : MonadLift (Except String) IO where monadLift @@ -19,4 +21,4 @@ def parse_json (s : String) : IO Unit := do let kernel <- Python.Parsing.parse s let stmts <- NKL.Trace.runNKIKernel kernel for s in stmts do - IO.println s!"{repr s}" + IO.println (" " ++ Lean.format s) --s!"{s}\n{repr s}" diff --git a/NKL/KLR/Basic.lean b/NKL/KLR/Basic.lean index 7537cd4..fab5989 100644 --- a/NKL/KLR/Basic.lean +++ b/NKL/KLR/Basic.lean @@ -14,16 +14,22 @@ portable format, a.k.a. Kernel Language Representation (KLR). namespace NKL.KLR --- TODO switch to tensor lib +-- TODO switch to TensorLib's version of these types --export TensorLib (Tensor Dtype Shape) --- Mostly, NKL deals with empty tensors, so just check dtype and shape --- TODO: talk to Sean about a more general BEq for Tensor ---instance : BEq Tensor where --- beq t₁ t₂ := t₁.dtype == t₂.dtype && t₁.shape == t₂.shape abbrev Dtype := String abbrev Shape := List Int -structure Tensor where + +/- +A TensorName is essentially a typed variable, where the type +must be a tensor type. When we flush out Typ below we may replace +this with `Expr.var name (Typ.tensor dtype shape)`. For now, this +only refers to dynamic tensors, or compile-time tensors, not +trace-time tensors. +-/ + +structure TensorName where + name : String dtype : Dtype shape : Shape deriving Repr, BEq @@ -71,7 +77,7 @@ def toInt : Const -> Except String Int end Const --- This correspondes to the "Quasi-Affine Expressions" in Neuron. +-- This corresponds to the "Quasi-Affine Expressions" in Neuron. -- Note, `floor` is the usual integer division. inductive IndexExpr where | var (name : String) @@ -94,7 +100,7 @@ inductive Index where inductive Expr where | var (x : String) | const (c : Const) - | tensor (t : Tensor) + | tensor (t : TensorName) | access (t : Expr) (ix : List Index) | call (f : Expr) (args : List Expr) (kwargs : List (String × Expr)) deriving Repr, BEq diff --git a/NKL/KLR/Encode.lean b/NKL/KLR/Encode.lean index e1cc148..c5336a1 100644 --- a/NKL/KLR/Encode.lean +++ b/NKL/KLR/Encode.lean @@ -265,7 +265,7 @@ private def chkIndex (i : Index) : Bool := partial def encExpr : Expr -> ByteArray | .var s => tag 0x30 [encString s] - | .tensor t => tag 0x31 [encString t.dtype, encList encInt t.shape] + | .tensor t => tag 0x31 [encString t.name, encString t.dtype, encList encInt t.shape] | .const c => tag 0x32 [encConst c] | .access e ix => tag 0x33 [encExpr e, encList encIndex ix] | .call f ax kw => tag 0x34 [encExpr f, encList encExpr ax, encList encKeyword kw] @@ -276,7 +276,7 @@ where partial def decExpr : DecodeM Expr := do match (<- next) with | 0x30 => return .var (<- decString) - | 0x31 => return .tensor $ .mk (<- decString) (<- decList decInt) + | 0x31 => return .tensor $ .mk (<- decString) (<- decString) (<- decList decInt) | 0x32 => return .const (<- decConst) | 0x33 => return .access (<- decExpr) (<- decList decIndex) | 0x34 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword) @@ -293,7 +293,7 @@ private def ixz := Index.coord (IndexExpr.int 0) #guard chkExpr nil #guard chkExpr (.var "var") -#guard chkExpr (.tensor $ .mk "float32" [1,2,3]) +#guard chkExpr (.tensor $ .mk "t" "float32" [1,2,3]) #guard chkExpr (.const (.int 1)) #guard chkExpr (.access nil [ixz, ixz, ixz]) #guard chkExpr (.call nil [nil, nil, nil] [("a", nil), ("b", nil)]) diff --git a/NKL/KLR/Pretty.lean b/NKL/KLR/Pretty.lean new file mode 100644 index 0000000..9e2e43a --- /dev/null +++ b/NKL/KLR/Pretty.lean @@ -0,0 +1,80 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Govereau +-/ +import NKL.KLR.Basic + +namespace NKL.KLR +open Std + +/- +This is a simple pretty printer for KLR terms. At some point, we may want to +make this output valid python syntax that would parse and elaborate to the same +KLR kernel. At the moment, there are too many unknowns to spend time on this. +The format here is just for ease of debugging, feel free to modify as you wish. +-/ + +private def abracket (f : Format) : Format := + Format.bracket "<" f ">" + +private def ppArgs [ToFormat a] (l : List a) : Format := + Format.joinSep l "," + +def ppTensor (t : TensorName) : Format := + "%" ++ t.name ++ abracket (t.dtype ++ ":" ++ ppArgs t.shape) + +def ppConst : Const -> Format + | .none => "None" + | .bool true => "True" + | .bool false => "False" + | .int i => format i + | .float f => format f + | .string s => "\"" ++ s.push '"' + +private def addParens : Nat -> Format -> Format + | 0, f => f + | _, f => f.paren + +def ppIndexExpr (n : Nat) : IndexExpr -> Format + | .var x => x + | .int i => format i + | .neg e => "-" ++ ppIndexExpr (n+1) e + | .add l r => addParens n $ ppIndexExpr 1 l ++ "+" ++ ppIndexExpr 1 r + | .mul i e => addParens n $ format i ++ "*" ++ ppIndexExpr 1 e + | .floor e i => addParens n $ ppIndexExpr 1 e ++ "/" ++ format i + | .ceil e i => "ceil" ++ Format.paren (ppIndexExpr 0 e ++","++ format i) + | .mod e i => addParens n $ ppIndexExpr 1 e ++ "%" ++ format i + +def ppIndexExpr? : Option IndexExpr -> Format + | none => "None" + | some e => ppIndexExpr 0 e + +def ppIndex : Index -> Format + | .ellipsis => "..." + | .coord e => ppIndexExpr? e + | .slice l u s => .joinSep ([l,u,s].map ppIndexExpr?) ":" + +partial def ppExpr : Expr -> Format + | .var x => x + | .const c => ppConst c + | .tensor t => ppTensor t + | .access t ix => .fill (ppExpr t ++ .sbracket (.joinSep (ix.map ppIndex) ",")) + | .call f args kwargs => + let args := args.map ppExpr + let kwargs := kwargs.map fun (x,e) => x ++ "=" ++ ppExpr e + .fill (ppExpr f ++ .paren (ppArgs (args ++ kwargs))) + +def ppStmt : Stmt -> Format + | .pass => "pass" + | .expr e => ppExpr e + | .ret e => "ret" ++ ppExpr e + | .assign x e => x ++ " = " ++ ppExpr e + | .loop _ _ _ _ _ => "" + +instance : ToFormat TensorName where format := ppTensor +instance : ToFormat Const where format := ppConst +instance : ToFormat IndexExpr where format := ppIndexExpr 0 +instance : ToFormat Index where format := ppIndex +instance : ToFormat Expr where format := ppExpr +instance : ToFormat Stmt where format := ppStmt diff --git a/NKL/Trace/Python.lean b/NKL/Trace/Python.lean index 87162ad..47667e9 100644 --- a/NKL/Trace/Python.lean +++ b/NKL/Trace/Python.lean @@ -116,7 +116,8 @@ partial def expr' : Expr' -> Tracer Item | .const c => return .term (<- const c) | .tensor s dty => do let shape <- s.mapM integer - return .term (.expr (.tensor ⟨ dty, shape ⟩) (.tensor dty shape)) + let name <- genName "t".toName + return .term (.expr (.tensor ⟨ name.toString, dty, shape ⟩) (.tensor dty shape)) | .name id _ => lookup_item id.toName | .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id) | .tuple l _ => return .term (.tuple (<- l.mapM term)) @@ -190,6 +191,7 @@ partial def stmt' : Stmt' -> Tracer Unit partial def bind_args (f : Fun) (args : List Term) (kwargs : List (String × Term)) + (rename : Bool := false) : Tracer (List (String × Term)) := do if f.args.vararg != none || f.args.kwarg != none then throw "var args not supported" @@ -208,7 +210,13 @@ partial def bind_args (f : Fun) return (x, <- term' e) else throw s!"argument {x} not supplied" + -- rename tensors if asked to + let argmap := if rename then argmap.map renameTensors else argmap return argmap +where + renameTensors : String × Term -> String × Term + | (s, .expr (.tensor t) ty) => (s, .expr (.tensor {t with name := s}) ty) + | other => other -- For a function call, first evaluate the argument in the current environment. -- Then enter a new environment and evaluate the function statements. @@ -216,8 +224,7 @@ partial def function_call (f : Fun) (args : List Term) (kwargs : List (String × Term)) : Tracer Unit := do - let args <- bind_args f args kwargs - --let args <- args.mapM fun (x,e) => return (x, e) + let args <- bind_args f args kwargs (rename:=true) withSrc f.source $ enterFun $ do args.forM fun (x,e) => do extend x.toName e f.body.forM stmt diff --git a/NKL/Trace/Tensor.lean b/NKL/Trace/Tensor.lean index c361a33..f9ce83c 100644 --- a/NKL/Trace/Tensor.lean +++ b/NKL/Trace/Tensor.lean @@ -23,21 +23,21 @@ private def tensor_call (op : String) (args : List Expr) : Term := -- Unary operations on tensors -def tensor_op (op : String) (t : Tensor) : TraceM Term := +def tensor_op (op : String) (t : TensorName) : TraceM Term := return tensor_call op [.tensor t] -- Binary operations on tensors / scalars -def tensor_tensor (op : String) (l r : Tensor) : TraceM Term := +def tensor_tensor (op : String) (l r : TensorName) : TraceM Term := return tensor_call op [.tensor l, .tensor r] -private def broadcast (t : Tensor) (c : Const) : Expr := +private def broadcast (t : TensorName) (c : Const) : Expr := let args := t.shape.map fun i => Expr.const (.int i) let args := .const c :: args .call (.var "broadcast") args [] -def tensor_scalar (op : String) (t : Tensor) (c : Const) : TraceM Term := +def tensor_scalar (op : String) (t : TensorName) (c : Const) : TraceM Term := return tensor_call op [ .tensor t, broadcast t c] -def scalar_tensor (op : String) (c : Const) (t : Tensor) : TraceM Term := +def scalar_tensor (op : String) (c : Const) (t : TensorName) : TraceM Term := return tensor_call op [ .tensor t, broadcast t c]