Skip to content

Commit

Permalink
[Refactor] Format; Simplify PackedFunc Registration; Unify parameter …
Browse files Browse the repository at this point in the history
…order of `alloc_tensor` (#65)

* [Refactor] Simplify the global registration logic

* Address comments; Pass tests; Format stuff

* Address comments
  • Loading branch information
junrushao committed Oct 14, 2022
1 parent 4a8d252 commit ab3dd5a
Show file tree
Hide file tree
Showing 33 changed files with 1,727 additions and 342 deletions.
12 changes: 6 additions & 6 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,12 @@ class RelayExprNode : public BaseExprNode {
mutable Type checked_type_ = Type(nullptr);

/*!
* \brief Stores the result of static shape analysis. It must be a RelayExpr
* and ObjectRef is used here to avoid cyclic typing.
*
* \note The value will be optional if a static shape can not be inferred.
* use .shape() instead to acesss an always defined shape expression.
*/
* \brief Stores the result of static shape analysis. It must be a RelayExpr
* and ObjectRef is used here to avoid cyclic typing.
*
* \note The value will be optional if a static shape can not be inferred.
* use .shape() instead to acesss an always defined shape expression.
*/
mutable Optional<ObjectRef> shape_ = Optional<ObjectRef>();

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
#define TVM_IR_TYPE_FUNCTOR_H_

#include <tvm/node/functor.h>
#include <tvm/relax/type.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relax/type.h>

#include <string>
#include <utility>
Expand Down
4 changes: 1 addition & 3 deletions include/tvm/relax/attrs/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(AllocTensorAttrs, "relax.attrs.AllocTensorAttrs") {
TVM_ATTR_FIELD(offset)
.describe("Storage offset to allocate the tensor.")
.set_default(0);
TVM_ATTR_FIELD(offset).describe("Storage offset to allocate the tensor.").set_default(0);
TVM_ATTR_FIELD(dtype)
.describe("The dtype of the tensor to allocate.")
.set_default(DataType::Float(32, 1));
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
#define TVM_RELAX_BLOCK_BUILDER_H_

#include <tvm/ir/expr.h>
#include <tvm/relax/utils.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/utils.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
Expand Down
6 changes: 2 additions & 4 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ class MatchShapeNode : public BindingNode {
}

bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const {
return equal(value, other->value) && equal(pattern, other->pattern)
&& equal(var, other->var);
return equal(value, other->value) && equal(pattern, other->pattern) && equal(var, other->var);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -236,8 +235,7 @@ class MatchShapeNode : public BindingNode {

class MatchShape : public Binding {
public:
TVM_DLL explicit MatchShape(Expr value, Array<PrimExpr> pattern,
Var var, Span span = Span());
TVM_DLL explicit MatchShape(Expr value, Array<PrimExpr> pattern, Var var, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchShapeNode);
};
Expand Down
16 changes: 5 additions & 11 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
}
};


/*!
* \brief A simple visitor wrapper around ExprFunctor.
* Recursively visit the content.
Expand Down Expand Up @@ -192,7 +191,7 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
virtual void VisitVarDef_(const DataflowVarNode* var);

virtual void VisitType(const Type& t);
virtual void VisitSpan(const Span& span);
virtual void VisitSpan(const Span& span);
};

void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
Expand All @@ -206,9 +205,7 @@ void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
*/
class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
public:
ExprMutator() {
builder_ = BlockBuilder::Create();
}
ExprMutator() { builder_ = BlockBuilder::Create(); }

Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const ConstantNode* op) override;
Expand Down Expand Up @@ -293,12 +290,9 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
}

/*!
* \brief Create a new var with specified shape and type if the original var's shape or type does not
* match with the specified ones.
* \param var The var to be updated.
* \param shape The specified shape.
* \param type The specified type.
* \return The var filled with \p shape and \p type.
* \brief Create a new var with specified shape and type if the original var's shape or type does
* not match with the specified ones. \param var The var to be updated. \param shape The specified
* shape. \param type The specified type. \return The var filled with \p shape and \p type.
*/
Var WithShapeAndType(Var var, Optional<ObjectRef> shape, Type type);

Expand Down
17 changes: 4 additions & 13 deletions include/tvm/relax/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,9 @@ namespace relax {

class ShapeTypeNode : public TypeNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); }

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("span", &span);
}

bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const {
return true;
}
bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { return true; }

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }

Expand Down Expand Up @@ -111,13 +106,9 @@ class DynTensorType : public Type {

class DimTypeNode : public TypeNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("span", &span);
}
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); }

bool SEqualReduce(const DimTypeNode* other, SEqualReducer equal) const {
return true;
}
bool SEqualReduce(const DimTypeNode* other, SEqualReducer equal) const { return true; }

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#ifndef TVM_RELAX_UTILS_H_
#define TVM_RELAX_UTILS_H_

