Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracing for python source functions #20

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Paul Govereau
-/
import Lean
import NKL.Python
import NKL.Trace

namespace NKL

Expand All @@ -16,12 +17,6 @@ local instance : MonadLift (Except String) IO where
@[export parse_json]
def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
let names := kernel.funcs.map fun x => x.fst
let names := String.intercalate "," names
IO.println s!"Found functions: {names}"
for x in kernel.args do
IO.println s!"arg: {repr x}"
for x in kernel.kwargs do
IO.println s!"arg: {repr x}"
for x in kernel.globals do
IO.println s!"global: {repr x}"
let stmts <- NKL.Trace.runNKIKernel kernel
for s in stmts do
IO.println s!"{repr s}"
24 changes: 10 additions & 14 deletions NKL/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ then the structure will be populated with:
defaults = [1, 2]
vararg = "args"
kwonlyargs = [d, e]
kw_defaults = [None, 3]
kw_defaults = [("e", 3)]
kwarg = "kwargs"
Note, defaults and kw_defaults are inconsistent in how they treat
missing arguments, but this is just how it works in the python AST.
Note, this is slightly different from the official Python AST, which
encodes the kw_defaults as a list with None for missing defaults.
-/
structure Args where
posonlyargs : List String
Expand All @@ -122,17 +122,13 @@ structure Args where
kwarg : Option String
deriving Repr

def Args.names (ax : Args) : List String :=
let xs := ax.posonlyargs.append ax.args
let xs := match ax.vararg with | none => xs | some x => xs.append [x]
let xs := xs.append ax.kwonlyargs
let xs := match ax.kwarg with | none => xs | some x => xs.append [x]
xs

def Args.all_defaults (ax : Args) : List (String × Expr') :=
let args := ax.posonlyargs ++ ax.args
let dflt := args.reverse.zip ax.defaults.reverse
dflt ++ ax.kw_defaults
def Args.names (args : Args) : List String :=
args.posonlyargs ++ args.args ++ args.kwonlyargs

def Args.all_defaults (args : Args) : List (String × Expr') :=
let pargs := args.posonlyargs ++ args.args
let dflt := pargs.reverse.zip args.defaults.reverse
dflt ++ args.kw_defaults

structure Fun where
source : String
Expand Down
13 changes: 12 additions & 1 deletion NKL/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.Python
import NKL.Trace.Types
import NKL.Trace.Basic
import NKL.Trace.Builtin
--import NKL.Trace.Python
import NKL.Trace.Python
import NKL.Trace.NKI

namespace NKL.Trace

def runNKIKernel (k : NKL.Python.Kernel) : Except String (List NKL.KLR.Stmt) :=
tracer ⟨ .ofList NKIEnv, #[] ⟩ do
traceKernel k
seanmcl marked this conversation as resolved.
Show resolved Hide resolved
let g <- get
return g.body.toList
3 changes: 0 additions & 3 deletions NKL/Trace/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,8 @@ def unOp : String -> Term -> TraceM Term
| op, _ => throw s!"unimp {op}"

-- Comparison operators
-- TODO: need to think about comparison of tensors, in NKI this is object equality,
-- but I suspect this doesn't make sense and may be a source of bugs.

def cmpOp : String -> Term -> Term -> TraceM Bool
| "Eq", .expr l _, .expr r _ => return l == r
| s, l, r => throw s!"unsupported comparison operator {s} {repr l} {repr r}"

def compare : Term -> List String -> List Term -> TraceM Term
Expand Down
49 changes: 49 additions & 0 deletions NKL/Trace/NKI.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.Trace.Types
import NKL.Trace.Builtin

/-
# NKI built-ins

This module defines the builtin constants used by tracing for NKI kernels.
-/
namespace NKL.Trace
open NKL.KLR

private def module (s : String) : Name × Item :=
let name := s.toName
(name, .module name)

private def const_var (s : String) : Name × Item :=
let name := s.toName
(name, .term (.expr (.var s) (.any name)))

/-
Note: this object contains a bunch of architecture parameters that
need to be set according to which HW we are compiling for.
TODO: figure out the mechanism for this.
-/
def tile_size : Global :=
let name := "nki.langauge.tile_size".toName
{ name := name
, attr := attrs
, call := uncallable name
}
where
attrs : GlobalAttr
| "pmax" => return .expr (.const $ .int 128) .int
govereau marked this conversation as resolved.
Show resolved Hide resolved
| a => throw s!"unsupported attribute {a}"

def NKIEnv : List (Name × Item) :=
[ module "nki"
, module "nki.language"
, const_var "nki.language.add"
, const_var "nki.language.load"
, const_var "nki.language.store"
, ("nki.language.tile_size".toName, .global tile_size)
]
Loading
Loading