Skip to content

Commit

Permalink
Updating signature and arg matching to use FX Graph information inste…
Browse files Browse the repository at this point in the history
…ad of PyDot graph labels. Fixes Resnet50 failures with Convolution, BatchNorm, MaxPool
  • Loading branch information
parthmannan committed Jan 24, 2023
1 parent 274da50 commit 384cf76
Showing 1 changed file with 63 additions and 32 deletions.
95 changes: 63 additions & 32 deletions fx_graph_converter_patch
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
--- /opt/pytorch/pytorch/torch/_functorch/aot_autograd.py 2023-01-17 00:39:30.000000000 -0800
+++ aot_autograd.py 2023-01-19 17:16:42.225709000 -0800
+++ aot_autograd.py 2023-01-23 20:32:14.164081000 -0800
@@ -86,6 +86,7 @@
# one counter is allocated per entire compiled block (but this block
# may involve compiling multiple subgraphs; e.g., for forwards/backwards)
Expand All @@ -8,11 +8,12 @@

KNOWN_TYPES = tuple(
[torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types)
@@ -1656,6 +1657,247 @@
@@ -1656,6 +1657,278 @@
log.debug(f"====== Joint graph {aot_config.aot_id} ======")
log.debug(fx_g.print_readable(print_output=False))

+ next(FX_CONVERT_COUNTER)
+ print("[FX Graph to NetworkX Exporter]: Starting Graph Export {}".format(FX_CONVERT_COUNTER))
+ from torch.fx.passes.graph_drawer import FxGraphDrawer as fgd
+ g = fgd(fx_g, 'fx_graph_extraction')
+ x = g.get_main_dot_graph()
Expand All @@ -39,14 +40,17 @@
+
+ #label_str = label.replace(':','=') #This is a workaround to make to_pydot work for drawing. Look at https://github.com/pydot/pydot/issues/258
+
+ fx_graph_node = [x for x in fx_g.graph.nodes if x.name==source][0]
+ arg_dict = dict()
+ if 'args' in label_dict.keys():
+ #import pdb; pdb.set_trace()
+ #if 'args' in label_dict.keys():
+ if getattr(fx_graph_node, 'args', None):
+ #Below code block is an attempt to map args received for node from FX Graph exactly to the op signature evaluated using the target
+ #This is still experimental as it requires a lot more fundamental understanding of the PyTorch codebase and individual ops to capture all cases
+ if 'target' in label_dict.keys():
+ #if 'target' in label_dict.keys():
+ if getattr(fx_graph_node, 'target', None):
+ try:
+ node_call_sig = eval(label_dict['target'])._schema.__str__()
+ #node_call_sig = eval(label_dict['target'])._schema.__str__()
+ node_call_sig = fx_graph_node.target._schema.__str__()
+ call_sigs.append(node_call_sig)
+ arg_dict = dict()
+
Expand Down Expand Up @@ -87,11 +91,11 @@
+ opt_s_arg_type, opt_s_arg_def = opt_s_arg.split(' ')
+ opt_s_arg_key, opt_s_arg_val = opt_s_arg_def.split('=')
+ opt_s_arg_dict.update({opt_s_arg_key:opt_s_arg_val})
+
+ #node_args = eval(label_dict['args'])
+ node_args = fx_graph_node.args
+
+ #node_args = re.split(',', label_dict['args'].replace('(', '').replace(')', ''))
+ #node_args.remove('')
+ node_args = eval(label_dict['args'])
+
+ #FIXME: Update kwargs to use FX Graph information as well instead of label data
+ #Below code block is to make the kwargs str work with eval.
+ if 'kwargs' in label_dict.keys():
+ kwargs_str = label_dict['kwargs']
Expand All @@ -103,13 +107,9 @@
+ for node_kwarg in node_kwargs.keys():
+ if node_kwarg in opt_s_arg_dict.keys():
+ arg_dict.update({node_kwarg:node_kwargs[node_kwarg]})
+ opt_s_arg_dict.pop(node_kwarg)
+ else:
+ node_kwargs = None
+
+ #Add remaining default opt_s_args here
+ for opt_s_arg_key in opt_s_arg_dict.keys():
+ if opt_s_arg_key not in arg_dict.keys():
+ arg_dict.update({opt_s_arg_key:opt_s_arg_dict[opt_s_arg_key]})
+
+ '''
+ #This section was originally written to process any definite arguments such as Tensor other=mul_1
Expand All @@ -129,9 +129,9 @@
+
+ '''
+
+ assert len(req_s_args) >= len(node_args), \
+ "Recieved more args {} for node {} than the signature args evaluated {} from target{}".format(node_args,source,req_s_args,node_call_sig)
+ #This to remove self args from the signature where it exists so that other args can be mapped correctly
+ #TODO: Is it always true that self/input will show up in the first arg of signature if the node args don't have it?
+ #TODO Contd: Check if these cases even happen since I moved from label_dict to fx graph args.
+ if len(req_s_args) > len(node_args) and ('self' in req_s_args[0] or 'input' in req_s_args[0]):
+ del req_s_args[0]
+ #Case where op accepts more optional inputs but they weren't given
Expand All @@ -141,11 +141,24 @@
+
+ #Case where multiple optional inputs were added
+ if len(node_args) > len(req_s_args) and '*' in req_s_args:
+ req_s_args.extend(['*' for x in range(len(node_args) - len(req_s_args))])
+ assert len(req_s_args) == len(node_args), "Node args {} received do not match signature evaluated {}".format(node_args, sig_args)
+ req_s_args.extend(['*' for x in range(len(node_args) - len(req_s_args))])
+ assert len(req_s_args) == len(node_args), "Node args {} received do not match signature evaluated {}".format(node_args, sig_args)
+
+ #Assign node args to req args until they run out. In most cases all node args should also get consumed here.
+ for idx, sig_arg in enumerate(req_s_args):
+ arg_dict.update({sig_arg:node_args[idx]})
+ node_args = node_args[len(req_s_args):]
+
+ #Continue assigning to opt args until node args run out. Remove opt_s_arg_dict key once assigned.
+ opt_s_args_to_assign = list(opt_s_arg_dict.keys())[:len(node_args)]
+ for sig_arg, n_arg in zip(opt_s_args_to_assign, node_args):
+ arg_dict.update({sig_arg:n_arg})
+ opt_s_arg_dict.pop(sig_arg)
+
+ for sig_arg, node_arg in zip(req_s_args, node_args):
+ arg_dict.update({sig_arg:node_arg})
+ #Add remaining default opt_s_args here
+ for opt_s_arg_key in opt_s_arg_dict.keys():
+ if opt_s_arg_key not in arg_dict.keys():
+ arg_dict.update({opt_s_arg_key:opt_s_arg_dict[opt_s_arg_key]})
+
+ except:
+ print("[WARNING - FX Graph to NetworkX] Failed to map args. Target signature {} for op {}".format(label_dict['target'], source))
Expand All @@ -155,21 +168,33 @@
+ arg_dict = dict((k, label_dict[k]) for k in ['args', 'kwargs'] if k in label_dict.keys())
+
+ arg_str = str(arg_dict).replace(':', '=') #This is a workaround to make to_pydot work for drawing. Look at https://github.com/pydot/pydot/issues/258
+ #import pdb; pdb.set_trace()
+ arg_str = arg_str.replace('{','')
+ arg_str = arg_str.replace('}','')
+ lbl = 'name=' + str(label_dict['name']) + '\ncall_args_info=' + arg_str
+ #node_dict = {'node_call_info' : arg_str, 'name' : label_dict['name'], 'label': lbl}
+ node_dict = {'node_call_info' : arg_str, 'label': lbl}
+ node_attr_dict.update({source:node_dict})
+
+ def _get_tensor_info(fxg_node):
+ tensor_info = None
+ if hasattr(fxg_node, "meta") and "tensor_meta" in fxg_node.meta:
+ tensor_info = fxg_node.meta['tensor_meta']
+ elif hasattr(fxg_node, "meta") and "val" in fxg_node.meta:
+ tensor_info = fxg_node.meta['val']
+ return tensor_info
+
+ edge_names = [(source,t,k) for t in nx_g[source].keys() for k in nx_g[source][t].keys()]
+ fx_graph_node = [x for x in fx_g.graph.nodes if x.name==source][0]
+ tensor_info = None
+ if hasattr(fx_graph_node, "meta") and "tensor_meta" in fx_graph_node.meta:
+ tensor_info = fx_graph_node.meta['tensor_meta']
+ elif hasattr(fx_graph_node, "meta") and "val" in fx_graph_node.meta:
+ tensor_info = fx_graph_node.meta['val']
+
+ tensor_info = _get_tensor_info(fx_graph_node)
+ if tensor_info is None:
+ #This means that tensor metadata was not present for this node. This is undesirable and may be good to file bugs on PyTorch regarding this.
+ #For now, there are cases we can manage when output is managed by getitem nodes
+ #It may be desirable to ALWAYS match tensor metadata when multiple getitem nodes are users but that's an enhancement for future.
+ if all(['getitem' in x.name for x in fx_graph_node.users.keys()]):
+ for dest_node in fx_graph_node.users.keys():
+ dest_tensor_info = _get_tensor_info(dest_node)
+ assert dest_tensor_info is not None, "getitem node {} does not have tensor info. This fails assumptions made in the exporter.".format(dest_node)
+ tensor_info = tensor_info + [dest_tensor_info] if tensor_info else [dest_tensor_info]
+
+ if tensor_info is not None:
+ #There are underlying assumption based on my conversation with Horace at Meta
Expand All @@ -182,7 +207,11 @@
+ tensor_info = [tensor_info]
+ if len(tensor_info) < len(edge_names) and len(tensor_info) == 1:
+ tensor_info = [tensor_info[0] for idx in range(len(edge_names))]
+ assert len(tensor_info) == len(edge_names), "The length of tensors available in FX Graph do not match with the edges present in exported NetworkX Graph for op {}".format(source)
+
+ if len(edge_names) == 0:
+ assert len(fx_graph_node.users) == 0, "There are no edges in the exported graph for op {} but FX Graph has users {}".format(source,fx_graph_node.users)
+ else:
+ assert len(tensor_info) == len(edge_names), "The length of tensors available in FX Graph do not match with the edges present in exported NetworkX Graph for op {}".format(source)
+ if len(edge_names) == 1:
+ edge_dict = dict()
+ edge_dict.update({'shape':tensor_info[0].shape})
Expand Down Expand Up @@ -220,8 +249,10 @@
+ #fx_g_extracted = pickle.load(open('fx_graph_extracted.pickle', 'rb'))
+
+ #Drawing
+ edited_dot = nx.nx_pydot.to_pydot(nx_g)
+ getattr(edited_dot, 'write_png')(gen_name + '.png')
+ #print("[FX Graph to NetworkX Exporter]: Starting to draw extracted graph {}".format(FX_CONVERT_COUNTER))
+ #edited_dot = nx.nx_pydot.to_pydot(nx_g)
+ #getattr(edited_dot, 'write_png')(gen_name + '.png')
+ #print("[FX Graph to NetworkX Exporter]: Finished drawing extracted graph {}".format(FX_CONVERT_COUNTER))
+
+ '''
+ new_nx_g = nx.Graph(nx_g)
Expand Down Expand Up @@ -252,7 +283,7 @@
+
+ plt.savefig("fx_extracted_nx.png", format="PNG")
+ '''
+
+
with torch.no_grad():
with track_graph_compiling(aot_config, "joint"):
num_inner_fwd_outputs = _num_mutated_inputs + _num_outputs + _fw_metadata.num_intermediate_bases

0 comments on commit 384cf76

Please sign in to comment.