Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some op #65

Merged
merged 3 commits into from
Jul 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
* \return The result expression.
*/
TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());

/*!
* \brief Check if x is nullptr.
* \param x The input data
* \param span The location of this operation in the source.
* \return The result expression.
*/
TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
/*!
* \brief Check if x is infinite.
* \param x The input data
Expand Down Expand Up @@ -879,6 +885,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
Span span = Span());

/*!
* \brief Returns the address of an element in the buffer
* \param buffer_load The input BufferLoad.
* \param span The location of this operation in the source.
* \return The address of an element in the buffer.
*/
TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());

/*!
* \brief Returns the param by name
* \param param_name The param name.
* \param span The location of this operation in the source.
* \return The handle of param.
*/
TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/builder/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
preflattened_buffer,
prim_func,
)
from .var import Buffer, buffer_decl
from .var import Buffer, buffer_decl, var
from .stmt import (
Assert,
let,
Expand Down
25 changes: 23 additions & 2 deletions python/tvm/script/builder/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,24 @@ def wrapped(*args, **kwargs):
return wrapped


def dtype_forward(func):
def forwarded(*args, **kwargs):
if "dtype" in kwargs:
args = (kwargs.pop("dtype"),) + args
return func(*args, **kwargs)

return forwarded


abs = op_wrapper(op.abs)
acos = op_wrapper(op.acos)
acosh = op_wrapper(op.acosh)
address_of = op_wrapper(op.address_of)
asin = op_wrapper(op.asin)
asinh = op_wrapper(op.asinh)
atan = op_wrapper(op.atan)
atan2 = op_wrapper(op.atan2)
atanh = op_wrapper(op.atanh)
call_extern = op_wrapper(op.call_extern)
call_packed = op_wrapper(op.call_packed)
ceil = op_wrapper(op.ceil)
clz = op_wrapper(op.clz)
comm_reducer = op_wrapper(op.comm_reducer)
Expand All @@ -60,17 +68,22 @@ def wrapped(*args, **kwargs):
isfinite = op_wrapper(op.isfinite)
isinf = op_wrapper(op.isinf)
isnan = op_wrapper(op.isnan)
isnullptr = op_wrapper(op.isnullptr)
ldexp = op_wrapper(op.ldexp)
likely = op_wrapper(op.likely)
log = op_wrapper(op.log)
log1p = op_wrapper(op.log1p)
log2 = op_wrapper(op.log2)
log10 = op_wrapper(op.log10)
lookup_param = op_wrapper(op.lookup_param)
max_value = op_wrapper(op.max_value)
min_value = op_wrapper(op.min_value)
nearbyint = op_wrapper(op.nearbyint)
nextafter = op_wrapper(op.nextafter)
popcount = op_wrapper(op.popcount)
power = op_wrapper(op.power)
q_multiply_shift = op_wrapper(op.q_multiply_shift)
ret = op_wrapper(op.ret)
reinterpret = op_wrapper(op.reinterpret)
round = op_wrapper(op.round)
rsqrt = op_wrapper(op.rsqrt)
Expand All @@ -84,6 +97,14 @@ def wrapped(*args, **kwargs):
truncdiv = op_wrapper(op.truncdiv)
truncmod = op_wrapper(op.truncmod)

call_cpacked = dtype_forward(op.call_cpacked)
call_extern = dtype_forward(op.call_extern)
call_intrin = dtype_forward(op.call_intrin)
call_llvm_intrin = dtype_forward(op.call_llvm_intrin)
call_llvm_pure_intrin = dtype_forward(op.call_llvm_pure_intrin)
call_packed = dtype_forward(op.call_packed)
call_pure_extern = dtype_forward(op.call_pure_extern)

from . import _ffi_api


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/builder/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def realize(


def attr(node: Object, attr_key: str, value: Union[PrimExpr, str]) -> AttrFrame:
if isinstance(node, str):
node = StringImm(node)
if isinstance(value, str):
value = StringImm(value)
return _ffi_api.AttrFrame(node, attr_key, value) # pylint: disable=no-member # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/builder/tir/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,8 @@ def buffer_decl(
)


def var(dtype) -> Var:
return Var("", dtype) # pylint: disable=no-member # type: ignore


Buffer = BufferProxy()
5 changes: 4 additions & 1 deletion python/tvm/script/parse/tir/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,7 @@ def visit_tvm_annotation(self: Parser, node: doc.expr):

@dispatch.register(token="tir", type_name="Expr")
def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
self.eval_expr(node.value)
res = self.eval_expr(node.value)
if isinstance(res, Frame):
res.add_callback(partial(res.__exit__, None, None, None))
res.__enter__()
5 changes: 3 additions & 2 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@
from .op import tan, tanh, atan, atan2, atanh
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf, copysign
from .op import isnan, isfinite, isinf, copysign, isnullptr
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift
from .op import q_multiply_shift, address_of, lookup_param
from .op import likely

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError

Expand Down
79 changes: 79 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,47 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
)


def address_of(buffer_load, span=None):
"""Returns the address of an element in the buffer

Parameters
----------
dtype : str
The data type of the result.

buffer_load: BufferLoad
The buffer load.

span : Optional[Span]
The location of this operator in the source code.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.address_of", buffer_load)


def lookup_param(param_name, span=None):
"""Returns the param by name

Parameters
----------
param_name : str
The name of param.

span : Optional[Span]
The location of this operator in the source code.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.lookup_param", param_name)


def ret(val):
"""Create a tir return expression

Expand Down Expand Up @@ -1049,6 +1090,25 @@ def ldexp(x1, x2):
return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore


def likely(cond, span=None):
"""Mark condition as likely.

Parameters
----------
cond : PrimExpr
Input argument.

span : Optional[Span]
The location of this operator in the source code.

Returns
-------
y : PrimExpr
The marked expression.
"""
return _ffi_api.likely(cond, span) # type: ignore


def isnan(x, span=None):
"""Check if input value is Nan.

Expand All @@ -1068,6 +1128,25 @@ def isnan(x, span=None):
return _ffi_api.isnan(x, span) # type: ignore


def isnullptr(x, span=None):
"""Check if input value is nullptr.

Parameters
----------
x : PrimExpr
Input argument.

span : Optional[Span]
The location of this operator in the source code.

Returns
-------
y : PrimExpr
The result.
"""
return call_intrin("bool", "tir.isnullptr", x) # type: ignore


def isfinite(x, span=None):
"""Check if input value is finite.

Expand Down
18 changes: 18 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s
{x, y, q, s}, span);
}

// address_of
PrimExpr address_of(tir::BufferLoad buffer_load, Span span) {
return tir::Call(DataType::Handle(), tir::builtin::address_of(), {buffer_load}, span);
}

// lookup_param
PrimExpr lookup_param(String param_name, Span span) {
return tir::Call(DataType::Handle(), tir::builtin::lookup_param(), {tir::StringImm(param_name)},
span);
}

// The public function with a quick checking path.
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
if (lhs.dtype() == rhs.dtype()) return;
Expand Down Expand Up @@ -700,6 +711,11 @@ PrimExpr isnan(PrimExpr x, Span span) {
}
}

// isnullptr
PrimExpr isnullptr(PrimExpr x, Span span) {
return tir::Call(DataType::Bool(1), tir::builtin::isnullptr(), {x}, span);
}

// isinf
PrimExpr isinf(PrimExpr x, Span span) {
DataType t = DataType::Bool(x.dtype().lanes());
Expand Down Expand Up @@ -929,6 +945,8 @@ TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);

TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);

TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);

TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);

TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
Expand Down