diff --git a/example.py b/example.py index 617c1d4..98b7404 100644 --- a/example.py +++ b/example.py @@ -1,7 +1,8 @@ from models.vicuna_v1_7b_fp16 import Module as Vicuna -from mlc_bench.extraction import extract_from_relax +from mlc_bench.extraction import extract_from_relax, extract if __name__ == "__main__": - extract_from_relax( - mod=Vicuna, model_name="vicuna_v1_7b_fp_16", file_path="./extracted" - ) + # extract_from_relax( + # mod=Vicuna, model_name="vicuna_v1_7b_fp_16", file_path="./extracted" + # ) + extract(mod=Vicuna) diff --git a/mlc_bench/extraction.py b/mlc_bench/extraction.py index f7e5bbb..61a8e32 100644 --- a/mlc_bench/extraction.py +++ b/mlc_bench/extraction.py @@ -142,6 +142,16 @@ def extract_from_relax(mod: tvm.ir.IRModule, model_name: str, file_path: str): ) +def extract_shape(arg) -> List: + if isinstance(arg, (tuple, list, tvm.relax.Tuple)): + r = [] + for a in arg: + r.extend(extract_shape(a)) + return r + else: + return [arg.struct_info] + + def prim_func_usage_gen(mod: tvm.ir.IRModule): for gv, func in mod.functions.items(): if isinstance(func, tvm.relax.Function): @@ -152,12 +162,31 @@ def prim_func_usage_gen(mod: tvm.ir.IRModule): functor = raw_args[0] if isinstance(functor, tvm.ir.GlobalVar): if isinstance(mod.functions[functor], tvm.tir.PrimFunc): - args = [arg.struct_info for arg in raw_args[1:]] + [ - binding.value.struct_info - ] + args = extract_shape(raw_args[1:]) + extract_shape(binding.value) yield gv, functor, args +def extract_dynamic_var(func_dict): + dynamic_var_dict = {} + for gv in func_dict: + print(gv) + dynamic_var_dict[gv] = set() + for functor in func_dict[gv]: + for arg_list, _ in func_dict[gv][functor]: + for arg in arg_list: + if isinstance(arg, tvm.relax.TensorStructInfo): + for v in arg.shape.values: + if isinstance(v, tvm.tir.Var): + dynamic_var_dict[gv].add((v.name, v.dtype)) + elif isinstance(arg, tvm.relax.ShapeStructInfo): + for v in arg.values: + if isinstance(v, tvm.tir.Var): + dynamic_var_dict[gv].add((v.name, v.dtype)) + else: + raise NotImplementedError + return dynamic_var_dict + + def extract(mod: tvm.ir.IRModule): def update_records(records, new_args): for i, (args, count) in enumerate(records): @@ -178,8 +207,10 @@ def update_records(records, new_args): prim_func_dict[functor] = [] update_records(prim_func_dict[functor], args) update_records(relax_func_dict[gv][functor], args) + dynamic_var_dict = extract_dynamic_var(relax_func_dict) print(prim_func_dict) print(relax_func_dict) + print(dynamic_var_dict) if __name__ == "__main__":