Skip to content

Commit

Permalink
[FX] Pickle serialization of GraphModule via forward source (#43674)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch/pytorch#43674

Test Plan: Imported from OSS

Reviewed By: dzhulgakov

Differential Revision: D23362396

Pulled By: jamesr66a

fbshipit-source-id: cb8181edff70643b7bbe548cc6b0957328d4eedd
  • Loading branch information
James Reed authored and facebook-github-bot committed Sep 1, 2020
1 parent 73f7d63 commit a1a2366
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
13 changes: 13 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
import operator
import numbers
import pickle
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, DefaultDelegate

from fx.quantization import Quantizer
Expand All @@ -17,6 +18,10 @@
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

class SimpleTest(torch.nn.Module):
def forward(self, x):
return torch.relu(x + 3.0)

class TestFX(JitTestCase):
def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
"""Check that an nn.Module's results match the GraphModule version
Expand Down Expand Up @@ -418,5 +423,13 @@ def forward(self, x):
traced = symbolic_trace(ct)
traced(torch.rand(4, 4))

def test_pickle_graphmodule(self):
st = SimpleTest()
traced = symbolic_trace(st)
pickled = pickle.dumps(traced)
loaded = pickle.loads(pickled)
x = torch.rand(3, 4)
self.assertEqual(loaded(x), traced(x))

if __name__ == '__main__':
run_tests()
22 changes: 22 additions & 0 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.overrides
import linecache
import copy
from typing import Type, Dict, List, Any
from .graph import Graph

Expand All @@ -26,6 +27,24 @@ def patched_getline(*args, **kwargs):
return _orig_getlines(*args, **kwargs)
linecache.getlines = patched_getline

def deserialize_graphmodule(root : torch.nn.Module, src : str) -> torch.nn.Module:
"""
Deserialize a GraphModule given the original `root` module and the generated
`forward()` source code (`src`). This will exec() the source of the forward
onto the root module to create a well-formed Module with code analogous
to the original code. Then it symbolically traces through it to get the
GraphModule
"""
root = copy.copy(root)
from .symbolic_trace import symbolic_trace
gbls: Dict[str, Any] = {
'torch': torch
}
exec_with_source(src, gbls)
cls = type(root)
for k, v in gbls.items():
setattr(root, k, v)
return symbolic_trace(root)

class GraphModule(torch.nn.Module):
def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
Expand Down Expand Up @@ -66,6 +85,9 @@ def forward(self, {', '.join(free_variables)}):
for k, v in gbls.items():
setattr(cls, k, v)

def __reduce__(self):
return (deserialize_graphmodule, (self.root, self.code))

# workarounds for issues in __torch_function__

# WAR for __torch_function__ not handling tensor lists,
Expand Down

0 comments on commit a1a2366

Please sign in to comment.