Skip to content

Commit

Permalink
switch over decorator usage of dynamo to torch.compile in tests and a…
Browse files Browse the repository at this point in the history
…dd inductor test

ghstack-source-id: 88864e7358c82e38be80b03cc550d960566f0500
Pull Request resolved: #298
  • Loading branch information
PaliC committed Jan 14, 2023
1 parent cef6cac commit 2e9a48a
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions multipy/runtime/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest

import torch

import torch._dynamo

class TestCompat(unittest.TestCase):
def test_torchvision(self):
Expand All @@ -22,32 +22,41 @@ def test_pytorch3d(self):
def test_hf_tokenizers(self):
import tokenizers # noqa: F401

@unittest.skip("torch.Library is not supported")
def test_torchdynamo_eager(self):
import torch._dynamo as torchdynamo

@torchdynamo.optimize("eager")
torch._dynamo.reset()

def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b

fn(torch.randn(10), torch.randn(10))
c_fn = torch.compile(fn, backend="eager")
c_fn(torch.randn(10), torch.randn(10))

@unittest.skip("torch.Library is not supported")
def test_torchdynamo_ofi(self):
import torch._dynamo as torchdynamo

torchdynamo.reset()
torch._dynamo.reset()

@torchdynamo.optimize("ofi")
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b

fn(torch.randn(10), torch.randn(10))
c_fn = torch.compile(fn, backend="ofi")
c_fn(torch.randn(10), torch.randn(10))

def test_torchdynamo_inductor(self):

torch._dynamo.reset()

def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b

c_fn = torch.compile(fn)
c_fn(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
unittest.main()

0 comments on commit 2e9a48a

Please sign in to comment.