Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: basic tracing for Terms #19

Merged
merged 1 commit into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NKL/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.Trace.Types
--import NKL.Trace.Basic
import NKL.Trace.Basic
import NKL.Trace.Builtin
--import NKL.Trace.Python
161 changes: 161 additions & 0 deletions NKL/Trace/Basic.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.Trace.Types
import NKL.Trace.Tensor

/-
# Basic tracing facilities

Basic tracing definitions only deal with Terms (not Python sources)
-/

namespace NKL.Trace
open NKL.KLR

-- Operators within index expressions

def indexBinOp : String -> KLR.IndexExpr -> KLR.IndexExpr -> ErrorM KLR.IndexExpr
| "Add" , l, r => return .add l r
seanmcl marked this conversation as resolved.
Show resolved Hide resolved
| "Sub" , l, r => return .add l r.neg
| "Mult", .int i, e
| "Mult", e, .int i => return .mul i e
| "Div" , e, .int i => return .floor e i
| "Mod" , e, .int i => return .mod e i
| _, _, _ => throw "invalid index expression"

def indexUnOp : String -> KLR.IndexExpr -> ErrorM KLR.IndexExpr
| "USub", e => return .neg e
| _, _ => throw "invalid index expresssion"

-- Truthiness of Terms following Python

def Term.isTrue : Term -> TraceM Bool
| .object _ => return true
| .tuple [] => return false
| .tuple _ => return true
| .list [] => return false
| .list _ => return true
| .expr (.const c) _ => return c.isTrue
| .expr _ _ => throw "non-constant expression"

def Term.isFalse (t : Term) : TraceM Bool :=
return not (<- t.isTrue)

-- Following Python semantics, boolean operators return
-- the first value that is convertible to True or False

def boolOp (op : String) (es : List Term) : TraceM Term := do
bop (<- bopFn op) es
where
bop fn : List Term -> TraceM Term
| [] => throw "invalid expression"
| [x] => return x
| x :: xs => do if (<- fn x) then return x else bop fn xs
bopFn : String -> TraceM (Term -> TraceM Bool)
| "Or" => return Term.isTrue
| "And" => return Term.isFalse
| s => throw s!"unsupported boolean operator {s}"

-- Binary Operators

-- Multiply a sequence (tuple, list, string) by a scalar
govereau marked this conversation as resolved.
Show resolved Hide resolved
-- It is tempting to use Const.toInt here, but that would be
-- more permissive than Python. The only allowed cases are:
-- [1,2] * 2 => [1,2,1,2]
-- [1,2] * 0 => []
-- [1,2] * -10 => []
-- [1,2] * True => [1,2]
-- [1,2] * False => []

private def mulseq (l : List a) : Const -> TraceM (List a)
| .bool false => return []
| .bool true => return l
| .int i => return List.flatten $ List.replicate i.toNat l
| _ => throw "invalid multiply"

-- Binary operators on constants
private def constOp : String -> Const -> Const -> TraceM Term
| "Add", .int l, .int r => return int (l + r)
| "Sub", .int l, .int r => return int (l - r)
| "Mult", .int l, .int r => return int (l * r)
| "Div", .int l, .int r => return int (l / r)
| _,_,_ => throw "unimp"
where
int (i : Int) : Term := .expr (.const (.int i)) .int

-- Binary operators on tensors (see Tensor.lean)
private def exprOp : String -> Expr -> Expr -> TraceM Term
-- tensors
| op, .tensor l, .tensor r => tensor_tensor op l r
| op, .tensor t, .const c => tensor_scalar op t c
| op, .const c, .tensor t => scalar_tensor op c t
| _ , .tensor _, _
| _ , _ , .tensor _ => throw "invalid tensor op"
-- constants
| op, .const l, .const r => constOp op l r
| _ , _ , _ => throw "non-constant expression"

