Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show more custom information of tensor and gradient #82

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions torchviz/dot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import namedtuple
from distutils.version import LooseVersion

import numpy as np
from graphviz import Digraph
import torch
from torch.autograd import Variable
Expand All @@ -10,6 +12,7 @@
# Saved attrs for grad_fn (incl. saved variables) begin with `._saved_*`
SAVED_PREFIX = "_saved_"


def get_fn_name(fn, show_attrs, max_attr_chars):
name = str(type(fn).__name__)
if not show_attrs:
Expand All @@ -32,13 +35,13 @@ def get_fn_name(fn, show_attrs, max_attr_chars):
col1width = max(len(k) for k in attrs.keys())
col2width = min(max(len(str(v)) for v in attrs.values()), max_attr_chars)
sep = "-" * max(col1width + col2width + 2, len(name))
attrstr = '%-' + str(col1width) + 's: %' + str(col2width)+ 's'
attrstr = '%-' + str(col1width) + 's: %' + str(col2width) + 's'
truncate = lambda s: s[:col2width - 3] + "..." if len(s) > col2width else s
params = '\n'.join(attrstr % (k, truncate(str(v))) for (k, v) in attrs.items())
return name + '\n' + sep + '\n' + params


def make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50):
def make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50, display_var_set=set()):
""" Produces Graphviz representation of PyTorch autograd graph.

If a node represents a backward function, it is gray. Otherwise, the node
Expand All @@ -61,9 +64,10 @@ def make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_char
present, are always displayed. (Requires PyTorch version >= 1.9)
max_attr_chars: if show_attrs is `True`, sets max number of characters
to display for any given attribute.
display_var_set: Set(str), which attributes of the variable/tensor you want to show
"""
if LooseVersion(torch.__version__) < LooseVersion("1.9") and \
(show_attrs or show_saved):
(show_attrs or show_saved):
warnings.warn(
"make_dot: showing grad_fn attributes and saved variables"
" requires PyTorch version >= 1.9. (This does NOT apply to"
Expand Down Expand Up @@ -91,6 +95,30 @@ def size_to_str(size):
def get_var_name(var, name=None):
if not name:
name = param_map[id(var)] if id(var) in param_map else ''
if len(display_var_set)>0:
attrs = dict()
for attr in dir(var):
if attr in display_var_set:
val = getattr(var, attr)
if isinstance(val, torch.Tensor):
attrs[attr] = np.array2string(val.cpu().numpy(), max_line_width=max_attr_chars, threshold=max_attr_chars, precision=4)
else:
attrs[attr] = str(val)
col1width = max(len(k) for k in attrs.keys())
for attr, val in attrs.items():
res = []
if '\n' in val:
vals = val.split('\n')
res.append(vals[0])
if len(vals)>1:
res.extend([f'%-{col1width}s' % ''+ v for v in vals[1:]])
attrs[attr] = '\n'.join(res)
col2width = min(max(len(str(v)) if '\n' not in str(v) else max(list(map(lambda x: len(x) + 1, v.split('\n')))) for v in attrs.values()), max_attr_chars)
sep = "-" * max(col1width + col2width + 2, len(name))
attrstr = '%' + str(col1width) + 's: %-' + str(col2width) + 's'
truncate = lambda s: s[:col2width - 3] + "..." if max(list(map(lambda x: len(x), s.split('\n')))) > col2width else s
p = '\n'.join(attrstr % (k, truncate(str(v))) for (k, v) in attrs.items())
return '%s\n %s' % (name, size_to_str(var.size())) + '\n' + sep + '\n' + p
return '%s\n %s' % (name, size_to_str(var.size()))

def add_nodes(fn):
Expand Down Expand Up @@ -138,11 +166,9 @@ def add_nodes(fn):
# also note that this still works for custom autograd functions
if hasattr(fn, 'saved_tensors'):
for t in fn.saved_tensors:
seen.add(t)
dot.edge(str(id(t)), str(id(fn)), dir="none")
dot.edge(str(id(t)), str(id(fn)))
dot.node(str(id(t)), get_var_name(t), fillcolor='orange')


def add_base_tensor(var, color='darkolivegreen1'):
if var in seen:
return
Expand All @@ -155,7 +181,6 @@ def add_base_tensor(var, color='darkolivegreen1'):
add_base_tensor(var._base, color='darkolivegreen3')
dot.edge(str(id(var._base)), str(id(var)), style="dotted")


# handle multiple outputs
if isinstance(var, tuple):
for v in var:
Expand Down