Skip to content

Commit

Permalink
some op
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 authored and junrushao committed Jul 4, 2022
1 parent f8146b8 commit 8c9c4de
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 6 deletions.
24 changes: 23 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
* \return The result expression.
*/
TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());

/*!
* \brief Check if x is nullptr.
* \param x The input data
* \param span The location of this operation in the source.
* \return The result expression.
*/
TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
/*!
* \brief Check if x is infinite.
* \param x The input data
Expand Down Expand Up @@ -879,6 +885,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
Span span = Span());

/*!
* \brief Returns the address of an element in the buffer
* \param buffer_load The input BufferLoad.
* \param span The location of this operation in the source.
* \return The address of an element in the buffer.
*/
TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());

/*!
* \brief Returns the param by name
* \param param_name The param name.
* \param span The location of this operation in the source.
* \return The handle of param.
*/
TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/builder/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
preflattened_buffer,
prim_func,
)
from .var import Buffer, buffer_decl
from .var import Buffer, buffer_decl, var
from .stmt import (
Assert,
let,
Expand Down
26 changes: 24 additions & 2 deletions python/tvm/script/builder/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,25 @@ def wrapped(*args, **kwargs):
return wrapped


def dtype_forward(func):
def forwarded(*args, **kwargs):
if "dtype" in kwargs:
args = (kwargs.get("dtype"),) + args
kwargs.pop("dtype")
return func(*args, **kwargs)

return forwarded


abs = op_wrapper(op.abs)
acos = op_wrapper(op.acos)
acosh = op_wrapper(op.acosh)
address_of = op_wrapper(op.address_of)
asin = op_wrapper(op.asin)
asinh = op_wrapper(op.asinh)
atan = op_wrapper(op.atan)
atan2 = op_wrapper(op.atan2)
atanh = op_wrapper(op.atanh)
call_extern = op_wrapper(op.call_extern)
call_packed = op_wrapper(op.call_packed)
ceil = op_wrapper(op.ceil)
clz = op_wrapper(op.clz)
comm_reducer = op_wrapper(op.comm_reducer)
Expand All @@ -60,17 +69,22 @@ def wrapped(*args, **kwargs):
isfinite = op_wrapper(op.isfinite)
isinf = op_wrapper(op.isinf)
isnan = op_wrapper(op.isnan)
isnullptr = op_wrapper(op.isnullptr)
ldexp = op_wrapper(op.ldexp)
likely = op_wrapper(op.likely)
log = op_wrapper(op.log)
log1p = op_wrapper(op.log1p)
log2 = op_wrapper(op.log2)
log10 = op_wrapper(op.log10)
lookup_param = op_wrapper(op.lookup_param)
max_value = op_wrapper(op.max_value)
min_value = op_wrapper(op.min_value)
nearbyint = op_wrapper(op.nearbyint)
nextafter = op_wrapper(op.nextafter)
popcount = op_wrapper(op.popcount)
power = op_wrapper(op.power)
q_multiply_shift = op_wrapper(op.q_multiply_shift)
ret = op_wrapper(op.ret)
reinterpret = op_wrapper(op.reinterpret)
round = op_wrapper(op.round)
rsqrt = op_wrapper(op.rsqrt)
Expand All @@ -84,6 +98,14 @@ def wrapped(*args, **kwargs):
truncdiv = op_wrapper(op.truncdiv)
truncmod = op_wrapper(op.truncmod)

call_cpacked = dtype_forward(op.call_cpacked)
call_extern = dtype_forward(op.call_extern)
call_intrin = dtype_forward(op.call_intrin)
call_llvm_intrin = dtype_forward(op.call_llvm_intrin)
call_llvm_pure_intrin = dtype_forward(op.call_llvm_pure_intrin)
call_packed = dtype_forward(op.call_packed)
call_pure_extern = dtype_forward(op.call_pure_extern)

from . import _ffi_api


Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/builder/tir/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,8 @@ def buffer_decl(
)


def var(dtype) -> Var:
return Var("", dtype) # pylint: disable=no-member # type: ignore


Buffer = BufferProxy()
5 changes: 3 additions & 2 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@
from .op import tan, tanh, atan, atan2, atanh
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf, copysign
from .op import isnan, isfinite, isinf, copysign, isnullptr
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift
from .op import q_multiply_shift, address_of, lookup_param
from .op import likely

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError

Expand Down
79 changes: 79 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,47 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
)


def address_of(buffer_load, span=None):
"""Returns the address of an element in the buffer
Parameters
----------
dtype : str
The data type of the result.
buffer_load: BufferLoad
The buffer load.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.address_of", buffer_load)


def lookup_param(param_name, span=None):
"""Returns the param by name
Parameters
----------
param_name : str
The name of param.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.lookup_param", param_name)


def ret(val):
"""Create a tir return expression
Expand Down Expand Up @@ -1049,6 +1090,25 @@ def ldexp(x1, x2):
return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore


def likely(cond, span=None):
"""Mark condition as likely.
Parameters
----------
cond : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The marked expression.
"""
return _ffi_api.likely(cond, span) # type: ignore


def isnan(x, span=None):
"""Check if input value is Nan.
Expand All @@ -1068,6 +1128,25 @@ def isnan(x, span=None):
return _ffi_api.isnan(x, span) # type: ignore


def isnullptr(x, span=None):
"""Check if input value is nullptr.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return call_intrin("bool", "tir.isnullptr", x) # type: ignore


def isfinite(x, span=None):
"""Check if input value is finite.
Expand Down
18 changes: 18 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s
{x, y, q, s}, span);
}

// address_of
PrimExpr address_of(tir::BufferLoad buffer_load, Span span) {
return tir::Call(DataType::Handle(), tir::builtin::address_of(), {buffer_load}, span);
}

// lookup_param
PrimExpr lookup_param(String param_name, Span span) {
return tir::Call(DataType::Handle(), tir::builtin::lookup_param(), {tir::StringImm(param_name)},
span);
}

// The public function with a quick checking path.
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
if (lhs.dtype() == rhs.dtype()) return;
Expand Down Expand Up @@ -700,6 +711,11 @@ PrimExpr isnan(PrimExpr x, Span span) {
}
}

// isnullptr
PrimExpr isnullptr(PrimExpr x, Span span) {
return tir::Call(DataType::Bool(1), tir::builtin::isnullptr(), {x}, span);
}

// isinf
PrimExpr isinf(PrimExpr x, Span span) {
DataType t = DataType::Bool(x.dtype().lanes());
Expand Down Expand Up @@ -929,6 +945,8 @@ TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);

TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);

TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);

TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);

TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
Expand Down

0 comments on commit 8c9c4de

Please sign in to comment.