#include <string>
#include <algorithm>
#include <string>
#include <unordered_map>

namespace tvm {
Expand Down
17 changes: 5 additions & 12 deletions include/tvm/relax/vm/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ namespace tvm {
namespace runtime {
namespace relax_vm {


/*!
* \brief The storage type for the bytecode in the VM.
*/
using ExecWord = int64_t;
using ExecWord = int64_t;

/*! \brief A register name. */
using RegName = ExecWord;
Expand All @@ -61,7 +60,6 @@ enum class Opcode {
If = 4U,
};


/*! \brief A single virtual machine instruction.
*
* The representation of the instruction is as
Expand Down Expand Up @@ -99,8 +97,7 @@ struct Instruction {
/*! \brief Construct from the kind and value. */
Arg(ArgKind kind, Index value) {
// TODO(ziheng): check value?
this->data = (static_cast<ExecWord>(kind) << kValueBit) |
(value & kValueMask);
this->data = (static_cast<ExecWord>(kind) << kValueBit) | (value & kValueMask);
}
/*!
* \brief Get the kind of argument..
Expand All @@ -114,16 +111,14 @@ struct Instruction {
* \brief Get the value of argument..
* \return The value of argument.
*/
ExecWord value() const {
return data & ((static_cast<ExecWord>(1) << kValueBit) - 1);
}
ExecWord value() const { return data & ((static_cast<ExecWord>(1) << kValueBit) - 1); }
/*! \brief The underlying stored data. */
ExecWord data;
};
/*! \brief The instruction opcode. */
Opcode op;
/*! \brief The destination register. */
RegName dst;
RegName dst;
union {
struct /* Call */ {
/*! \brief The index into the packed function table. */
Expand Down Expand Up @@ -160,9 +155,7 @@ struct Instruction {
* \param dst The destination register.
* \return The call instruction.
*/
static Instruction Call(Index func_idx, Index num_args,
Arg* args,
RegName dst);
static Instruction Call(Index func_idx, Index num_args, Arg* args, RegName dst);
/*!
* \brief Construct a return instruction.
* \param result The register containing the return value.
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/relax/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

/*!
* \file tvm/relax/vm/executable.h
* \brief
* \brief
*/
#ifndef TVM_RELAX_VM_EXECUTABLE_H_
#define TVM_RELAX_VM_EXECUTABLE_H_

#include <tvm/ir/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>

#include "./bytecode.h"

namespace tvm {
Expand All @@ -43,7 +44,7 @@ class Executable;
*/
struct VMFunction {
/*! \brief The function's name. */
std::string name;
std::string name;
/*! \brief The start instruction index of the function. */
Index start_instr;
/*! \brief The number of arguments of the function. */
Expand Down Expand Up @@ -116,7 +117,7 @@ class ExecutableNode : public Object {
std::vector<ExecWord> instr_data;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.Executable";
static constexpr const char* _type_key = "relax.Executable";
TVM_DECLARE_FINAL_OBJECT_INFO(ExecutableNode, Object);

private:
Expand Down Expand Up @@ -168,7 +169,6 @@ class Executable : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Executable, ObjectRef, ExecutableNode);
};


} // namespace relax_vm
} // namespace runtime
} // namespace tvm
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def transform_type(self, ty: ast.Type, bind_free_vars: bool) -> Tuple[relax.Type

# annotation with type arguments/shape annotation
if isinstance(ty, ast.TypeApply):
if ty.id.name == "Tensor":
if ty.func_name.id.name == "Tensor":
# TODO(@altanh): forgetting dtype like "Tensor[(n, m)]" ends up getting parsed as
# Tensor[n, m] which makes correct errors difficult here...
if len(ty.params) != 2:
Expand Down Expand Up @@ -295,7 +295,7 @@ def transform_type(self, ty: ast.Type, bind_free_vars: bool) -> Tuple[relax.Type
)

return (relax.DynTensorType(rank=rank, dtype=dtype, span=span), shape)
elif ty.id.name == "Tuple":
elif ty.func_name.id.name == "Tuple":
field_types = []
field_shapes = []
for field in ty.params:
Expand Down
3 changes: 2 additions & 1 deletion src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "T", bool sh
String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
runtime::TypedPackedFunc<std::string(Stmt)> annotate);

Doc AsTVMScriptDoc(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false, const PrimFunc& func = PrimFunc());
Doc AsTVMScriptDoc(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false,
const PrimFunc& func = PrimFunc());

} // namespace tir
} // namespace tvm
Expand Down
Loading

0 comments on commit ab3dd5a

Please sign in to comment.