Skip to content

Commit

Permalink
[FX] Support tensor-valued constants (#43666)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch/pytorch#43666

Test Plan: Imported from OSS

Reviewed By: dzhulgakov

Differential Revision: D23359110

Pulled By: jamesr66a

fbshipit-source-id: 8569a2db0ef081ea7d8e81d7ba26a92bc12ed423
  • Loading branch information
James Reed authored and facebook-github-bot committed Sep 1, 2020
1 parent 06c277f commit 73f7d63
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
33 changes: 33 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,39 @@ def forward(self, a, b):
self.assertTrue(hasattr(n, 'tag'))
self.assertEqual(n.tag, 'foo')

def test_tensor_attribute(self):
class TensorAttribute(torch.nn.Module):
def __init__(self):
super().__init__()
self.tensor = torch.rand(3, 4)

def forward(self, x):
return torch.nn.functional.linear(x, self.tensor)

ta = TensorAttribute()
traced = symbolic_trace(ta)
traced(torch.rand(4, 4))

class WrapperForQualname(torch.nn.Module):
def __init__(self):
super().__init__()
self.ta = TensorAttribute()

def forward(self, x):
return torch.nn.functional.linear(x, self.ta.tensor)

wfq = WrapperForQualname()
traced2 = symbolic_trace(wfq)
traced2(torch.rand(4, 4))

def test_tensor_constant(self):
class ConstTensor(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.linear(x, torch.zeros(3, 4))

ct = ConstTensor()
traced = symbolic_trace(ct)
traced(torch.rand(4, 4))

if __name__ == '__main__':
run_tests()
40 changes: 40 additions & 0 deletions torch/fx/symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,46 @@ def create_arg(self, a: Any) -> Argument:
if a is p:
return self.get_param(n)
raise NameError('parameter is not a member of this module')
# Tensors do not have a reliable string repr() from which they can be
# constructed (and we probably don't want to rely on that, either), so
# for any constant Tensor values we encounter, first search for if they
# are an attribute of some module in the module hierarchy. If so, emit
# a get_param to retrieve that tensor. Otherwise, we'll store away the
# tensor value into a special attribute on the Module s.t. we can
# retrieve it with a get_param.
if isinstance(a, torch.Tensor):
# TODO: slow
def search_for_tensor(m : torch.nn.Module) -> Optional[List[str]]:
"""
Search for a tensor value in the module's attributes. If it's
found, return the qualified name of that attribute, given the
previous `qualname_atoms`. If it's not found, recurse down into
child submodules. If it's not found there, return None
"""
for n, p in m.__dict__.items():
if a is p:
return [n]
for n, c in m.named_children():
maybe_result : Optional[List[str]] = search_for_tensor(c)
if maybe_result:
return [n] + maybe_result
return None
# Retrieve the qualname for an existing Tensor attribute
qualname_atoms : Optional[List[str]] = search_for_tensor(self.root)
qualname = '.'.join(qualname_atoms) if qualname_atoms else None

# Tensor was not found in the Module hierarchy, stow it away in a
# special attribute and set the qualname to refer to that
if not qualname:
i = 0
while True:
qualname = f'__tensor_constant{i}'
if not hasattr(self.root, qualname):
break
i += 1
setattr(self.root, qualname, a)

return self.get_param(qualname)
return super().create_arg(a)


Expand Down

0 comments on commit 73f7d63

Please sign in to comment.