Skip to content

Commit

Permalink
Patch file for adding FX Graph converter tool to your PyTorch source
Browse files Browse the repository at this point in the history
  • Loading branch information
parthmannan committed Jan 20, 2023
1 parent 9fac30d commit 1b4b57a
Showing 1 changed file with 258 additions and 0 deletions.
258 changes: 258 additions & 0 deletions fx_graph_converter_patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
--- /opt/pytorch/pytorch/torch/_functorch/aot_autograd.py 2023-01-17 00:39:30.000000000 -0800
+++ aot_autograd.py 2023-01-19 17:16:42.225709000 -0800
@@ -86,6 +86,7 @@
# one counter is allocated per entire compiled block (but this block
# may involve compiling multiple subgraphs; e.g., for forwards/backwards)
AOT_COUNTER = itertools.count()
+FX_CONVERT_COUNTER = itertools.count()

KNOWN_TYPES = tuple(
[torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types)
@@ -1656,6 +1657,247 @@
log.debug(f"====== Joint graph {aot_config.aot_id} ======")
log.debug(fx_g.print_readable(print_output=False))

+ next(FX_CONVERT_COUNTER)
+ from torch.fx.passes.graph_drawer import FxGraphDrawer as fgd
+ g = fgd(fx_g, 'fx_graph_extraction')
+ x = g.get_main_dot_graph()
+ import networkx as nx
+ nx_g = nx.nx_pydot.from_pydot(x)
+ source_nodes = [x.name for x in fx_g.graph.nodes if 'primals' in x.name]
+ source_nodes += [x.name for x in fx_g.graph.nodes if 'tangents' in x.name]
+ nx_edges = [x for x in nx.edge_bfs(nx_g, source=source_nodes)]
+
+ edge_attr_dict = dict()
+ node_attr_dict = dict()
+ call_sigs = []
+ for source in nx_g.nodes:
+ label = nx_g.nodes[source]['label']
+ label = label.replace('{','', 1)
+ label = label.rsplit('}',1)[0]
+ label = label.replace('\n', '').replace('\\n', '').replace('\l','').replace('\\l','').replace('\\','')
+ labels = label.split('|')
+ label_dict = {'shape':[], 'dtype':None}
+ for x in labels:
+ new_x = x.split('=')
+ #new_x[1] = new_x[1].replace(':', '=')
+ label_dict.update({new_x[0] : new_x[1]})
+
+ #label_str = label.replace(':','=') #This is a workaround to make to_pydot work for drawing. Look at https://github.com/pydot/pydot/issues/258
+
+ arg_dict = dict()
+ if 'args' in label_dict.keys():
+ #import pdb; pdb.set_trace()
+ #Below code block is an attempt to map args received for node from FX Graph exactly to the op signature evaluated using the target
+ #This is still experimental as it requires a lot more fundamental understanding of the PyTorch codebase and individual ops to capture all cases
+ if 'target' in label_dict.keys():
+ try:
+ node_call_sig = eval(label_dict['target'])._schema.__str__()
+ call_sigs.append(node_call_sig)
+ arg_dict = dict()
+
+ '''
+ #Simple code block to just add args received as is.
+ #Can comment out the experimental code block if processing the args and signature is not working
+ #=========
+ arg_dict = {'target_signature': node_call_sig, 'args': label_dict['args']}
+ if 'kwargs' in label_dict.keys():
+ arg_dict.update({'kwargs': label_dict['kwargs']})
+ #=========
+ '''
+
+ #Experimental code block -->
+ import re
+ #Match for anything between () as long as ')' is not inside the paranthesis.
+ #This is to protect signatures like 'aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)'
+ #sig_match = re.search(r'(\([^\)]+\))', node_call_sig).group()
+ #sig_args = re.split(',', sig_match.replace('(', '').replace(')', ''))
+ sig_match = re.search(r'(\(.+\)\s->)', node_call_sig).group()
+ sig_match = sig_match.replace(' ->', '')
+ sig_match = sig_match.replace('(','', 1)
+ sig_match = sig_match.rsplit(')',1)[0]
+ sig_args = re.split(',', sig_match)
+
+ req_s_args = []
+ opt_s_args = []
+ for s_arg in sig_args:
+ #Replace the empty white space if any at the beginning of the arg
+ s_arg = re.sub(r'^\s', '', s_arg)
+ if '=' in s_arg:
+ opt_s_args.append(s_arg)
+ else:
+ req_s_args.append(s_arg)
+
+ opt_s_arg_dict = dict()
+ for opt_s_arg in opt_s_args:
+ opt_s_arg_type, opt_s_arg_def = opt_s_arg.split(' ')
+ opt_s_arg_key, opt_s_arg_val = opt_s_arg_def.split('=')
+ opt_s_arg_dict.update({opt_s_arg_key:opt_s_arg_val})
+
+ #node_args = re.split(',', label_dict['args'].replace('(', '').replace(')', ''))
+ #node_args.remove('')
+ node_args = eval(label_dict['args'])
+
+ #Below code block is to make the kwargs str work with eval.
+ if 'kwargs' in label_dict.keys():
+ kwargs_str = label_dict['kwargs']
+ for opt_s_arg_key in opt_s_arg_dict.keys():
+ if opt_s_arg_key in kwargs_str:
+ kwargs_str = kwargs_str.replace(str(opt_s_arg_key), '"' + str(opt_s_arg_key) + '"')
+
+ node_kwargs = eval(kwargs_str)
+ for node_kwarg in node_kwargs.keys():
+ if node_kwarg in opt_s_arg_dict.keys():
+ arg_dict.update({node_kwarg:node_kwargs[node_kwarg]})
+ else:
+ node_kwargs = None
+
+ #Add remaining default opt_s_args here
+ for opt_s_arg_key in opt_s_arg_dict.keys():
+ if opt_s_arg_key not in arg_dict.keys():
+ arg_dict.update({opt_s_arg_key:opt_s_arg_dict[opt_s_arg_key]})
+
+ '''
+ #This section was originally written to process any definite arguments such as Tensor other=mul_1
+ #This is commented out because I have not seen a case where a definite argument is present in the args
+ #And kwargs are handled separately above.
+ #FIXME: Add check to see if args ever get a '=' definite argument
+
+ # definite_node_args = [n_arg for n_arg in node_args if '=' in n_arg]
+
+ # for def_n_arg in definite_node_args:
+ # arg_key, arg_val = node_arg.split('=')
+ # assert arg_key in sig_args,\
+ # "Found definite arg {} with value {} in FX Graph but does not match found signature {}".format(arg_key, arg_val, node_call_sig)
+ # arg_dict.update({arg_key:arg_val})
+ # sig_args.remove(arg_key)
+ # node_args.remove(def_n_arg)
+
+ '''
+
+ assert len(req_s_args) >= len(node_args), \
+ "Recieved more args {} for node {} than the signature args evaluated {} from target{}".format(node_args,source,req_s_args,node_call_sig)
+ #This to remove self args from the signature where it exists so that other args can be mapped correctly
+ if len(req_s_args) > len(node_args) and ('self' in req_s_args[0] or 'input' in req_s_args[0]):
+ del req_s_args[0]
+ #Case where op accepts more optional inputs but they weren't given
+ if len(req_s_args) > len(node_args) and ('*' in req_s_args):
+ req_s_args.remove('*')
+ assert len(req_s_args) == len(node_args), "Node args {} received do not match signature evaluated {}".format(node_args, sig_args)
+
+ #Case where multiple optional inputs were added
+ if len(node_args) > len(req_s_args) and '*' in req_s_args:
+ req_s_args.extend(['*' for x in range(len(node_args) - len(req_s_args))])
+ assert len(req_s_args) == len(node_args), "Node args {} received do not match signature evaluated {}".format(node_args, sig_args)
+
+ for sig_arg, node_arg in zip(req_s_args, node_args):
+ arg_dict.update({sig_arg:node_arg})
+
+ except:
+ print("[WARNING - FX Graph to NetworkX] Failed to map args. Target signature {} for op {}".format(label_dict['target'], source))
+ arg_dict = dict((k, label_dict[k]) for k in ['args', 'kwargs'] if k in label_dict.keys())
+ else:
+ #Does this case ever exist where there are args provided but no target function call? Not sure.
+ arg_dict = dict((k, label_dict[k]) for k in ['args', 'kwargs'] if k in label_dict.keys())
+
+ arg_str = str(arg_dict).replace(':', '=') #This is a workaround to make to_pydot work for drawing. Look at https://github.com/pydot/pydot/issues/258
+ #import pdb; pdb.set_trace()
+ arg_str = arg_str.replace('{','')
+ arg_str = arg_str.replace('}','')
+ lbl = 'name=' + str(label_dict['name']) + '\ncall_args_info=' + arg_str
+ #node_dict = {'node_call_info' : arg_str, 'name' : label_dict['name'], 'label': lbl}
+ node_dict = {'node_call_info' : arg_str, 'label': lbl}
+ node_attr_dict.update({source:node_dict})
+
+ edge_names = [(source,t,k) for t in nx_g[source].keys() for k in nx_g[source][t].keys()]
+ fx_graph_node = [x for x in fx_g.graph.nodes if x.name==source][0]
+ tensor_info = None
+ if hasattr(fx_graph_node, "meta") and "tensor_meta" in fx_graph_node.meta:
+ tensor_info = fx_graph_node.meta['tensor_meta']
+ elif hasattr(fx_graph_node, "meta") and "val" in fx_graph_node.meta:
+ tensor_info = fx_graph_node.meta['val']
+
+ if tensor_info is not None:
+ #There are underlying assumption based on my conversation with Horace at Meta
+ #1. When an op returns multiple tensors, it is unpacked by subsequent getitem operator calls.
+ #Look at assumption in len(edge_names) > 1
+ #2. When an op has a tensor that is connected to multiple users (other nodes), then the same tensor is on each edge
+ #Look at case where where len(tensor_info) < len(edge_names) and len(tensor_info) == 1:
+ #Remember - Assumption is if there are multiple tensors (len(tensor_info) > 1), they will always go to individual getitem calls
+ if not isinstance(tensor_info, list):
+ tensor_info = [tensor_info]
+ if len(tensor_info) < len(edge_names) and len(tensor_info) == 1:
+ tensor_info = [tensor_info[0] for idx in range(len(edge_names))]
+ assert len(tensor_info) == len(edge_names), "The length of tensors available in FX Graph do not match with the edges present in exported NetworkX Graph for op {}".format(source)
+ if len(edge_names) == 1:
+ edge_dict = dict()
+ edge_dict.update({'shape':tensor_info[0].shape})
+ edge_dict.update({'dtype':tensor_info[0].dtype})
+ edge_dict.update({'label':'shape=' + str(tensor_info[0].shape) + '\ndtype=' + str(tensor_info[0].dtype)})
+ edge_attr_dict.update({edge_names[0]:edge_dict})
+ elif len(edge_names) > 1:
+ #Here we are hoping that the order of edges in FX Graph tensor_info and corresponding NetworkX Graph edges match up correctly.
+ #Perhaps we could match with users from FX Graph and targets in source,target,key to match edges.
+ #TODO: Explore the above matching
+ for edge_idx in range(len(edge_names)):
+ edge_dict = dict()
+ edge_dict.update({'shape':tensor_info[edge_idx].shape})
+ edge_dict.update({'dtype':tensor_info[edge_idx].dtype})
+ edge_dict.update({'label':'shape=' + str(tensor_info[edge_idx].shape) + '\ndtype=' + str(tensor_info[edge_idx].dtype)})
+ edge_attr_dict.update({edge_names[edge_idx]:edge_dict})
+ else:
+ if len(edge_names) == 0:
+ assert fx_graph_node.op == 'output', "No edges found for non-output op {}".format(source)
+ else:
+ #This case should hopefully not be used if FXGraph has been constructed properly
+ print("[WARNING: FX Graph to NetworkX] Edge tensor information was not found in the FX Graph for op {}. Using label from PyDot export".format(source))
+ edge_dict = dict((k, label_dict[k]) for k in ('shape', 'dtype'))
+ edge_dict.update({'label':'shape=' + str(label_dict['shape']) + '\ndtype=' + str(label_dict['dtype'])})
+ edge_attr_dict.update({edge_names[0]:edge_dict})
+
+ nx.set_edge_attributes(nx_g, edge_attr_dict)
+ nx.set_node_attributes(nx_g, node_attr_dict)
+
+ import pickle
+ gen_name = 'fx_graph_extracted_id_' + str(FX_CONVERT_COUNTER)
+ pickle.dump(nx_g, open(gen_name + '.pickle', 'wb'))
+
+ #Note: To read this pickle dump, use
+ #fx_g_extracted = pickle.load(open('fx_graph_extracted.pickle', 'rb'))
+
+ #Drawing
+ edited_dot = nx.nx_pydot.to_pydot(nx_g)
+ getattr(edited_dot, 'write_png')(gen_name + '.png')
+
+ '''
+ new_nx_g = nx.Graph(nx_g)
+
+ import matplotlib.pyplot as plt
+
+ #Draw Nodes
+ pos = nx.spring_layout(new_nx_g)
+ nx.draw_networkx_nodes(new_nx_g, pos, node_size=900)
+
+ import pdb; pdb.set_trace()
+ #Create minimal label for drawing
+ name_lbls = nx.get_node_attributes(new_nx_g, 'name')
+ call_lbls = nx.get_node_attributes(new_nx_g, 'node_call_info')
+ for n in new_nx_g.nodes:
+ lbl = 'name: ' + str(name_lbls[n]) + '\ncall_args_info:' + str(call_lbls[n])
+ node_draw_lbl = {n:lbl}
+ nx.draw_networkx_labels(new_nx_g, pos, labels=node_draw_lbl)
+
+ #Draw Edges
+ nx.draw_networkx_edges(new_nx_g, pos)
+ shape_lbls = nx.get_edge_attributes(new_nx_g, 'shape')
+ dtype_lbls = nx.get_edge_attributes(new_nx_g, 'dtype')
+ for e in new_nx_g.edges():
+ lbl = 'shape: ' + str(shape_lbls[e]) + '\dtype:' + str(dtype_lbls[e])
+ edge_draw_lbl = {e:lbl}
+ nx.draw_networkx_edge_labels(new_nx_g, pos, edge_labels=edge_draw_lbl)
+
+ plt.savefig("fx_extracted_nx.png", format="PNG")
+ '''
+
with torch.no_grad():
with track_graph_compiling(aot_config, "joint"):
num_inner_fwd_outputs = _num_mutated_inputs + _num_outputs + _fw_metadata.num_intermediate_bases

0 comments on commit 1b4b57a

Please sign in to comment.