From 9ffeee196e55613780c4c05b0266807179bae1f8 Mon Sep 17 00:00:00 2001 From: Paul Govereau Date: Thu, 9 Jan 2025 13:05:00 -0500 Subject: [PATCH] feat: tracing for python source functions This patch adds basic tracing for user python functions. The main code is in Python.lean, and depends on definitions in Basic.lean and NKI.lean, which are incomplete. As more primitives are implemented, more user kernels will be supported. --- NKL/FFI.lean | 13 +- NKL/Python.lean | 24 ++-- NKL/Trace.lean | 13 +- NKL/Trace/Basic.lean | 3 - NKL/Trace/NKI.lean | 49 +++++++ NKL/Trace/Python.lean | 264 +++++++++++++++++++++++++++++++++++++ interop/test/test_basic.py | 93 +++++++++++++ 7 files changed, 432 insertions(+), 27 deletions(-) create mode 100644 NKL/Trace/NKI.lean create mode 100644 NKL/Trace/Python.lean create mode 100644 interop/test/test_basic.py diff --git a/NKL/FFI.lean b/NKL/FFI.lean index 6f59d36..764d939 100644 --- a/NKL/FFI.lean +++ b/NKL/FFI.lean @@ -5,6 +5,7 @@ Authors: Paul Govereau -/ import Lean import NKL.Python +import NKL.Trace namespace NKL @@ -16,12 +17,6 @@ local instance : MonadLift (Except String) IO where @[export parse_json] def parse_json (s : String) : IO Unit := do let kernel <- Python.Parsing.parse s - let names := kernel.funcs.map fun x => x.fst - let names := String.intercalate "," names - IO.println s!"Found functions: {names}" - for x in kernel.args do - IO.println s!"arg: {repr x}" - for x in kernel.kwargs do - IO.println s!"arg: {repr x}" - for x in kernel.globals do - IO.println s!"global: {repr x}" + let stmts <- NKL.Trace.runNKIKernel kernel + for s in stmts do + IO.println s!"{repr s}" diff --git a/NKL/Python.lean b/NKL/Python.lean index ededec5..b2aedd0 100644 --- a/NKL/Python.lean +++ b/NKL/Python.lean @@ -106,11 +106,11 @@ then the structure will be populated with: defaults = [1, 2] vararg = "args" kwonlyargs = [d, e] - kw_defaults = [None, 3] + kw_defaults = [("e", 3)] kwarg = "kwargs" -Note, defaults and kw_defaults are inconsistent in how they treat -missing arguments, but this is just how it works in the python AST. +Note, this is slightly different from the official Python AST, which +encodes the kw_defaults as a list with None for missing defaults. -/ structure Args where posonlyargs : List String @@ -122,17 +122,13 @@ structure Args where kwarg : Option String deriving Repr -def Args.names (ax : Args) : List String := - let xs := ax.posonlyargs.append ax.args - let xs := match ax.vararg with | none => xs | some x => xs.append [x] - let xs := xs.append ax.kwonlyargs - let xs := match ax.kwarg with | none => xs | some x => xs.append [x] - xs - -def Args.all_defaults (ax : Args) : List (String × Expr') := - let args := ax.posonlyargs ++ ax.args - let dflt := args.reverse.zip ax.defaults.reverse - dflt ++ ax.kw_defaults +def Args.names (args : Args) : List String := + args.posonlyargs ++ args.args ++ args.kwonlyargs + +def Args.all_defaults (args : Args) : List (String × Expr') := + let pargs := args.posonlyargs ++ args.args + let dflt := pargs.reverse.zip args.defaults.reverse + dflt ++ args.kw_defaults structure Fun where source : String diff --git a/NKL/Trace.lean b/NKL/Trace.lean index 8f23ca4..481a552 100644 --- a/NKL/Trace.lean +++ b/NKL/Trace.lean @@ -3,7 +3,18 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ +import NKL.KLR +import NKL.Python import NKL.Trace.Types import NKL.Trace.Basic import NKL.Trace.Builtin ---import NKL.Trace.Python +import NKL.Trace.Python +import NKL.Trace.NKI + +namespace NKL.Trace + +def runNKIKernel (k : NKL.Python.Kernel) : Except String (List NKL.KLR.Stmt) := + tracer ⟨ .ofList NKIEnv, #[] ⟩ do + traceKernel k + let g <- get + return g.body.toList diff --git a/NKL/Trace/Basic.lean b/NKL/Trace/Basic.lean index a1369b0..59d1f7d 100644 --- a/NKL/Trace/Basic.lean +++ b/NKL/Trace/Basic.lean @@ -118,11 +118,8 @@ def unOp : String -> Term -> TraceM Term | op, _ => throw s!"unimp {op}" -- Comparison operators --- TODO: need to think about comparison of tensors, in NKI this is object equality, --- but I suspect this doesn't make sense and may be a source of bugs. def cmpOp : String -> Term -> Term -> TraceM Bool - | "Eq", .expr l _, .expr r _ => return l == r | s, l, r => throw s!"unsupported comparison operator {s} {repr l} {repr r}" def compare : Term -> List String -> List Term -> TraceM Term diff --git a/NKL/Trace/NKI.lean b/NKL/Trace/NKI.lean new file mode 100644 index 0000000..328b9d0 --- /dev/null +++ b/NKL/Trace/NKI.lean @@ -0,0 +1,49 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Govereau +-/ +import NKL.KLR +import NKL.Trace.Types +import NKL.Trace.Builtin + +/- +# NKI built-ins + +This module defines the builtin constants used by tracing for NKI kernels. +-/ +namespace NKL.Trace +open NKL.KLR + +private def module (s : String) : Name × Item := + let name := s.toName + (name, .module name) + +private def const_var (s : String) : Name × Item := + let name := s.toName + (name, .term (.expr (.var s) (.any name))) + +/- +Note: this object contains a bunch of architecture parameters that +need to be set according to which HW we are compiling for. +TODO: figure out the mechanism for this. +-/ +def tile_size : Global := + let name := "nki.langauge.tile_size".toName + { name := name + , attr := attrs + , call := uncallable name + } +where + attrs : GlobalAttr + | "pmax" => return .expr (.const $ .int 128) .int + | a => throw s!"unsupported attribute {a}" + +def NKIEnv : List (Name × Item) := + [ module "nki" + , module "nki.language" + , const_var "nki.language.add" + , const_var "nki.language.load" + , const_var "nki.language.store" + , ("nki.language.tile_size".toName, .global tile_size) + ] diff --git a/NKL/Trace/Python.lean b/NKL/Trace/Python.lean new file mode 100644 index 0000000..87162ad --- /dev/null +++ b/NKL/Trace/Python.lean @@ -0,0 +1,264 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Govereau +-/ +import Lean +import NKL.KLR +import NKL.Python +import NKL.Trace.Types +import NKL.Trace.Basic + +namespace NKL.Trace +open NKL.Python + +def const : Const -> ErrorM Term + | .none => return .expr (.const $ .none) .none + | .bool b => return .expr (.const $ .bool b) .bool + | .int i => return .expr (.const $ .int i) .int + | .float f => return .expr (.const $ .float f) .float + | .string s => return .expr (.const $ .string s) .string + | .ellipsis => throw "unsupported use of ellipsis" + +/- +Evaluating index expressions. + +An index expression occurs only within a subscript expression. For example, in +the expression: + + t[1,1:10,None,x+9] + +all of 1, 1:10, None, and x+9 are indexes. Note None may also be written as +np.newaxis. Also, a None or a slice (or ellipsis) may only occur at the +outer-most level of an index: if you write, e.g. + + t[x+None] + +then the None is interpreted as an integer and not as a new axis. If you write, + + t[(1:2) + 3] + t[... * 8] + +these are syntax errors in python. +-/ + +mutual +-- top-level index expressions: None (a.k.a. np.newaxis) or IndexExpr +def indexExpr? : Option Expr -> Tracer (Option KLR.IndexExpr) + | none => return none + | some (.exprPos (.const .none) _) => return none + | some e => indexExpr e + +-- general sub-expressions +def indexExpr : Expr -> Tracer KLR.IndexExpr + | .exprPos e' p => withPos p (indexExpr' e') + +def indexExpr' : Expr' -> Tracer KLR.IndexExpr + | .const (.int i) => return .int i + | .name id _ => return .var id + | .binOp op l r => return <- indexBinOp op (<- indexExpr l) (<- indexExpr r) + | .unaryOp op e => return <- indexUnOp op (<- indexExpr e) + | _ => throw "invalid index expression" + +-- top-level index: slice, ellipsis, or indexExpr? +-- TODO: get rid of ... +def index : Expr -> Tracer KLR.Index + | .exprPos (.const .ellipsis) _ => return .ellipsis + | .exprPos (.slice l u s) p => withPos p do + return (.slice (<- indexExpr? l) (<- indexExpr? u) (<- indexExpr? s)) + | e => return (.coord (<- indexExpr? e)) +end + +-- Note, a list index can be negative, which means index from end of list. +def list_access (l : List Term) : List KLR.Index -> TraceM Term + | [.coord (some (.int i))] => do + let i := if i < 0 then l.length + i else i + if i < 0 then throw "index out of bounds" + let n := i.toNat + if h:l.length > n then return l.get (Fin.mk n h) + else throw "index out of bounds" + |_ => throw "unsupported __subscript__" + +def access : Term -> List KLR.Index -> TraceM Term + | .object _, _ => throw "builtin object __subscript__ not supported" + | .tuple l, ix + | .list l, ix => list_access l ix + | .expr e _, ix => return .expr (.access e ix) (.any "?".toName) + + +mutual +partial def expr : Expr -> Tracer Item + | .exprPos e' p => withPos p (expr' e') + +partial def term (e : Expr) : Tracer Term := do + match (<- expr e) with + | .module n => return .expr (.var n.toString) (.any "?".toName) + | .global g => return .expr (.var g.name.toString) (.any "?".toName) + | .source _ => throw "invalid use of source function" + | .term t => return t + +partial def term' (e : Expr') : Tracer Term := do + term (.exprPos e (<- getPos)) + +partial def klr (e : Expr) : Tracer KLR.Expr := do + match (<- term e) with + | .object obj => return .var obj.name.toString + | .tuple _ => throw "tuple cannot be converted to a KLR term" + | .list _ => throw "list cannot be converted to a KLR term" + | .expr e _ => return e + +partial def integer (e : Expr) : Tracer Int := do + match (<- term e) with + | .expr (.const c) _ => return (<- c.toInt) + | _ => throw "invalid tensor dimension" + +partial def expr' : Expr' -> Tracer Item + | .const c => return .term (<- const c) + | .tensor s dty => do + let shape <- s.mapM integer + return .term (.expr (.tensor ⟨ dty, shape ⟩) (.tensor dty shape)) + | .name id _ => lookup_item id.toName + | .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id) + | .tuple l _ => return .term (.tuple (<- l.mapM term)) + | .list l _ => return .term (.list (<- l.mapM term)) + | .subscript t [ .exprPos (.tuple ix _) _ ] _ + | .subscript t ix _ => return .term (<- access (<- term t) (<- ix.mapM index)) + | .slice _ _ _ => throw "syntax error" + | .boolOp op xs => return .term (<- boolOp op (<- xs.mapM term)) + | .binOp op l r => return .term (<- binOp op (<- term l) (<- term r)) + | .unaryOp op e => return .term (<- unOp op (<- term e)) + | .compare l ops cs => return .term (<- compare (<- term l) ops (<- cs.mapM term)) + | .ifExp tst tru fls => do + let tst <- (<- term tst).isTrue + let tru <- expr tru -- eagerly evaluate both branches + let fls <- expr fls -- to report errors to user + return if tst then tru else fls + | .call f args kws => do + match <- expr f with + | .module n => throw s!"module {n} not callable" + | .global g => return .term (<- g.call (<- args.mapM term) (<- kws.mapM (keyword term))) + | .term t => return .term (<- t.call (<- args.mapM klr) (<- kws.mapM (keyword klr))) + | .source f => do + function_call f (<- args.mapM term) (<- kws.mapM (keyword term)) + return .term (.expr (.const .none) .none) + +partial def keyword (f : Expr -> Tracer a) : Keyword -> Tracer (String × a) + | .keyword id e p => withPos p do return (id, (<- f e)) + +-- When looking for a variable we rely on the store attribute +-- from the Python parser to check if it is a defining use. +partial def var : Expr -> Tracer String + | .exprPos (.name id .store) _ => return id + | _ => throw "expecting variable" + +-- When we perform an assignment, we will either add to the environment +-- the term found on the RHS, or the variable itself. The latter case +-- allows us to lookup and find the variable without substituting +-- its definition. +partial def assign (xs : List Expr) (e : Expr) : Tracer Unit := do + let xs <- xs.mapM var + let e <- term e + match e with + | .expr (.const _) _ => xs.forM fun x => extend x.toName e + | .expr e ty => xs.forM fun x => do + extend x.toName (.expr (.var xs[0]!) ty) + add_stmt (KLR.Stmt.assign x e) + | t => xs.forM fun x => extend x.toName t + +partial def stmt : Stmt -> Tracer Unit + | .stmtPos s' p => withPos p (stmt' s') + +partial def stmt' : Stmt' -> Tracer Unit + | .expr (.exprPos (.const _) _) => return () + | .expr e => do + match <- term e with + | .expr e _ => add_stmt (.expr e) + | _ => return () -- effects are done, can be removed from KLR + | .assert e => do + let t <- term e + if (<- t.isFalse) then throw "assertion failed" + | .assign xs e => assign xs e + | .augAssign x op e => do + stmt' (.assign [x] (.exprPos (.binOp op x e) (<- getPos))) + | .annAssign _ _ .none => return () + | .annAssign x _ (.some e) => stmt' (.assign [x] e) + | _s => throw "not yet implemented" --s!"unimp {repr s}" + +-- Bind positional and keyword arguments to a Python function based on its +-- signature. + +partial def bind_args (f : Fun) + (args : List Term) + (kwargs : List (String × Term)) + : Tracer (List (String × Term)) := do + if f.args.vararg != none || f.args.kwarg != none then + throw "var args not supported" + if args.length < f.args.posonlyargs.length then + throw "not enough arguments" + let dflts := f.args.all_defaults + let names := f.args.names + if args.length + kwargs.length > names.length then + throw "too many arguments supplied (varargs not supported)" + let argmap <- f.args.names.enum.mapM fun (i,x) => do + if h:args.length > i then + return (x, args.get (Fin.mk i h)) + else if let some v := kwargs.lookup x then + return (x, v) + else if let some e := dflts.lookup x then + return (x, <- term' e) + else + throw s!"argument {x} not supplied" + return argmap + +-- For a function call, first evaluate the argument in the current environment. +-- Then enter a new environment and evaluate the function statements. +partial def function_call (f : Fun) + (args : List Term) + (kwargs : List (String × Term)) + : Tracer Unit := do + let args <- bind_args f args kwargs + --let args <- args.mapM fun (x,e) => return (x, e) + withSrc f.source $ enterFun $ do + args.forM fun (x,e) => do extend x.toName e + f.body.forM stmt + +end + +/- +Evaluate each global in the current environment, skipping any globals that are +already defined. We do not redefine globals because, we may have picked up +functions with dummy implementations, e.g., nki.language.add is defined as: + + def add(x,y): pass + +in some versions of the code. We do not want this to shadow a built-in +definition of add. If we have an internal definition, we will use this over +anything found during parsing. +-/ + +private def globals (k : Kernel) : Tracer Unit := do + let s <- get + for (n, f) in k.funcs do + let n := n.toName + if not (s.env.contains n) then + extend_global n (.source f) + for (n,e) in k.globals do + let n := n.toName + if not (s.env.contains n) then + extend_global n (<- expr' e) + +-- Call the top-level kernel function +def traceKernel (k : Kernel) : Tracer Unit := do + globals k + match k.funcs.lookup k.entry with + | none => throw s!"function {k.entry} not found" + | some f => do + let args <- k.args.mapM term' + let kwargs <- k.kwargs.mapM fun (x,e) => return (x, <- term' e) + function_call f args kwargs + +def runKernel (k : Kernel) : Except String (List KLR.Stmt) := + tracer ⟨ ∅, #[] ⟩ do + traceKernel k + let g <- get + return g.body.toList diff --git a/interop/test/test_basic.py b/interop/test/test_basic.py new file mode 100644 index 0000000..71c64be --- /dev/null +++ b/interop/test/test_basic.py @@ -0,0 +1,93 @@ +import numpy as np +import nki +import pytest + +from nkl.parser import Parser + +# Success cases +# (these functions should load and trace to KLR) + +def const_stmt(t): + "this will be ignored because it has no effect" + 1 # so will this, it is a simple constant + 1.0 # so will this + False # and this + None # and this + (1,2) # and this + [1,2] # and this + +string = "a string" +integer = -3 +floating = 1.23 +boolean = True +nothing = None +triple = (1, floating, False) +list3 = [string, triple, nki] + +def expr_name(t): + # these names will end up in the global environment after parsing + # they will be eliminated after substitution during tracing + string, integer, floating, boolean, nothing + # constant tuples are also OK + triple + # as are constant lists + list3 + # as are module references + nki + +def expr_tuple(t): + assert (1,False,"hello") + +def expr_list(t): + assert [1,2,nki] + assert not [] + +def expr_subscript(t): + t[1] + t[1,2,3] + t[1:2:3] + t[1:2] + t[1:] + t[1::] + t[1::2] + t[1:2:None] + t[1:None:2] + t[:] + t[:,:] + t[...] + t[1,...] + t[:,None] + t[1] + +def expr_bool_op(t): + True and 1 and [1] and [] and True # evals to [] + False or None or [] or 1 # evals to 1 + 1 or None # evals to 1 + (False,) or 1 # evals to (False,) + +@pytest.mark.parametrize("f", [ + const_stmt, + expr_name, + expr_tuple, + expr_list, + expr_subscript, + expr_bool_op, + ]) +def test_succeed(f): + t = np.ndarray(10) + F = Parser(f) + F(t) + +# Failing cases +# (These functions are expected to fail elaboration to KLR) + +def name_not_found(): + return x + +@pytest.mark.parametrize("f", [ + name_not_found, + ]) +def test_fails(f): + F = Parser(f) + with pytest.raises(Exception): + F()