-- Binary operators on terms
def binOp : String -> Term -> Term -> TraceM Term
-- lists and tuples
| "Add", .list l, .list r => return .list (l ++ r)
| "Add", .tuple l, .tuple r => return .tuple (l ++ r)
| "Mult", .list l, .expr (.const c) _
| "Mult", .expr (.const c) _, .list l => return .list (<- mulseq l c)
| "Mult", .tuple l, .expr (.const c) _
| "Mult", .expr (.const c) _, .tuple l => return .tuple (<- mulseq l c)
| op , .expr l _, .expr r _ => exprOp op l r
| _, _, _ => throw "unsupported operator"

-- Unary operators
def unOp : String -> Term -> TraceM Term
| op , .expr (.tensor t) _ => tensor_op op t
| "Not", t => return .expr (.const $ .bool (<- t.isFalse)) .bool
| op, _ => throw s!"unimp {op}"

-- Comparison operators
-- TODO: need to think about comparison of tensors, in NKI this is object equality,
-- but I suspect this doesn't make sense and may be a source of bugs.

def cmpOp : String -> Term -> Term -> TraceM Bool
| "Eq", .expr l _, .expr r _ => return l == r
seanmcl marked this conversation as resolved.
Show resolved Hide resolved
| s, l, r => throw s!"unsupported comparison operator {s} {repr l} {repr r}"

def compare : Term -> List String -> List Term -> TraceM Term
| x, [op], [y] => return bool (<- cmpOp op x y)
| x, op::ops, y::ys => do
if (<- cmpOp op x y)
then compare y ops ys
else return (bool false)
| _, _, _ => throw "invalid comparison"
where
bool b := .expr (.const $ .bool b) .bool

def Term.attr : Term -> String -> TraceM Term
| .object o, id => o.attr id
| .expr _ (.tensor d _), "dtype" => return (str d)
| .expr _ (.tensor _ s), "shape" => return (list s)
| .expr e _, id => throw s!"unsupported attribute {id} on {repr e}"
| t, id => throw s!"unsupported attribute {id} on {repr t}"
where
str s := .expr (.const $ .string s) .string
list l := .list $ l.map fun i => .expr (.const (.int i)) .int

def Item.attr : Item -> String -> Tracer Item
| .module n, id => lookup_global (n.str id)
| .global g, id => return .term (<- g.attr id)
| .source _, id => throw s!"unsupported attribute {id}"
| .term t, id => return .term (<- t.attr id)

def Term.call (f : Term)
(args : List Expr)
(kws : List (String × Expr)) : TraceM Term := do
match f with
| .object o => o.call args kws
| .tuple _ => throw "tuple is not a callable type"
seanmcl marked this conversation as resolved.
Show resolved Hide resolved
| .list _ => throw "list is not a callable type"
| .expr f _ => return .expr (.call f args kws) (.any "?".toName)
seanmcl marked this conversation as resolved.
Show resolved Hide resolved
43 changes: 43 additions & 0 deletions NKL/Trace/Tensor.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.Trace.Types

/-
# Tracing for Tensor related operations

TODO: These are just place holders...
-/
namespace NKL.Trace
open NKL.KLR

private def tensor_call (op : String) (args : List Expr) : Term :=
let type := if let .tensor t :: _ := args
then TermType.tensor t.dtype t.shape
else TermType.any "?".toName
let name := Expr.var ("tensor_".append op)
.expr (.call name args []) type

-- Unary operations on tensors

def tensor_op (op : String) (t : Tensor) : TraceM Term :=
return tensor_call op [.tensor t]

-- Binary operations on tensors / scalars

def tensor_tensor (op : String) (l r : Tensor) : TraceM Term :=
return tensor_call op [.tensor l, .tensor r]

private def broadcast (t : Tensor) (c : Const) : Expr :=
let args := t.shape.map fun i => Expr.const (.int i)
let args := .const c :: args
.call (.var "broadcast") args []

def tensor_scalar (op : String) (t : Tensor) (c : Const) : TraceM Term :=
return tensor_call op [ .tensor t, broadcast t c]

def scalar_tensor (op : String) (c : Const) (t : Tensor) : TraceM Term :=
return tensor_call op [ .tensor t, broadcast t c]
seanmcl marked this conversation as resolved.
Show resolved Hide resolved
Loading