Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A simple pretty-printer for KLR #21

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import Lean
import NKL.KLR.Pretty
import NKL.Python
import NKL.Trace

namespace NKL
open NKL.KLR

local instance : MonadLift (Except String) IO where
monadLift
Expand All @@ -19,4 +21,4 @@ def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
let stmts <- NKL.Trace.runNKIKernel kernel
for s in stmts do
IO.println s!"{repr s}"
IO.println (" " ++ Lean.format s) --s!"{s}\n{repr s}"
22 changes: 14 additions & 8 deletions NKL/KLR/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@ portable format, a.k.a. Kernel Language Representation (KLR).

namespace NKL.KLR

-- TODO switch to tensor lib
-- TODO switch to TensorLib's version of these types
--export TensorLib (Tensor Dtype Shape)
-- Mostly, NKL deals with empty tensors, so just check dtype and shape
-- TODO: talk to Sean about a more general BEq for Tensor
--instance : BEq Tensor where
-- beq t₁ t₂ := t₁.dtype == t₂.dtype && t₁.shape == t₂.shape

abbrev Dtype := String
abbrev Shape := List Int
structure Tensor where

/-
A TensorName is essentially a typed variable, where the type
must be a tensor type. When we flush out Typ below we may replace
this with `Expr.var name (Typ.tensor dtype shape)`. For now, this
only refers to dynamic tensors, or compile-time tensors, not
trace-time tensors.
-/

structure TensorName where
name : String
govereau marked this conversation as resolved.
Show resolved Hide resolved
dtype : Dtype
shape : Shape
deriving Repr, BEq
Expand Down Expand Up @@ -71,7 +77,7 @@ def toInt : Const -> Except String Int

end Const

-- This correspondes to the "Quasi-Affine Expressions" in Neuron.
-- This corresponds to the "Quasi-Affine Expressions" in Neuron.
-- Note, `floor` is the usual integer division.
inductive IndexExpr where
| var (name : String)
Expand All @@ -94,7 +100,7 @@ inductive Index where
inductive Expr where
| var (x : String)
| const (c : Const)
| tensor (t : Tensor)
| tensor (t : TensorName)
| access (t : Expr) (ix : List Index)
| call (f : Expr) (args : List Expr) (kwargs : List (String × Expr))
deriving Repr, BEq
Expand Down
6 changes: 3 additions & 3 deletions NKL/KLR/Encode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ private def chkIndex (i : Index) : Bool :=

partial def encExpr : Expr -> ByteArray
| .var s => tag 0x30 [encString s]
| .tensor t => tag 0x31 [encString t.dtype, encList encInt t.shape]
| .tensor t => tag 0x31 [encString t.name, encString t.dtype, encList encInt t.shape]
| .const c => tag 0x32 [encConst c]
| .access e ix => tag 0x33 [encExpr e, encList encIndex ix]
| .call f ax kw => tag 0x34 [encExpr f, encList encExpr ax, encList encKeyword kw]
Expand All @@ -276,7 +276,7 @@ where
partial def decExpr : DecodeM Expr := do
match (<- next) with
| 0x30 => return .var (<- decString)
| 0x31 => return .tensor $ .mk (<- decString) (<- decList decInt)
| 0x31 => return .tensor $ .mk (<- decString) (<- decString) (<- decList decInt)
govereau marked this conversation as resolved.
Show resolved Hide resolved
| 0x32 => return .const (<- decConst)
| 0x33 => return .access (<- decExpr) (<- decList decIndex)
| 0x34 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword)
Expand All @@ -293,7 +293,7 @@ private def ixz := Index.coord (IndexExpr.int 0)

#guard chkExpr nil
#guard chkExpr (.var "var")
#guard chkExpr (.tensor $ .mk "float32" [1,2,3])
#guard chkExpr (.tensor $ .mk "t" "float32" [1,2,3])
#guard chkExpr (.const (.int 1))
#guard chkExpr (.access nil [ixz, ixz, ixz])
#guard chkExpr (.call nil [nil, nil, nil] [("a", nil), ("b", nil)])
Expand Down
80 changes: 80 additions & 0 deletions NKL/KLR/Pretty.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR.Basic

namespace NKL.KLR
open Std

/-
This is a simple pretty printer for KLR terms. At some point, we may want to
make this output valid python syntax that would parse and elaborate to the same
KLR kernel. At the moment, there are too many unknowns to spend time on this.
The format here is just for ease of debugging, feel free to modify as you wish.
-/

private def abracket (f : Format) : Format :=
Format.bracket "<" f ">"

private def ppArgs [ToFormat a] (l : List a) : Format :=
Format.joinSep l ","

def ppTensor (t : TensorName) : Format :=
"%" ++ t.name ++ abracket (t.dtype ++ ":" ++ ppArgs t.shape)

def ppConst : Const -> Format
| .none => "None"
| .bool true => "True"
| .bool false => "False"
| .int i => format i
| .float f => format f
| .string s => "\"" ++ s.push '"'

private def addParens : Nat -> Format -> Format
| 0, f => f
| _, f => f.paren

