diff --git a/NKL/FFI.lean b/NKL/FFI.lean index 6f59d36..764d939 100644 --- a/NKL/FFI.lean +++ b/NKL/FFI.lean @@ -5,6 +5,7 @@ Authors: Paul Govereau -/ import Lean import NKL.Python +import NKL.Trace namespace NKL @@ -16,12 +17,6 @@ local instance : MonadLift (Except String) IO where @[export parse_json] def parse_json (s : String) : IO Unit := do let kernel <- Python.Parsing.parse s - let names := kernel.funcs.map fun x => x.fst - let names := String.intercalate "," names - IO.println s!"Found functions: {names}" - for x in kernel.args do - IO.println s!"arg: {repr x}" - for x in kernel.kwargs do - IO.println s!"arg: {repr x}" - for x in kernel.globals do - IO.println s!"global: {repr x}" + let stmts <- NKL.Trace.runNKIKernel kernel + for s in stmts do + IO.println s!"{repr s}" diff --git a/NKL/Python.lean b/NKL/Python.lean index ededec5..b2aedd0 100644 --- a/NKL/Python.lean +++ b/NKL/Python.lean @@ -106,11 +106,11 @@ then the structure will be populated with: defaults = [1, 2] vararg = "args" kwonlyargs = [d, e] - kw_defaults = [None, 3] + kw_defaults = [("e", 3)] kwarg = "kwargs" -Note, defaults and kw_defaults are inconsistent in how they treat -missing arguments, but this is just how it works in the python AST. +Note, this is slightly different from the official Python AST, which +encodes the kw_defaults as a list with None for missing defaults. -/ structure Args where posonlyargs : List String @@ -122,17 +122,13 @@ structure Args where kwarg : Option String deriving Repr -def Args.names (ax : Args) : List String := - let xs := ax.posonlyargs.append ax.args - let xs := match ax.vararg with | none => xs | some x => xs.append [x] - let xs := xs.append ax.kwonlyargs - let xs := match ax.kwarg with | none => xs | some x => xs.append [x] - xs - -def Args.all_defaults (ax : Args) : List (String × Expr') := - let args := ax.posonlyargs ++ ax.args - let dflt := args.reverse.zip ax.defaults.reverse - dflt ++ ax.kw_defaults +def Args.names (args : Args) : List String := + args.posonlyargs ++ args.args ++ args.kwonlyargs + +def Args.all_defaults (args : Args) : List (String × Expr') := + let pargs := args.posonlyargs ++ args.args + let dflt := pargs.reverse.zip args.defaults.reverse + dflt ++ args.kw_defaults structure Fun where source : String diff --git a/NKL/Trace.lean b/NKL/Trace.lean index 8f23ca4..481a552 100644 --- a/NKL/Trace.lean +++ b/NKL/Trace.lean @@ -3,7 +3,18 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ +import NKL.KLR +import NKL.Python import NKL.Trace.Types import NKL.Trace.Basic import NKL.Trace.Builtin ---import NKL.Trace.Python +import NKL.Trace.Python +import NKL.Trace.NKI + +namespace NKL.Trace + +def runNKIKernel (k : NKL.Python.Kernel) : Except String (List NKL.KLR.Stmt) := + tracer ⟨ .ofList NKIEnv, #[] ⟩ do + traceKernel k + let g <- get + return g.body.toList diff --git a/NKL/Trace/Basic.lean b/NKL/Trace/Basic.lean index a1369b0..59d1f7d 100644 --- a/NKL/Trace/Basic.lean +++ b/NKL/Trace/Basic.lean @@ -118,11 +118,8 @@ def unOp : String -> Term -> TraceM Term | op, _ => throw s!"unimp {op}" -- Comparison operators --- TODO: need to think about comparison of tensors, in NKI this is object equality, --- but I suspect this doesn't make sense and may be a source of bugs. def cmpOp : String -> Term -> Term -> TraceM Bool - | "Eq", .expr l _, .expr r _ => return l == r | s, l, r => throw s!"unsupported comparison operator {s} {repr l} {repr r}" def compare : Term -> List String -> List Term -> TraceM Term diff --git a/NKL/Trace/NKI.lean b/NKL/Trace/NKI.lean new file mode 100644 index 0000000..328b9d0 --- /dev/null +++ b/NKL/Trace/NKI.lean @@ -0,0 +1,49 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Govereau +-/ +import NKL.KLR +import NKL.Trace.Types +import NKL.Trace.Builtin + +/- +# NKI built-ins + +This module defines the builtin constants used by tracing for NKI kernels. +-/ +namespace NKL.Trace +open NKL.KLR + +private def module (s : String) : Name × Item := + let name := s.toName + (name, .module name) + +private def const_var (s : String) : Name × Item := + let name := s.toName + (name, .term (.expr (.var s) (.any name))) + +/- +Note: this object contains a bunch of architecture parameters that +need to be set according to which HW we are compiling for. +TODO: figure out the mechanism for this. +-/ +def tile_size : Global := + let name := "nki.langauge.tile_size".toName + { name := name + , attr := attrs + , call := uncallable name + } +where + attrs : GlobalAttr + | "pmax" => return .expr (.const $ .int 128) .int + | a => throw s!"unsupported attribute {a}" + +def NKIEnv : List (Name × Item) := + [ module "nki" + , module "nki.language" + , const_var "nki.language.add" + , const_var "nki.language.load" + , const_var "nki.language.store" + , ("nki.language.tile_size".toName, .global tile_size) + ] diff --git a/NKL/Trace/Python.lean b/NKL/Trace/Python.lean new file mode 100644 index 0000000..87162ad --- /dev/null +++ b/NKL/Trace/Python.lean @@ -0,0 +1,264 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Govereau +-/ +import Lean +import NKL.KLR +import NKL.Python +import NKL.Trace.Types +import NKL.Trace.Basic + +namespace NKL.Trace +open NKL.Python + +def const : Const -> ErrorM Term + | .none => return .expr (.const $ .none) .none + | .bool b => return .expr (.const $ .bool b) .bool + | .int i => return .expr (.const $ .int i) .int + | .float f => return .expr (.const $ .float f) .float + | .string s => return .expr (.const $ .string s) .string + | .ellipsis => throw "unsupported use of ellipsis" + +/- +Evaluating index expressions. + +An index expression occurs only within a subscript expression. For example, in +the expression: + + t[1,1:10,None,x+9] + +all of 1, 1:10, None, and x+9 are indexes. Note None may also be written as +np.newaxis. Also, a None or a slice (or ellipsis) may only occur at the +outer-most level of an index: if you write, e.g. + + t[x+None] + +then the None is interpreted as an integer and not as a new axis. If you write, + + t[(1:2) + 3] + t[... * 8] + +these are syntax errors in python. +-/ + +mutual +-- top-level index expressions: None (a.k.a. np.newaxis) or IndexExpr +def indexExpr? : Option Expr -> Tracer (Option KLR.IndexExpr) + | none => return none + | some (.exprPos (.const .none) _) => return none + | some e => indexExpr e + +-- general sub-expressions +def indexExpr : Expr -> Tracer KLR.IndexExpr + | .exprPos e' p => withPos p (indexExpr' e') + +def indexExpr' : Expr' -> Tracer KLR.IndexExpr + | .const (.int i) => return .int i + | .name id _ => return .var id + | .binOp op l r => return <- indexBinOp op (<- indexExpr l) (<- indexExpr r) + | .unaryOp op e => return <- indexUnOp op (<- indexExpr e) + | _ => throw "invalid index expression" + +-- top-level index: slice, ellipsis, or indexExpr? +-- TODO: get rid of ... +def index : Expr -> Tracer KLR.Index + | .exprPos (.const .ellipsis) _ => return .ellipsis + | .exprPos (.slice l u s) p => withPos p do + return (.slice (<- indexExpr? l) (<- indexExpr? u) (<- indexExpr? s)) + | e => return (.coord (<- indexExpr? e)) +end + +-- Note, a list index can be negative, which means index from end of list. +def list_access (l : List Term) : List KLR.Index -> TraceM Term + | [.coord (some (.int i))] => do + let i := if i < 0 then l.length + i else i + if i < 0 then throw "index out of bounds" + let n := i.toNat + if h:l.length > n then return l.get (Fin.mk n h) + else throw "index out of bounds" + |_ => throw "unsupported __subscript__" + +def access : Term -> List KLR.Index -> TraceM Term + | .object _, _ => throw "builtin object __subscript__ not supported" + | .tuple l, ix + | .list l, ix => list_access l ix + | .expr e _, ix => return .expr (.access e ix) (.any "?".toName) + + +mutual +partial def expr : Expr -> Tracer Item + | .exprPos e' p => withPos p (expr' e') + +partial def term (e : Expr) : Tracer Term := do + match (<- expr e) with + | .module n => return .expr (.var n.toString) (.any "?".toName) + | .global g => return .expr (.var g.name.toString) (.any "?".toName) + | .source _ => throw "invalid use of source function" + | .term t => return t + +partial def term' (e : Expr') : Tracer Term := do + term (.exprPos e (<- getPos)) + +partial def klr (e : Expr) : Tracer KLR.Expr := do + match (<- term e) with + | .object obj => return .var obj.name.toString + | .tuple _ => throw "tuple cannot be converted to a KLR term" + | .list _ => throw "list cannot be converted to a KLR term" + | .expr e _ => return e + +partial def integer (e : Expr) : Tracer Int := do + match (<- term e) with + | .expr (.const c) _ => return (<- c.toInt) + | _ => throw "invalid tensor dimension" + +partial def expr' : Expr' -> Tracer Item + | .const c => return .term (<- const c) + | .tensor s dty => do + let shape <- s.mapM integer + return .term (.expr (.tensor ⟨ dty, shape ⟩) (.tensor dty shape)) + | .name id _ => lookup_item id.toName + | .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id) + | .tuple l _ => return .term (.tuple (<- l.mapM term)) + | .list l _ => return .term (.list (<- l.mapM term)) + | .subscript t [ .exprPos (.tuple ix _) _ ] _ + | .subscript t ix _ => return .term (<- access (<- term t) (<- ix.mapM index)) + | .slice _ _ _ => throw "syntax error" + | .boolOp op xs => return .term (<- boolOp op (<- xs.mapM term)) + | .binOp op l r => return .term (<- binOp op (<- term l) (<- term r)) + | .unaryOp op e => return .term (<- unOp op (<- term e)) + | .compare l ops cs => return .term (<- compare (<- term l) ops (<- cs.mapM term)) + | .ifExp tst tru fls => do + let tst <- (<- term tst).isTrue + let tru <- expr tru -- eagerly evaluate both branches + let fls <- expr fls -- to report errors to user + return if tst then tru else fls + | .call f args kws => do + match <- expr f with + | .module n => throw s!"module {n} not callable" + | .global g => return .term (<- g.call (<- args.mapM term) (<- kws.mapM (keyword term))) + | .term t => return .term (<- t.call (<- args.mapM klr) (<- kws.mapM (keyword klr))) + | .source f => do + function_call f (<- args.mapM term) (<- kws.mapM (keyword term)) + return .term (.expr (.const .none) .none) + +partial def keyword (f : Expr -> Tracer a) : Keyword -> Tracer (String × a) + | .keyword id e p => withPos p do return (id, (<- f e)) + +-- When looking for a variable we rely on the store attribute +-- from the Python parser to check if it is a defining use. +partial def var : Expr -> Tracer String + | .exprPos (.name id .store) _ => return id + | _ => throw "expecting variable" + +-- When we perform an assignment, we will either add to the environment +-- the term found on the RHS, or the variable itself. The latter case +-- allows us to lookup and find the variable without substituting +-- its definition. +partial def assign (xs : List Expr) (e : Expr) : Tracer Unit := do + let xs <- xs.mapM var + let e <- term e + match e with + | .expr (.const _) _ => xs.forM fun x => extend x.toName e + | .expr e ty => xs.forM fun x => do + extend x.toName (.expr (.var xs[0]!) ty) + add_stmt (KLR.Stmt.assign x e) + | t => xs.forM fun x => extend x.toName t + +partial def stmt : Stmt -> Tracer Unit + | .stmtPos s' p => withPos p (stmt' s') + +partial def stmt' : Stmt' -> Tracer Unit + | .expr (.exprPos (.const _) _) => return () + | .expr e => do + match <- term e with + | .expr e _ => add_stmt (.expr e) + | _ => return () -- effects are done, can be removed from KLR + | .assert e => do + let t <- term e + if (<- t.isFalse) then throw "assertion failed" + | .assign xs e => assign xs e + | .augAssign x op e => do + stmt' (.assign [x] (.exprPos (.binOp op x e) (<- getPos))) + | .annAssign _ _ .none => return () + | .annAssign x _ (.some e) => stmt' (.assign [x] e) + | _s => throw "not yet implemented" --s!"unimp {repr s}" + +-- Bind positional and keyword arguments to a Python function based on its +-- signature. + +partial def bind_args (f : Fun) + (args : List Term) + (kwargs : List (String × Term)) + : Tracer (List (String × Term)) := do + if f.args.vararg != none || f.args.kwarg != none then + throw "var args not supported" + if args.length < f.args.posonlyargs.length then + throw "not enough arguments" + let dflts := f.args.all_defaults + let names := f.args.names + if args.length + kwargs.length > names.length then + throw "too many arguments supplied (varargs not supported)" + let argmap <- f.args.names.enum.mapM fun (i,x) => do + if h:args.length > i then + return (x, args.get (Fin.mk i h)) + else if let some v := kwargs.lookup x then + return (x, v) + else if let some e := dflts.lookup x then + return (x, <- term' e) + else + throw s!"argument {x} not supplied" + return argmap + +-- For a function call, first evaluate the argument in the current environment. +-- Then enter a new environment and evaluate the function statements. +partial def function_call (f : Fun) + (args : List Term) + (kwargs : List (String × Term)) + : Tracer Unit := do + let args <- bind_args f args kwargs + --let args <- args.mapM fun (x,e) => return (x, e) + withSrc f.source $ enterFun $ do + args.forM fun (x,e) => do extend x.toName e + f.body.forM stmt + +end + +/- +Evaluate each global in the current environment, skipping any globals that are +already defined. We do not redefine globals because, we may have picked up +functions with dummy implementations, e.g., nki.language.add is defined as: + + def add(x,y): pass + +in some versions of the code. We do not want this to shadow a built-in +definition of add. If we have an internal definition, we will use this over +anything found during parsing. +-/ + +private def globals (k : Kernel) : Tracer Unit := do + let s <- get + for (n, f) in k.funcs do + let n := n.toName + if not (s.env.contains n) then + extend_global n (.source f) + for (n,e) in k.globals do + let n := n.toName + if not (s.env.contains n) then + extend_global n (<- expr' e) + +-- Call the top-level kernel function +def traceKernel (k : Kernel) : Tracer Unit := do + globals k + match k.funcs.lookup k.entry with + | none => throw s!"function {k.entry} not found" + | some f => do + let args <- k.args.mapM term' + let kwargs <- k.kwargs.mapM fun (x,e) => return (x, <- term' e) + function_call f args kwargs + +def runKernel (k : Kernel) : Except String (List KLR.Stmt) := + tracer ⟨ ∅, #[] ⟩ do + traceKernel k + let g <- get + return g.body.toList diff --git a/interop/test/test_basic.py b/interop/test/test_basic.py new file mode 100644 index 0000000..71c64be --- /dev/null +++ b/interop/test/test_basic.py @@ -0,0 +1,93 @@ +import numpy as np +import nki +import pytest + +from nkl.parser import Parser + +# Success cases +# (these functions should load and trace to KLR) + +def const_stmt(t): + "this will be ignored because it has no effect" + 1 # so will this, it is a simple constant + 1.0 # so will this + False # and this + None # and this + (1,2) # and this + [1,2] # and this + +string = "a string" +integer = -3 +floating = 1.23 +boolean = True +nothing = None +triple = (1, floating, False) +list3 = [string, triple, nki] + +def expr_name(t): + # these names will end up in the global environment after parsing + # they will be eliminated after substitution during tracing + string, integer, floating, boolean, nothing + # constant tuples are also OK + triple + # as are constant lists + list3 + # as are module references + nki + +def expr_tuple(t): + assert (1,False,"hello") + +def expr_list(t): + assert [1,2,nki] + assert not [] + +def expr_subscript(t): + t[1] + t[1,2,3] + t[1:2:3] + t[1:2] + t[1:] + t[1::] + t[1::2] + t[1:2:None] + t[1:None:2] + t[:] + t[:,:] + t[...] + t[1,...] + t[:,None] + t[1] + +def expr_bool_op(t): + True and 1 and [1] and [] and True # evals to [] + False or None or [] or 1 # evals to 1 + 1 or None # evals to 1 + (False,) or 1 # evals to (False,) + +@pytest.mark.parametrize("f", [ + const_stmt, + expr_name, + expr_tuple, + expr_list, + expr_subscript, + expr_bool_op, + ]) +def test_succeed(f): + t = np.ndarray(10) + F = Parser(f) + F(t) + +# Failing cases +# (These functions are expected to fail elaboration to KLR) + +def name_not_found(): + return x + +@pytest.mark.parametrize("f", [ + name_not_found, + ]) +def test_fails(f): + F = Parser(f) + with pytest.raises(Exception): + F()