From 847f159d389f04c08330fc87d6c99f79a21426b8 Mon Sep 17 00:00:00 2001 From: Paul Govereau Date: Tue, 7 Jan 2025 21:11:27 -0500 Subject: [PATCH] refactor: cleanup KLR definitions Remove some unnecessary parts of KLR, and reorganize the source files. --- NKL.lean | 1 - NKL/KLR.lean | 104 +----------------------------------- NKL/KLR/Basic.lean | 108 ++++++++++++++++++++++++++++++++++++++ NKL/{ => KLR}/Encode.lean | 38 +++++--------- 4 files changed, 123 insertions(+), 128 deletions(-) create mode 100644 NKL/KLR/Basic.lean rename NKL/{ => KLR}/Encode.lean (89%) diff --git a/NKL.lean b/NKL.lean index baf227c..8d90796 100644 --- a/NKL.lean +++ b/NKL.lean @@ -3,7 +3,6 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ -import NKL.Encode import NKL.FFI import NKL.KLR import NKL.Python diff --git a/NKL/KLR.lean b/NKL/KLR.lean index ed50052..d5a0544 100644 --- a/NKL/KLR.lean +++ b/NKL/KLR.lean @@ -3,105 +3,5 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ - - -/-! -# Abstract syntax of Core NKL language - -This language is the result of "tracing", and is used as the -portable format, a.k.a. Kernel Language Representation (KLR). --/ - -namespace NKL.KLR - --- TODO -inductive Ty where - -inductive Const where - | none - | bool (value : Bool) - | int (value : Int) - | float (value : Float) - | string (value : String) - deriving Repr, BEq - -namespace Const - --- Python-like rules for conversion to boolean -def isTrue : Const -> Bool - | .none => false - | .bool b => b - | .int i => i != 0 - | .float f => f != 0.0 - | .string s => s != "" - --- Python-like rules for conversion to integer -def toInt : Const -> Except String Int - | .none => throw "none cannot be converted to an integer" - | .bool true => return 1 - | .bool false => return 0 - | .int i => return i - | .float f => - -- Python is a bit strange here, it truncates both - -- positive and negative numbers toward zero - if f < 0.0 then - return (Int.ofNat (Float.floor (-f)).toUInt64.toBitVec.toNat).neg - else - return Int.ofNat (Float.floor f).toUInt64.toBitVec.toNat - | .string s => - match s.toInt? with - | .none => throw s!"string {s} cannot be converted to an integer" - | .some i => return i - -end Const - -inductive IndexExpr where - | var (name : String) - | int (i : Int) - | neg (expr : IndexExpr) - | add (left right : IndexExpr) - | mul (scalar : Int) (expr : IndexExpr) - | floor (expr : IndexExpr) (scalar : Int) - | ceil (expr : IndexExpr) (scalar : Int) - | mod (expr : IndexExpr) (scalar : Int) - deriving Repr, BEq - -inductive Index where - | ellipsis - | coord (e : Option IndexExpr) - | range (l u step : Option IndexExpr) - deriving Repr, BEq - -inductive Expr where - | var (x : String) - | const (c : Const) - | tensor (name : String) (shape : List Int) - | tuple (xs : List Expr) - | list (xs : List Expr) - | access (t : Expr) (ix : List Index) - | binop (op : String) (left right : Expr) - | unop (op : String) (e : Expr) - | call (f : Expr) (args : List Expr) (keywords : List (String × Expr)) - deriving Repr, BEq - -namespace Expr - --- TODO: Just a place-holder for now -def toAffine : Expr -> Except String IndexExpr - | .var v => return .var v - | .const (.int i) => return .int i - | e => throw s!"toAffine unimp {repr e}" - --- TODO: Just a place-holder for now -def simplify : Expr -> Expr := - fun x => x - -end Expr - -inductive Stmt where - | pass - | expr (v : Expr) - | ret (v : Expr) - | assign (x : String) (e : Expr) - | loop (x : String) (l u step : IndexExpr) (body : List Stmt) - deriving Repr, BEq +import NKL.KLR.Basic +import NKL.KLR.Encode diff --git a/NKL/KLR/Basic.lean b/NKL/KLR/Basic.lean new file mode 100644 index 0000000..7537cd4 --- /dev/null +++ b/NKL/KLR/Basic.lean @@ -0,0 +1,108 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Govereau +-/ +--import TensorLib.Tensor + +/-! +# Abstract syntax of Core NKL language + +This language is the result of "tracing", and is used as the +portable format, a.k.a. Kernel Language Representation (KLR). +-/ + +namespace NKL.KLR + +-- TODO switch to tensor lib +--export TensorLib (Tensor Dtype Shape) +-- Mostly, NKL deals with empty tensors, so just check dtype and shape +-- TODO: talk to Sean about a more general BEq for Tensor +--instance : BEq Tensor where +-- beq t₁ t₂ := t₁.dtype == t₂.dtype && t₁.shape == t₂.shape + +abbrev Dtype := String +abbrev Shape := List Int +structure Tensor where + dtype : Dtype + shape : Shape + deriving Repr, BEq + +-- TODO +inductive Typ where + +inductive Const where + | none + | bool (value : Bool) + | int (value : Int) + | float (value : Float) + | string (value : String) + deriving Repr, BEq + +namespace Const + +-- Python-like rules for conversion to boolean +def isTrue : Const -> Bool + | .none => false + | .bool b => b + | .int i => i != 0 + | .float f => f != 0.0 + | .string s => s != "" + +-- Python-like rules for conversion to integer +def toInt : Const -> Except String Int + | .none => throw "none cannot be converted to an integer" + | .bool true => return 1 + | .bool false => return 0 + | .int i => return i + | .float f => + -- Python is a bit strange here, it truncates both + -- positive and negative numbers toward zero + if f < 0.0 then + return (Int.ofNat (Float.floor (-f)).toUInt64.toNat).neg + else + return Int.ofNat (Float.floor f).toUInt64.toNat + | .string s => + -- Fortunately, Lean's String.toInt appears to be compatible + -- with Python's int(string) conversion. + match s.toInt? with + | .none => throw s!"string {s} cannot be converted to an integer" + | .some i => return i + +end Const + +-- This correspondes to the "Quasi-Affine Expressions" in Neuron. +-- Note, `floor` is the usual integer division. +inductive IndexExpr where + | var (name : String) + | int (i : Int) + | neg (expr : IndexExpr) + | add (left right : IndexExpr) + | mul (scalar : Int) (expr : IndexExpr) + | floor (expr : IndexExpr) (scalar : Int) + | ceil (expr : IndexExpr) (scalar : Int) + | mod (expr : IndexExpr) (scalar : Int) + deriving Repr, BEq + +-- Note: `np.newindex` is represented as `(.coord none)` +inductive Index where + | ellipsis + | coord (e : Option IndexExpr) + | slice (l u step : Option IndexExpr) + deriving Repr, BEq + +inductive Expr where + | var (x : String) + | const (c : Const) + | tensor (t : Tensor) + | access (t : Expr) (ix : List Index) + | call (f : Expr) (args : List Expr) (kwargs : List (String × Expr)) + deriving Repr, BEq + +inductive Stmt where + | pass + | expr (v : Expr) + | ret (v : Expr) + | assign (x : String) (e : Expr) + | loop (x : String) (l u step : IndexExpr) (body : List Stmt) + deriving Repr, BEq diff --git a/NKL/Encode.lean b/NKL/KLR/Encode.lean similarity index 89% rename from NKL/Encode.lean rename to NKL/KLR/Encode.lean index 706acf0..e1cc148 100644 --- a/NKL/Encode.lean +++ b/NKL/KLR/Encode.lean @@ -3,7 +3,7 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ -import NKL.KLR +import NKL.KLR.Basic /-! # Serialization and Deserialization @@ -239,7 +239,7 @@ private def ie_var : IndexExpr := .var "s" def encIndex : Index -> ByteArray | .ellipsis => tag 0x20 [] | .coord e => tag 0x21 [enc e] - | .range l u s => tag 0x22 [enc l, enc u, enc s] + | .slice l u s => tag 0x22 [enc l, enc u, enc s] where enc := encOption encIndexExpr @@ -247,7 +247,7 @@ def decIndex : DecodeM Index := do match (<- next) with | 0x20 => return .ellipsis | 0x21 => return .coord (<- dec) - | 0x22 => return .range (<- dec) (<- dec) (<- dec) + | 0x22 => return .slice (<- dec) (<- dec) (<- dec) | t => throw s!"Unknown tag in Index {t}" where dec:= decOption decIndexExpr @@ -258,21 +258,17 @@ private def chkIndex (i : Index) : Bool := #guard chkIndex .ellipsis #guard chkIndex (.coord none) #guard chkIndex (.coord $ some ie_var) -#guard chkIndex (.range (some ie_var) none none) +#guard chkIndex (.slice (some ie_var) none none) ------------------------------------------------------------------------------ -- Expressions partial def encExpr : Expr -> ByteArray - | .var s => tag 0x30 [encString s] - | .tensor t s => tag 0x31 [encString t, encList encInt s] - | .const c => tag 0x32 [encConst c] - | .tuple es => tag 0x33 [encList encExpr es] - | .list es => tag 0x34 [encList encExpr es] - | .access e ix => tag 0x35 [encExpr e, encList encIndex ix] - | .binop op l r => tag 0x36 [encString op, encExpr l, encExpr r] - | .unop op e => tag 0x37 [encString op, encExpr e] - | .call f ax kw => tag 0x38 [encExpr f, encList encExpr ax, encList encKeyword kw] + | .var s => tag 0x30 [encString s] + | .tensor t => tag 0x31 [encString t.dtype, encList encInt t.shape] + | .const c => tag 0x32 [encConst c] + | .access e ix => tag 0x33 [encExpr e, encList encIndex ix] + | .call f ax kw => tag 0x34 [encExpr f, encList encExpr ax, encList encKeyword kw] where encKeyword : String × Expr -> ByteArray | (key, expr) => (encString key).append (encExpr expr) @@ -280,14 +276,10 @@ where partial def decExpr : DecodeM Expr := do match (<- next) with | 0x30 => return .var (<- decString) - | 0x31 => return .tensor (<- decString) (<- decList decInt) + | 0x31 => return .tensor $ .mk (<- decString) (<- decList decInt) | 0x32 => return .const (<- decConst) - | 0x33 => return .tuple (<- decList decExpr) - | 0x34 => return .list (<- decList decExpr) - | 0x35 => return .access (<- decExpr) (<- decList decIndex) - | 0x36 => return .binop (<- decString) (<- decExpr) (<- decExpr) - | 0x37 => return .unop (<- decString) (<- decExpr) - | 0x38 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword) + | 0x33 => return .access (<- decExpr) (<- decList decIndex) + | 0x34 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword) | t => throw s!"Unknown tag in Expr {t}" where decKeyword : DecodeM (String × Expr) := @@ -301,13 +293,9 @@ private def ixz := Index.coord (IndexExpr.int 0) #guard chkExpr nil #guard chkExpr (.var "var") -#guard chkExpr (.tensor "float32" [1,2,3]) +#guard chkExpr (.tensor $ .mk "float32" [1,2,3]) #guard chkExpr (.const (.int 1)) -#guard chkExpr (.tuple [nil, nil, nil]) -#guard chkExpr (.list [nil, nil, nil]) #guard chkExpr (.access nil [ixz, ixz, ixz]) -#guard chkExpr (.binop "op" nil nil) -#guard chkExpr (.unop "op" nil) #guard chkExpr (.call nil [nil, nil, nil] [("a", nil), ("b", nil)]) ------------------------------------------------------------------------------