def ppIndexExpr (n : Nat) : IndexExpr -> Format
| .var x => x
| .int i => format i
| .neg e => "-" ++ ppIndexExpr (n+1) e
| .add l r => addParens n $ ppIndexExpr 1 l ++ "+" ++ ppIndexExpr 1 r
| .mul i e => addParens n $ format i ++ "*" ++ ppIndexExpr 1 e
| .floor e i => addParens n $ ppIndexExpr 1 e ++ "/" ++ format i
| .ceil e i => "ceil" ++ Format.paren (ppIndexExpr 0 e ++","++ format i)
| .mod e i => addParens n $ ppIndexExpr 1 e ++ "%" ++ format i

def ppIndexExpr? : Option IndexExpr -> Format
| none => "None"
| some e => ppIndexExpr 0 e

def ppIndex : Index -> Format
| .ellipsis => "..."
| .coord e => ppIndexExpr? e
| .slice l u s => .joinSep ([l,u,s].map ppIndexExpr?) ":"

partial def ppExpr : Expr -> Format
| .var x => x
| .const c => ppConst c
| .tensor t => ppTensor t
| .access t ix => .fill (ppExpr t ++ .sbracket (.joinSep (ix.map ppIndex) ","))
| .call f args kwargs =>
let args := args.map ppExpr
let kwargs := kwargs.map fun (x,e) => x ++ "=" ++ ppExpr e
.fill (ppExpr f ++ .paren (ppArgs (args ++ kwargs)))

def ppStmt : Stmt -> Format
| .pass => "pass"
| .expr e => ppExpr e
| .ret e => "ret" ++ ppExpr e
| .assign x e => x ++ " = " ++ ppExpr e
| .loop _ _ _ _ _ => "<loop>"

instance : ToFormat TensorName where format := ppTensor
instance : ToFormat Const where format := ppConst
instance : ToFormat IndexExpr where format := ppIndexExpr 0
instance : ToFormat Index where format := ppIndex
instance : ToFormat Expr where format := ppExpr
instance : ToFormat Stmt where format := ppStmt
13 changes: 10 additions & 3 deletions NKL/Trace/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ partial def expr' : Expr' -> Tracer Item
| .const c => return .term (<- const c)
| .tensor s dty => do
let shape <- s.mapM integer
return .term (.expr (.tensor ⟨ dty, shape ⟩) (.tensor dty shape))
let name <- genName "t".toName
return .term (.expr (.tensor ⟨ name.toString, dty, shape ⟩) (.tensor dty shape))
| .name id _ => lookup_item id.toName
| .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id)
| .tuple l _ => return .term (.tuple (<- l.mapM term))
Expand Down Expand Up @@ -190,6 +191,7 @@ partial def stmt' : Stmt' -> Tracer Unit
partial def bind_args (f : Fun)
(args : List Term)
(kwargs : List (String × Term))
(rename : Bool := false)
: Tracer (List (String × Term)) := do
if f.args.vararg != none || f.args.kwarg != none then
throw "var args not supported"
Expand All @@ -208,16 +210,21 @@ partial def bind_args (f : Fun)
return (x, <- term' e)
else
throw s!"argument {x} not supplied"
-- rename tensors if asked to
let argmap := if rename then argmap.map renameTensors else argmap
return argmap
where
renameTensors : String × Term -> String × Term
| (s, .expr (.tensor t) ty) => (s, .expr (.tensor {t with name := s}) ty)
| other => other

-- For a function call, first evaluate the argument in the current environment.
-- Then enter a new environment and evaluate the function statements.
partial def function_call (f : Fun)
(args : List Term)
(kwargs : List (String × Term))
: Tracer Unit := do
let args <- bind_args f args kwargs
--let args <- args.mapM fun (x,e) => return (x, e)
let args <- bind_args f args kwargs (rename:=true)
withSrc f.source $ enterFun $ do
args.forM fun (x,e) => do extend x.toName e
f.body.forM stmt
Expand Down
10 changes: 5 additions & 5 deletions NKL/Trace/Tensor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ private def tensor_call (op : String) (args : List Expr) : Term :=

-- Unary operations on tensors

def tensor_op (op : String) (t : Tensor) : TraceM Term :=
def tensor_op (op : String) (t : TensorName) : TraceM Term :=
return tensor_call op [.tensor t]

-- Binary operations on tensors / scalars

def tensor_tensor (op : String) (l r : Tensor) : TraceM Term :=
def tensor_tensor (op : String) (l r : TensorName) : TraceM Term :=
return tensor_call op [.tensor l, .tensor r]

private def broadcast (t : Tensor) (c : Const) : Expr :=
private def broadcast (t : TensorName) (c : Const) : Expr :=
let args := t.shape.map fun i => Expr.const (.int i)
let args := .const c :: args
.call (.var "broadcast") args []

def tensor_scalar (op : String) (t : Tensor) (c : Const) : TraceM Term :=
def tensor_scalar (op : String) (t : TensorName) (c : Const) : TraceM Term :=
return tensor_call op [ .tensor t, broadcast t c]

def scalar_tensor (op : String) (c : Const) (t : Tensor) : TraceM Term :=
def scalar_tensor (op : String) (c : Const) (t : TensorName) : TraceM Term :=
return tensor_call op [ .tensor t, broadcast t c]
Loading