diff --git a/NKL/FFI.lean b/NKL/FFI.lean index 6f59d36..764d939 100644 --- a/NKL/FFI.lean +++ b/NKL/FFI.lean @@ -5,6 +5,7 @@ Authors: Paul Govereau -/ import Lean import NKL.Python +import NKL.Trace namespace NKL @@ -16,12 +17,6 @@ local instance : MonadLift (Except String) IO where @[export parse_json] def parse_json (s : String) : IO Unit := do let kernel <- Python.Parsing.parse s - let names := kernel.funcs.map fun x => x.fst - let names := String.intercalate "," names - IO.println s!"Found functions: {names}" - for x in kernel.args do - IO.println s!"arg: {repr x}" - for x in kernel.kwargs do - IO.println s!"arg: {repr x}" - for x in kernel.globals do - IO.println s!"global: {repr x}" + let stmts <- NKL.Trace.runNKIKernel kernel + for s in stmts do + IO.println s!"{repr s}" diff --git a/NKL/Python.lean b/NKL/Python.lean index ededec5..b2aedd0 100644 --- a/NKL/Python.lean +++ b/NKL/Python.lean @@ -106,11 +106,11 @@ then the structure will be populated with: defaults = [1, 2] vararg = "args" kwonlyargs = [d, e] - kw_defaults = [None, 3] + kw_defaults = [("e", 3)] kwarg = "kwargs" -Note, defaults and kw_defaults are inconsistent in how they treat -missing arguments, but this is just how it works in the python AST. +Note, this is slightly different from the official Python AST, which +encodes the kw_defaults as a list with None for missing defaults. -/ structure Args where posonlyargs : List String @@ -122,17 +122,13 @@ structure Args where kwarg : Option String deriving Repr -def Args.names (ax : Args) : List String := - let xs := ax.posonlyargs.append ax.args - let xs := match ax.vararg with | none => xs | some x => xs.append [x] - let xs := xs.append ax.kwonlyargs - let xs := match ax.kwarg with | none => xs | some x => xs.append [x] - xs - -def Args.all_defaults (ax : Args) : List (String × Expr') := - let args := ax.posonlyargs ++ ax.args - let dflt := args.reverse.zip ax.defaults.reverse - dflt ++ ax.kw_defaults +def Args.names (args : Args) : List String := + args.posonlyargs ++ args.args ++ args.kwonlyargs + +def Args.all_defaults (args : Args) : List (String × Expr') := + let pargs := args.posonlyargs ++ args.args + let dflt := pargs.reverse.zip args.defaults.reverse + dflt ++ args.kw_defaults structure Fun where source : String diff --git a/NKL/Trace.lean b/NKL/Trace.lean index 8f23ca4..b6fe239 100644 --- a/NKL/Trace.lean +++ b/NKL/Trace.lean @@ -3,7 +3,18 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ +import NKL.KLR +import NKL.Python import NKL.Trace.Types import NKL.Trace.Basic import NKL.Trace.Builtin ---import NKL.Trace.Python +import NKL.Trace.Python +import NKL.Trace.NKI + +namespace NKL.Trace + +def runNKIKernel (k : NKL.Python.Kernel) : Except String (List NKL.KLR.Stmt) := + tracer ⟨ .ofList nki_env, #[] ⟩ do + traceKernel k + let g <- get + return g.body.toList diff --git a/NKL/Trace/NKI.lean b/NKL/Trace/NKI.lean new file mode 100644 index 0000000..62987e8 --- /dev/null +++ b/NKL/Trace/NKI.lean @@ -0,0 +1,32 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Govereau +-/ +import NKL.KLR +import NKL.Trace.Types +import NKL.Trace.Builtin + +/- +# NKI built-ins + +This module defines the builtin constants used by tracing for NKI kernels. +-/ +namespace NKL.Trace +open NKL.KLR + +private def module (s : String) : Name × Item := + let name := s.toName + (name, .module name) + +private def const_var (s : String) : Name × Item := + let name := s.toName + (name, .term (.expr (.var s) (.any name))) + +def nki_env : List (Name × Item) := + [ module "nki" + , module "nki.language" + , const_var "nki.language.add" + , const_var "nki.language.load" + , const_var "nki.language.store" + ] diff --git a/NKL/Trace/Python.lean b/NKL/Trace/Python.lean new file mode 100644 index 0000000..107b762 --- /dev/null +++ b/NKL/Trace/Python.lean @@ -0,0 +1,214 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Govereau +-/ +import Lean +import NKL.KLR +import NKL.Python +import NKL.Trace.Types +import NKL.Trace.Basic + +namespace NKL.Trace +open NKL.Python + +def const : Const -> TraceM Term + | .none => return .expr (.const $ .none) .none + | .bool b => return .expr (.const $ .bool b) .bool + | .int i => return .expr (.const $ .int i) .int + | .float f => return .expr (.const $ .float f) .float + | .string s => return .expr (.const $ .string s) .string + | .ellipsis => throw "unsupported use of ellipsis" + +mutual +def indexExpr : Expr -> Tracer KLR.IndexExpr + | .exprPos e' p => withPos p (indexExpr' e') + +def indexExpr' : Expr' -> Tracer KLR.IndexExpr + | .const (.int i) => return .int i + | .const c => throw s!"invalid constant {repr c} in index expression" + | .name id _ => return .var id + | .binOp op l r => return <- indexBinOp op (<- indexExpr l) (<- indexExpr r) + | .unaryOp op e => return <- indexUnOp op (<- indexExpr e) + | _ => throw "invalid index expression" + +def indexExpr? : Option Expr -> Tracer (Option KLR.IndexExpr) + | none => return none + | some (.exprPos (.const .none) _) => return none + | some e => indexExpr e + +def index : Expr -> Tracer KLR.Index + | .exprPos (.const .ellipsis) _ => return .ellipsis + | .exprPos (.slice l u s) p => withPos p do + return (.slice (<- indexExpr? l) (<- indexExpr? u) (<- indexExpr? s)) + | e => return (.coord (<- indexExpr? e)) +end + + +mutual +partial def expr : Expr -> Tracer Item + | .exprPos e' p => withPos p (expr' e') + +partial def term (e : Expr) : Tracer Term := do + match (<- expr e) with + | .module n => return .expr (.var n.toString) (.any "?".toName) + | .global g => return .expr (.var g.name.toString) (.any "?".toName) + | .source _ => throw "invalid use of source function" + | .term t => return t + +partial def term' (e : Expr') : Tracer Term := do + term (.exprPos e (<- getPos)) + +partial def klr (e : Expr) : Tracer KLR.Expr := do + match (<- term e) with + | .object obj => return .var obj.name.toString + | .tuple _ => throw "tuple cannot be converted to a KLR term" + | .list _ => throw "list cannot be converted to a KLR term" + | .expr e _ => return e + +partial def integer (e : Expr) : Tracer Int := do + match (<- term e) with + | .expr (.const c) _ => return (<- c.toInt) + | _ => throw "invalid tensor dimension" + +partial def expr' : Expr' -> Tracer Item + | .const c => return .term (<- const c) + | .tensor s dty => do + let shape <- s.mapM integer + return .term (.expr (.tensor ⟨ dty, shape ⟩) (.tensor dty shape)) + | .name id _ => lookup_item id.toName + | .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id) + | .tuple l _ => return .term (.tuple (<- l.mapM term)) + | .list l _ => return .term (.list (<- l.mapM term)) + | .subscript t [ .exprPos (.tuple ix _) _ ] _ + | .subscript t ix _ => return .term (.expr (.access (<- klr t) (<- ix.mapM index)) (.any "?".toName)) + | .slice _ _ _ => throw "syntax error" + | .boolOp op xs => return .term (<- boolOp op (<- xs.mapM term)) + | .binOp op l r => return .term (<- binOp op (<- term l) (<- term r)) + | .unaryOp op e => return .term (<- unOp op (<- term e)) + | .compare l ops cs => return .term (<- compare (<- term l) ops (<- cs.mapM term)) + | .ifExp tst tru fls => do + let tst <- (<- term tst).isTrue + let tru <- expr tru -- eagerly evaluate both branches + let fls <- expr fls -- to report errors to user + return if tst then tru else fls + | .call f args kws => do + match <- expr f with + | .module n => throw s!"module {n} not callable" + | .global g => return .term (<- g.call (<- args.mapM term) (<- kws.mapM (keyword term))) + | .term t => return .term (<- t.call (<- args.mapM klr) (<- kws.mapM (keyword klr))) + | .source f => do + function_call f (<- args.mapM term) (<- kws.mapM (keyword term)) + return .term (.expr (.const .none) .none) + +partial def keyword (f : Expr -> Tracer a) : Keyword -> Tracer (String × a) + | .keyword id e p => withPos p do return (id, (<- f e)) + + +partial def var (e : Expr) : Tracer String := do + match (<- klr e) with + | .var s => return s + | _ => throw "expecting variable" + +partial def assign (xs : List Expr) (e : Expr) : Tracer Unit := do + let xs <- xs.mapM var + let e <- term e + xs.forM fun x => extend x.toName e + if let .expr e _ := e then + xs.forM fun x => add_stmt (KLR.Stmt.assign x e) + +partial def stmt : Stmt -> Tracer Unit + | .stmtPos s' p => withPos p (stmt' s') + +partial def stmt' : Stmt' -> Tracer Unit + | .expr (.exprPos (.const _) _) => return () + | .expr e => do + match <- term e with + | .expr e _ => add_stmt (.expr e) + | _ => return () -- effects are done, can be removed from KLR + | .assert e => do + let t <- term e + if (<- t.isFalse) then throw "assertion failed" + | .assign xs e => assign xs e + | .augAssign x op e => do + stmt' (.assign [x] (.exprPos (.binOp op x e) (<- getPos))) + | .annAssign _ _ .none => return () + | .annAssign x _ (.some e) => stmt' (.assign [x] e) + | _s => throw "not yet implemented" --s!"unimp {repr s}" + +-- Bind positional and keyword arguments to a Python function. +-- Note: default arguments should be evaluated in the global environment, +-- however we know that each source function begins with an empty local +-- environment, so it is OK to evaluate the default arguments in the +-- functions initial environment. + +partial def bind_args (f : Fun) + (args : List Term) + (kwargs : List (String × Term)) + : Tracer (List (String × Term)) := do + if f.args.vararg != none || f.args.kwarg != none then + throw "var args not supported" + if args.length < f.args.posonlyargs.length then + throw "not enough arguments" + let dflts := f.args.all_defaults + let names := f.args.names + if args.length + kwargs.length > names.length then + throw "too many arguments supplied (varargs not supported)" + let argmap <- f.args.names.enum.mapM fun (i,x) => do + if h:args.length > i then + return (x, args.get (Fin.mk i h)) + else if let some v := kwargs.lookup x then + return (x, v) + else if let some e := dflts.lookup x then + return (x, <- term' e) + else + throw s!"argument {x} not supplied" + return argmap + +-- For a function call, first evaluate the argument in the current environment. +-- Then enter a new environment and evaluate the function statements. +partial def function_call (f : Fun) + (args : List Term) + (kwargs : List (String × Term)) + : Tracer Unit := do + let args <- bind_args f args kwargs + let args <- args.mapM fun (x,e) => return (x, e) + withSrc f.source $ enterFun $ do + args.forM fun (x,e) => do extend x.toName e + f.body.forM stmt + +end + +-- Evaluate each global in the current environment, skipping any globals that +-- are already defined. Note, we may have globals or functions with dummy +-- implementations, e.g. +-- def add(x,y): pass +-- If we have an internal definition, we will use this over anything found +-- during parsing. + +private def globals (k : Kernel) : Tracer Unit := do + let s <- get + for (n, f) in k.funcs do + let n := n.toName + if not (s.env.contains n) then + extend_global n (.source f) + for (n,e) in k.globals do + let n := n.toName + if not (s.env.contains n) then + extend_global n (<- expr' e) + +-- Call the top-level kernel function +def traceKernel (k : Kernel) : Tracer Unit := do + globals k + match k.funcs.lookup k.entry with + | none => throw s!"function {k.entry} not found" + | some f => do + let args <- k.args.mapM term' + let kwargs <- k.kwargs.mapM fun (x,e) => return (x, <- term' e) + function_call f args kwargs + +def runKernel (k : Kernel) : Except String (List KLR.Stmt) := + tracer ⟨ ∅, #[] ⟩ do + traceKernel k + let g <- get + return g.body.toList diff --git a/interop/test/test_basic.py b/interop/test/test_basic.py new file mode 100644 index 0000000..71c64be --- /dev/null +++ b/interop/test/test_basic.py @@ -0,0 +1,93 @@ +import numpy as np +import nki +import pytest + +from nkl.parser import Parser + +# Success cases +# (these functions should load and trace to KLR) + +def const_stmt(t): + "this will be ignored because it has no effect" + 1 # so will this, it is a simple constant + 1.0 # so will this + False # and this + None # and this + (1,2) # and this + [1,2] # and this + +string = "a string" +integer = -3 +floating = 1.23 +boolean = True +nothing = None +triple = (1, floating, False) +list3 = [string, triple, nki] + +def expr_name(t): + # these names will end up in the global environment after parsing + # they will be eliminated after substitution during tracing + string, integer, floating, boolean, nothing + # constant tuples are also OK + triple + # as are constant lists + list3 + # as are module references + nki + +def expr_tuple(t): + assert (1,False,"hello") + +def expr_list(t): + assert [1,2,nki] + assert not [] + +def expr_subscript(t): + t[1] + t[1,2,3] + t[1:2:3] + t[1:2] + t[1:] + t[1::] + t[1::2] + t[1:2:None] + t[1:None:2] + t[:] + t[:,:] + t[...] + t[1,...] + t[:,None] + t[1] + +def expr_bool_op(t): + True and 1 and [1] and [] and True # evals to [] + False or None or [] or 1 # evals to 1 + 1 or None # evals to 1 + (False,) or 1 # evals to (False,) + +@pytest.mark.parametrize("f", [ + const_stmt, + expr_name, + expr_tuple, + expr_list, + expr_subscript, + expr_bool_op, + ]) +def test_succeed(f): + t = np.ndarray(10) + F = Parser(f) + F(t) + +# Failing cases +# (These functions are expected to fail elaboration to KLR) + +def name_not_found(): + return x + +@pytest.mark.parametrize("f", [ + name_not_found, + ]) +def test_fails(f): + F = Parser(f) + with pytest.raises(Exception): + F()