Skip to content

Commit

Permalink
some op (#65)
Browse files Browse the repository at this point in the history
* some `op`

* apply code review suggestion

* concise scoping in `ast.Expr`
  • Loading branch information
cyx-6 authored and junrushao committed Jul 13, 2022
1 parent 6937916 commit 3191b77
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 7 deletions.
24 changes: 23 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
* \return The result expression.
*/
TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());

/*!
* \brief Check if x is nullptr.
* \param x The input data
* \param span The location of this operation in the source.
* \return The result expression.
*/
TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
/*!
* \brief Check if x is infinite.
* \param x The input data
Expand Down Expand Up @@ -879,6 +885,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
Span span = Span());

/*!
* \brief Returns the address of an element in the buffer
* \param buffer_load The input BufferLoad.
* \param span The location of this operation in the source.
* \return The address of an element in the buffer.
*/
TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());

/*!
* \brief Returns the param by name
* \param param_name The param name.
* \param span The location of this operation in the source.
* \return The handle of param.
*/
TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/builder/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
preflattened_buffer,
prim_func,
)
from .var import Buffer, buffer_decl
from .var import Buffer, buffer_decl, var
from .stmt import (
Assert,
let,
Expand Down
25 changes: 23 additions & 2 deletions python/tvm/script/builder/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,24 @@ def wrapped(*args, **kwargs):
return wrapped


def dtype_forward(func):
def forwarded(*args, **kwargs):
if "dtype" in kwargs:
args = (kwargs.pop("dtype"),) + args
return func(*args, **kwargs)

return forwarded


abs = op_wrapper(op.abs)
acos = op_wrapper(op.acos)
acosh = op_wrapper(op.acosh)
address_of = op_wrapper(op.address_of)
asin = op_wrapper(op.asin)
asinh = op_wrapper(op.asinh)
atan = op_wrapper(op.atan)
atan2 = op_wrapper(op.atan2)
atanh = op_wrapper(op.atanh)
call_extern = op_wrapper(op.call_extern)
call_packed = op_wrapper(op.call_packed)
ceil = op_wrapper(op.ceil)
clz = op_wrapper(op.clz)
comm_reducer = op_wrapper(op.comm_reducer)
Expand All @@ -60,17 +68,22 @@ def wrapped(*args, **kwargs):
isfinite = op_wrapper(op.isfinite)
isinf = op_wrapper(op.isinf)
isnan = op_wrapper(op.isnan)
isnullptr = op_wrapper(op.isnullptr)
ldexp = op_wrapper(op.ldexp)
likely = op_wrapper(op.likely)
log = op_wrapper(op.log)
log1p = op_wrapper(op.log1p)
log2 = op_wrapper(op.log2)
log10 = op_wrapper(op.log10)
lookup_param = op_wrapper(op.lookup_param)
max_value = op_wrapper(op.max_value)
min_value = op_wrapper(op.min_value)
nearbyint = op_wrapper(op.nearbyint)
nextafter = op_wrapper(op.nextafter)
popcount = op_wrapper(op.popcount)
power = op_wrapper(op.power)
q_multiply_shift = op_wrapper(op.q_multiply_shift)
ret = op_wrapper(op.ret)
reinterpret = op_wrapper(op.reinterpret)
round = op_wrapper(op.round)
rsqrt = op_wrapper(op.rsqrt)
Expand All @@ -84,6 +97,14 @@ def wrapped(*args, **kwargs):
truncdiv = op_wrapper(op.truncdiv)
truncmod = op_wrapper(op.truncmod)

call_cpacked = dtype_forward(op.call_cpacked)
call_extern = dtype_forward(op.call_extern)
call_intrin = dtype_forward(op.call_intrin)
call_llvm_intrin = dtype_forward(op.call_llvm_intrin)
call_llvm_pure_intrin = dtype_forward(op.call_llvm_pure_intrin)
call_packed = dtype_forward(op.call_packed)
call_pure_extern = dtype_forward(op.call_pure_extern)

from . import _ffi_api


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/builder/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def realize(


def attr(node: Object, attr_key: str, value: Union[PrimExpr, str]) -> AttrFrame:
if isinstance(node, str):
node = StringImm(node)
if isinstance(value, str):
value = StringImm(value)
return _ffi_api.AttrFrame(node, attr_key, value) # pylint: disable=no-member # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/builder/tir/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,8 @@ def buffer_decl(
)


def var(dtype) -> Var:
return Var("", dtype) # pylint: disable=no-member # type: ignore


Buffer = BufferProxy()
5 changes: 4 additions & 1 deletion python/tvm/script/parse/tir/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,7 @@ def visit_tvm_annotation(self: Parser, node: doc.expr):

@dispatch.register(token="tir", type_name="Expr")
def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
self.eval_expr(node.value)
res = self.eval_expr(node.value)
if isinstance(res, Frame):
res.add_callback(partial(res.__exit__, None, None, None))
res.__enter__()
5 changes: 3 additions & 2 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@
from .op import tan, tanh, atan, atan2, atanh
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf, copysign
from .op import isnan, isfinite, isinf, copysign, isnullptr
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift
from .op import q_multiply_shift, address_of, lookup_param
from .op import likely

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError

Expand Down
79 changes: 79 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,47 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
)


def address_of(buffer_load, span=None):
"""Returns the address of an element in the buffer
Parameters
----------
dtype : str
The data type of the result.
buffer_load: BufferLoad
The buffer load.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.address_of", buffer_load)


def lookup_param(param_name, span=None):
"""Returns the param by name
Parameters
----------
param_name : str
The name of param.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.lookup_param", param_name)


def ret(val):
"""Create a tir return expression
Expand Down Expand Up @@ -1049,6 +1090,25 @@ def ldexp(x1, x2):
return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore


def likely(cond, span=None):
"""Mark condition as likely.
Parameters
----------
cond : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The marked expression.
"""
return _ffi_api.likely(cond, span) # type: ignore


def isnan(x, span=None):
"""Check if input value is Nan.
Expand All @@ -1068,6 +1128,25 @@ def isnan(x, span=None):
return _ffi_api.isnan(x, span) # type: ignore


def isnullptr(x, span=None):
"""Check if input value is nullptr.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return call_intrin("bool", "tir.isnullptr", x) # type: ignore


def isfinite(x, span=None):
"""Check if input value is finite.
Expand Down
18 changes: 18 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s
{x, y, q, s}, span);
}

// address_of
PrimExpr address_of(tir::BufferLoad buffer_load, Span span) {
return tir::Call(DataType::Handle(), tir::builtin::address_of(), {buffer_load}, span);
}

// lookup_param
PrimExpr lookup_param(String param_name, Span span) {
return tir::Call(DataType::Handle(), tir::builtin::lookup_param(), {tir::StringImm(param_name)},
span);
}

// The public function with a quick checking path.
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
Expand Down Expand Up @@ -702,6 +713,11 @@ PrimExpr isnan(PrimExpr x, Span span) {
}
}

// isnullptr
PrimExpr isnullptr(PrimExpr x, Span span) {
return tir::Call(DataType::Bool(1), tir::builtin::isnullptr(), {x}, span);
}

// isinf
PrimExpr isinf(PrimExpr x, Span span) {
DataType t = DataType::Bool(x.dtype().lanes());
Expand Down Expand Up @@ -931,6 +947,8 @@ TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);

TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);

TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);

TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);

TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
Expand Down

0 comments on commit 3191b77

Please sign in to comment.