-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Patch file for adding FX Graph converter tool to your PyTorch source
- Loading branch information
1 parent
9fac30d
commit 1b4b57a
Showing
1 changed file
with
258 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
--- /opt/pytorch/pytorch/torch/_functorch/aot_autograd.py 2023-01-17 00:39:30.000000000 -0800 | ||
+++ aot_autograd.py 2023-01-19 17:16:42.225709000 -0800 | ||
@@ -86,6 +86,7 @@ | ||
# one counter is allocated per entire compiled block (but this block | ||
# may involve compiling multiple subgraphs; e.g., for forwards/backwards) | ||
AOT_COUNTER = itertools.count() | ||
+FX_CONVERT_COUNTER = itertools.count() | ||
|
||
KNOWN_TYPES = tuple( | ||
[torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types) | ||
@@ -1656,6 +1657,247 @@ | ||
log.debug(f"====== Joint graph {aot_config.aot_id} ======") | ||
log.debug(fx_g.print_readable(print_output=False)) | ||
|
||
+ next(FX_CONVERT_COUNTER) | ||
+ from torch.fx.passes.graph_drawer import FxGraphDrawer as fgd | ||
+ g = fgd(fx_g, 'fx_graph_extraction') | ||
+ x = g.get_main_dot_graph() | ||
+ import networkx as nx | ||
+ nx_g = nx.nx_pydot.from_pydot(x) | ||
+ source_nodes = [x.name for x in fx_g.graph.nodes if 'primals' in x.name] | ||
+ source_nodes += [x.name for x in fx_g.graph.nodes if 'tangents' in x.name] | ||
+ nx_edges = [x for x in nx.edge_bfs(nx_g, source=source_nodes)] | ||
+ | ||
+ edge_attr_dict = dict() | ||
+ node_attr_dict = dict() | ||
+ call_sigs = [] | ||
+ for source in nx_g.nodes: | ||
+ label = nx_g.nodes[source]['label'] | ||
+ label = label.replace('{','', 1) | ||
+ label = label.rsplit('}',1)[0] | ||
+ label = label.replace('\n', '').replace('\\n', '').replace('\l','').replace('\\l','').replace('\\','') | ||
+ labels = label.split('|') | ||
+ label_dict = {'shape':[], 'dtype':None} | ||
+ for x in labels: | ||
+ new_x = x.split('=') | ||
+ #new_x[1] = new_x[1].replace(':', '=') | ||
+ label_dict.update({new_x[0] : new_x[1]}) | ||
+ | ||
+ #label_str = label.replace(':','=') #This is a workaround to make to_pydot work for drawing. Look at https://github.com/pydot/pydot/issues/258 | ||
+ | ||
+ arg_dict = dict() | ||
+ if 'args' in label_dict.keys(): | ||
+ #import pdb; pdb.set_trace() | ||
+ #Below code block is an attempt to map args received for node from FX Graph exactly to the op signature evaluated using the target | ||
+ #This is still experimental as it requires a lot more fundamental understanding of the PyTorch codebase and individual ops to capture all cases | ||
+ if 'target' in label_dict.keys(): | ||
+ try: | ||
+ node_call_sig = eval(label_dict['target'])._schema.__str__() | ||
+ call_sigs.append(node_call_sig) | ||
+ arg_dict = dict() | ||
+ | ||
+ ''' | ||
+ #Simple code block to just add args received as is. | ||
+ #Can comment out the experimental code block if processing the args and signature is not working | ||
+ #========= | ||
+ arg_dict = {'target_signature': node_call_sig, 'args': label_dict['args']} | ||
+ if 'kwargs' in label_dict.keys(): | ||
+ arg_dict.update({'kwargs': label_dict['kwargs']}) | ||
+ #========= | ||
+ ''' | ||
+ | ||
+ #Experimental code block --> | ||
+ import re | ||
+ #Match for anything between () as long as ')' is not inside the paranthesis. | ||
+ #This is to protect signatures like 'aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)' | ||
+ #sig_match = re.search(r'(\([^\)]+\))', node_call_sig).group() | ||
+ #sig_args = re.split(',', sig_match.replace('(', '').replace(')', '')) | ||
+ sig_match = re.search(r'(\(.+\)\s->)', node_call_sig).group() | ||
+ sig_match = sig_match.replace(' ->', '') | ||
+ sig_match = sig_match.replace('(','', 1) | ||
+ sig_match = sig_match.rsplit(')',1)[0] | ||
+ sig_args = re.split(',', sig_match) | ||
+ | ||
+ req_s_args = [] | ||
+ opt_s_args = [] | ||
+ for s_arg in sig_args: | ||
+ #Replace the empty white space if any at the beginning of the arg | ||
+ s_arg = re.sub(r'^\s', '', s_arg) | ||
+ if '=' in s_arg: | ||
+ opt_s_args.append(s_arg) | ||
+ else: | ||
+ req_s_args.append(s_arg) | ||
+ | ||
+ opt_s_arg_dict = dict() | ||
+ for opt_s_arg in opt_s_args: | ||
+ opt_s_arg_type, opt_s_arg_def = opt_s_arg.split(' ') | ||
+ opt_s_arg_key, opt_s_arg_val = opt_s_arg_def.split('=') | ||
+ opt_s_arg_dict.update({opt_s_arg_key:opt_s_arg_val}) | ||
+ | ||
+ #node_args = re.split(',', label_dict['args'].replace('(', '').replace(')', '')) | ||
+ #node_args.remove('') | ||
+ node_args = eval(label_dict['args']) | ||
+ | ||
+ #Below code block is to make the kwargs str work with eval. | ||
+ if 'kwargs' in label_dict.keys(): | ||
+ kwargs_str = label_dict['kwargs'] | ||
+ for opt_s_arg_key in opt_s_arg_dict.keys(): | ||
+ if opt_s_arg_key in kwargs_str: | ||
+ kwargs_str = kwargs_str.replace(str(opt_s_arg_key), '"' + str(opt_s_arg_key) + '"') | ||
+ | ||
+ node_kwargs = eval(kwargs_str) | ||
+ for node_kwarg in node_kwargs.keys(): | ||
+ if node_kwarg in opt_s_arg_dict.keys(): | ||
+ arg_dict.update({node_kwarg:node_kwargs[node_kwarg]}) | ||
+ else: | ||
+ node_kwargs = None | ||
+ | ||
+ #Add remaining default opt_s_args here | ||
+ for opt_s_arg_key in opt_s_arg_dict.keys(): | ||
+ if opt_s_arg_key not in arg_dict.keys(): | ||
+ arg_dict.update({opt_s_arg_key:opt_s_arg_dict[opt_s_arg_key]}) | ||
+ | ||
+ ''' | ||
+ #This section was originally written to process any definite arguments such as Tensor other=mul_1 | ||
+ #This is commented out because I have not seen a case where a definite argument is present in the args | ||
+ #And kwargs are handled separately above. | ||
+ #FIXME: Add check to see if args ever get a '=' definite argument | ||
+ | ||
+ # definite_node_args = [n_arg for n_arg in node_args if '=' in n_arg] | ||
+ | ||
+ # for def_n_arg in definite_node_args: | ||
+ # arg_key, arg_val = node_arg.split('=') | ||
+ # assert arg_key in sig_args,\ | ||
+ # "Found definite arg {} with value {} in FX Graph but does not match found signature {}".format(arg_key, arg_val, node_call_sig) | ||
+ # arg_dict.update({arg_key:arg_val}) | ||
+ # sig_args.remove(arg_key) | ||
+ # node_args.remove(def_n_arg) | ||
+ | ||
+ ''' | ||
+ | ||
+ assert len(req_s_args) >= len(node_args), \ | ||
+ "Recieved more args {} for node {} than the signature args evaluated {} from target{}".format(node_args,source,req_s_args,node_call_sig) | ||
+ #This to remove self args from the signature where it exists so that other args can be mapped correctly | ||
+ if len(req_s_args) > len(node_args) and ('self' in req_s_args[0] or 'input' in req_s_args[0]): | ||
+ del req_s_args[0] | ||
+ #Case where op accepts more optional inputs but they weren't given | ||
+ if len(req_s_args) > len(node_args) and ('*' in req_s_args): | ||
+ req_s_args.remove('*') | ||
+ assert len(req_s_args) == len(node_args), "Node args {} received do not match signature evaluated {}".format(node_args, sig_args) | ||
+ | ||
+ #Case where multiple optional inputs were added | ||
+ if len(node_args) > len(req_s_args) and '*' in req_s_args: | ||
+ req_s_args.extend(['*' for x in range(len(node_args) - len(req_s_args))]) | ||
+ assert len(req_s_args) == len(node_args), "Node args {} received do not match signature evaluated {}".format(node_args, sig_args) | ||
+ | ||
+ for sig_arg, node_arg in zip(req_s_args, node_args): | ||
+ arg_dict.update({sig_arg:node_arg}) | ||
+ | ||
+ except: | ||
+ print("[WARNING - FX Graph to NetworkX] Failed to map args. Target signature {} for op {}".format(label_dict['target'], source)) | ||
+ arg_dict = dict((k, label_dict[k]) for k in ['args', 'kwargs'] if k in label_dict.keys()) | ||
+ else: | ||
+ #Does this case ever exist where there are args provided but no target function call? Not sure. | ||
+ arg_dict = dict((k, label_dict[k]) for k in ['args', 'kwargs'] if k in label_dict.keys()) | ||
+ | ||
+ arg_str = str(arg_dict).replace(':', '=') #This is a workaround to make to_pydot work for drawing. Look at https://github.com/pydot/pydot/issues/258 | ||
+ #import pdb; pdb.set_trace() | ||
+ arg_str = arg_str.replace('{','') | ||
+ arg_str = arg_str.replace('}','') | ||
+ lbl = 'name=' + str(label_dict['name']) + '\ncall_args_info=' + arg_str | ||
+ #node_dict = {'node_call_info' : arg_str, 'name' : label_dict['name'], 'label': lbl} | ||
+ node_dict = {'node_call_info' : arg_str, 'label': lbl} | ||
+ node_attr_dict.update({source:node_dict}) | ||
+ | ||
+ edge_names = [(source,t,k) for t in nx_g[source].keys() for k in nx_g[source][t].keys()] | ||
+ fx_graph_node = [x for x in fx_g.graph.nodes if x.name==source][0] | ||
+ tensor_info = None | ||
+ if hasattr(fx_graph_node, "meta") and "tensor_meta" in fx_graph_node.meta: | ||
+ tensor_info = fx_graph_node.meta['tensor_meta'] | ||
+ elif hasattr(fx_graph_node, "meta") and "val" in fx_graph_node.meta: | ||
+ tensor_info = fx_graph_node.meta['val'] | ||
+ | ||
+ if tensor_info is not None: | ||
+ #There are underlying assumption based on my conversation with Horace at Meta | ||
+ #1. When an op returns multiple tensors, it is unpacked by subsequent getitem operator calls. | ||
+ #Look at assumption in len(edge_names) > 1 | ||
+ #2. When an op has a tensor that is connected to multiple users (other nodes), then the same tensor is on each edge | ||
+ #Look at case where where len(tensor_info) < len(edge_names) and len(tensor_info) == 1: | ||
+ #Remember - Assumption is if there are multiple tensors (len(tensor_info) > 1), they will always go to individual getitem calls | ||
+ if not isinstance(tensor_info, list): | ||
+ tensor_info = [tensor_info] | ||
+ if len(tensor_info) < len(edge_names) and len(tensor_info) == 1: | ||
+ tensor_info = [tensor_info[0] for idx in range(len(edge_names))] | ||
+ assert len(tensor_info) == len(edge_names), "The length of tensors available in FX Graph do not match with the edges present in exported NetworkX Graph for op {}".format(source) | ||
+ if len(edge_names) == 1: | ||
+ edge_dict = dict() | ||
+ edge_dict.update({'shape':tensor_info[0].shape}) | ||
+ edge_dict.update({'dtype':tensor_info[0].dtype}) | ||
+ edge_dict.update({'label':'shape=' + str(tensor_info[0].shape) + '\ndtype=' + str(tensor_info[0].dtype)}) | ||
+ edge_attr_dict.update({edge_names[0]:edge_dict}) | ||
+ elif len(edge_names) > 1: | ||
+ #Here we are hoping that the order of edges in FX Graph tensor_info and corresponding NetworkX Graph edges match up correctly. | ||
+ #Perhaps we could match with users from FX Graph and targets in source,target,key to match edges. | ||
+ #TODO: Explore the above matching | ||
+ for edge_idx in range(len(edge_names)): | ||
+ edge_dict = dict() | ||
+ edge_dict.update({'shape':tensor_info[edge_idx].shape}) | ||
+ edge_dict.update({'dtype':tensor_info[edge_idx].dtype}) | ||
+ edge_dict.update({'label':'shape=' + str(tensor_info[edge_idx].shape) + '\ndtype=' + str(tensor_info[edge_idx].dtype)}) | ||
+ edge_attr_dict.update({edge_names[edge_idx]:edge_dict}) | ||
+ else: | ||
+ if len(edge_names) == 0: | ||
+ assert fx_graph_node.op == 'output', "No edges found for non-output op {}".format(source) | ||
+ else: | ||
+ #This case should hopefully not be used if FXGraph has been constructed properly | ||
+ print("[WARNING: FX Graph to NetworkX] Edge tensor information was not found in the FX Graph for op {}. Using label from PyDot export".format(source)) | ||
+ edge_dict = dict((k, label_dict[k]) for k in ('shape', 'dtype')) | ||
+ edge_dict.update({'label':'shape=' + str(label_dict['shape']) + '\ndtype=' + str(label_dict['dtype'])}) | ||
+ edge_attr_dict.update({edge_names[0]:edge_dict}) | ||
+ | ||
+ nx.set_edge_attributes(nx_g, edge_attr_dict) | ||
+ nx.set_node_attributes(nx_g, node_attr_dict) | ||
+ | ||
+ import pickle | ||
+ gen_name = 'fx_graph_extracted_id_' + str(FX_CONVERT_COUNTER) | ||
+ pickle.dump(nx_g, open(gen_name + '.pickle', 'wb')) | ||
+ | ||
+ #Note: To read this pickle dump, use | ||
+ #fx_g_extracted = pickle.load(open('fx_graph_extracted.pickle', 'rb')) | ||
+ | ||
+ #Drawing | ||
+ edited_dot = nx.nx_pydot.to_pydot(nx_g) | ||
+ getattr(edited_dot, 'write_png')(gen_name + '.png') | ||
+ | ||
+ ''' | ||
+ new_nx_g = nx.Graph(nx_g) | ||
+ | ||
+ import matplotlib.pyplot as plt | ||
+ | ||
+ #Draw Nodes | ||
+ pos = nx.spring_layout(new_nx_g) | ||
+ nx.draw_networkx_nodes(new_nx_g, pos, node_size=900) | ||
+ | ||
+ import pdb; pdb.set_trace() | ||
+ #Create minimal label for drawing | ||
+ name_lbls = nx.get_node_attributes(new_nx_g, 'name') | ||
+ call_lbls = nx.get_node_attributes(new_nx_g, 'node_call_info') | ||
+ for n in new_nx_g.nodes: | ||
+ lbl = 'name: ' + str(name_lbls[n]) + '\ncall_args_info:' + str(call_lbls[n]) | ||
+ node_draw_lbl = {n:lbl} | ||
+ nx.draw_networkx_labels(new_nx_g, pos, labels=node_draw_lbl) | ||
+ | ||
+ #Draw Edges | ||
+ nx.draw_networkx_edges(new_nx_g, pos) | ||
+ shape_lbls = nx.get_edge_attributes(new_nx_g, 'shape') | ||
+ dtype_lbls = nx.get_edge_attributes(new_nx_g, 'dtype') | ||
+ for e in new_nx_g.edges(): | ||
+ lbl = 'shape: ' + str(shape_lbls[e]) + '\dtype:' + str(dtype_lbls[e]) | ||
+ edge_draw_lbl = {e:lbl} | ||
+ nx.draw_networkx_edge_labels(new_nx_g, pos, edge_labels=edge_draw_lbl) | ||
+ | ||
+ plt.savefig("fx_extracted_nx.png", format="PNG") | ||
+ ''' | ||
+ | ||
with torch.no_grad(): | ||
with track_graph_compiling(aot_config, "joint"): | ||
num_inner_fwd_outputs = _num_mutated_inputs + _num_outputs + _fw_metadata.num_intermediate_bases |