Skip to content

Commit

Permalink
Merge pull request #3 from cyx-6/nmk
Browse files Browse the repository at this point in the history
extract dynamic vars
  • Loading branch information
zxybazh authored Jun 20, 2023
2 parents d70a822 + b135d7a commit 0cf6f1c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
9 changes: 5 additions & 4 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from models.vicuna_v1_7b_fp16 import Module as Vicuna
from mlc_bench.extraction import extract_from_relax
from mlc_bench.extraction import extract_from_relax, extract

if __name__ == "__main__":
extract_from_relax(
mod=Vicuna, model_name="vicuna_v1_7b_fp_16", file_path="./extracted"
)
# extract_from_relax(
# mod=Vicuna, model_name="vicuna_v1_7b_fp_16", file_path="./extracted"
# )
extract(mod=Vicuna)
37 changes: 34 additions & 3 deletions mlc_bench/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ def extract_from_relax(mod: tvm.ir.IRModule, model_name: str, file_path: str):
)


def extract_shape(arg) -> List:
if isinstance(arg, (tuple, list, tvm.relax.Tuple)):
r = []
for a in arg:
r.extend(extract_shape(a))
return r
else:
return [arg.struct_info]


def prim_func_usage_gen(mod: tvm.ir.IRModule):
for gv, func in mod.functions.items():
if isinstance(func, tvm.relax.Function):
Expand All @@ -152,12 +162,31 @@ def prim_func_usage_gen(mod: tvm.ir.IRModule):
functor = raw_args[0]
if isinstance(functor, tvm.ir.GlobalVar):
if isinstance(mod.functions[functor], tvm.tir.PrimFunc):
args = [arg.struct_info for arg in raw_args[1:]] + [
binding.value.struct_info
]
args = extract_shape(raw_args[1:]) + extract_shape(binding.value)
yield gv, functor, args


def extract_dynamic_var(func_dict):
dynamic_var_dict = {}
for gv in func_dict:
print(gv)
dynamic_var_dict[gv] = set()
for functor in func_dict[gv]:
for arg_list, _ in func_dict[gv][functor]:
for arg in arg_list:
if isinstance(arg, tvm.relax.TensorStructInfo):
for v in arg.shape.values:
if isinstance(v, tvm.tir.Var):
dynamic_var_dict[gv].add((v.name, v.dtype))
elif isinstance(arg, tvm.relax.ShapeStructInfo):
for v in arg.values:
if isinstance(v, tvm.tir.Var):
dynamic_var_dict[gv].add((v.name, v.dtype))
else:
raise NotImplementedError
return dynamic_var_dict


def extract(mod: tvm.ir.IRModule):
def update_records(records, new_args):
for i, (args, count) in enumerate(records):
Expand All @@ -178,8 +207,10 @@ def update_records(records, new_args):
prim_func_dict[functor] = []
update_records(prim_func_dict[functor], args)
update_records(relax_func_dict[gv][functor], args)
dynamic_var_dict = extract_dynamic_var(relax_func_dict)
print(prim_func_dict)
print(relax_func_dict)
print(dynamic_var_dict)


if __name__ == "__main__":
Expand Down

0 comments on commit 0cf6f1c

Please sign in to comment.