diff --git a/visualdl/component/graph/graph_component.py b/visualdl/component/graph/graph_component.py index 7360b647..f2039472 100644 --- a/visualdl/component/graph/graph_component.py +++ b/visualdl/component/graph/graph_component.py @@ -464,7 +464,30 @@ def get_sub_ops(op, op_name, all_ops, all_vars): all_ops[sub_op_name]['is_leaf_node'] = True now_var = utils.gen_var_name(sub_op.results()) for source in sub_op.operands_source(): + if not source.type(): + # if source.type() == Value().type(): + continue input_name = utils.gen_var_name(source) + if input_name not in all_vars.keys(): + all_vars[input_name] = {} + all_vars[input_name]['name'] = input_name + try: + attrs = source.results()[0].get_defining_op().attrs() + if 'place' in attrs: + attrs['place'] = str(attrs['place']) + attrs['dtype'] = safe_get_dtype(source) + except Exception: + attrs = {} + + all_vars[input_name]['shape'] = safe_get_shape(source) + all_vars[input_name]['type'] = safe_get_type(source) + all_vars[input_name]['dtype'] = safe_get_dtype(source) + all_vars[input_name]['value'] = [] + all_vars[input_name]['persistable'] = safe_get_persistable(source) + all_vars[input_name]['attrs'] = attrs + all_vars[input_name]['from_node'] = '' + all_vars[input_name]['to_nodes'] = [] + if sub_op.name() == "pd_op.increment_": all_vars[now_var]['to_nodes'].append(all_vars[input_name]['from_node']) all_ops[all_vars[input_name]['from_node']]['input_vars'][now_var] = [now_var] @@ -633,7 +656,30 @@ def analyse_pir(program): all_ops[op_name]['is_leaf_node'] = True now_var = utils.gen_var_name(op.results()) for source in op.operands_source(): + if not source.type(): + # if source.type() == Value().type(): + continue input_name = utils.gen_var_name(source) + if input_name not in all_vars.keys(): + all_vars[input_name] = {} + all_vars[input_name]['name'] = input_name + try: + attrs = source.results()[0].get_defining_op().attrs() + if 'place' in attrs: + attrs['place'] = str(attrs['place']) + attrs['dtype'] = safe_get_dtype(source) + except Exception: + attrs = {} + + all_vars[input_name]['shape'] = safe_get_shape(source) + all_vars[input_name]['type'] = safe_get_type(source) + all_vars[input_name]['dtype'] = safe_get_dtype(source) + all_vars[input_name]['value'] = [] + all_vars[input_name]['persistable'] = safe_get_persistable(source) + all_vars[input_name]['attrs'] = attrs + all_vars[input_name]['from_node'] = '' + all_vars[input_name]['to_nodes'] = [] + if op.name() == "pd_op.increment_": all_vars[now_var]['to_nodes'].append(all_vars[input_name]['from_node']) all_ops[all_vars[input_name]['from_node']]['input_vars'][now_var] = [now_var]