Skip to content

Commit

Permalink
Add a lowering to linear instead of using torch decomp
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jan 28, 2025
1 parent 310d2d8 commit 7ddc586
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
27 changes: 27 additions & 0 deletions torchax/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@
import torchax.tensor


class SeqModel(torch.nn.Module):
""" Architecture is LLM generated """
def __init__(self):
super().__init__()
self.gru = torch.nn.GRU(20, 30, batch_first=True)
self.linear = torch.nn.Linear(30, 1)

def forward(self, x: torch.Tensor):
output, _ = self.gru(x) #output, hidden state
output = self.linear(output)
return output


class TestTorchFunctions(parameterized.TestCase):

def setUp(self):
Expand Down Expand Up @@ -47,6 +60,20 @@ def test_bernoulli_inplace(self):
a = torch.randn((2,3))
a.bernoulli_(0.4)

def test_rnn(self):
model = SeqModel()
x = torch.randn((2, 100, 20))
res = model(x)
self.env.config.debug_print_each_op = True
with self.env:
model.to('jax')
x = x.to('jax')
res2 = model(x)
print(res.shape, res2.shape)

self.assertEqual(res.shape, res2.shape)





Expand Down
13 changes: 9 additions & 4 deletions torchax/torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def inner(func):
torch.ops.aten.reshape,
)
def _aten_unsafe_view(x, shape):
if shape == [200, 30]:
breakpoint()
print('view', x.shape, shape)
return jnp.reshape(x, shape)


Expand Down Expand Up @@ -378,9 +381,7 @@ def _aten_t(x):
@op(torch.ops.aten.transpose)
@op(torch.ops.aten.transpose_copy)
def _aten_transpose(x, dim0, dim1):
shape = list(range(len(x.shape)))
shape[dim0], shape[dim1] = shape[dim1], shape[dim0]
return jnp.transpose(x, shape)
return jnp.swapaxes(x, dim0, dim1)


@op(torch.ops.aten.triu)
Expand Down Expand Up @@ -792,7 +793,11 @@ def make_range(rank, dim, start, end):
@op(torch.ops.aten.permute)
@op(torch.ops.aten.permute_copy)
def permute(t, dims):
return jnp.transpose(t, dims)
print('before shape', t.shape)
print('dimsshape', dims)
res = jnp.transpose(t, dims)
print('after shape', res.shape)
return res


@op(torch.ops.aten.unsqueeze)
Expand Down
8 changes: 8 additions & 0 deletions torchax/torchax/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,11 @@ def linalg_tensorsolve(A, b, dims=None):
if A.shape[:b.ndim] != b.shape:
b = jnp.reshape(b, A.shape[:b.ndim])
return jnp.linalg.tensorsolve(A, b, axes=dims)


@register_function(torch.nn.functional.linear)
def functional_linear(self, weights, bias=None):
res = jnp.einsum("...a,ba->...b", self, weights)
if bias is not None:
res += bias
return res

0 comments on commit 7ddc586

Please sign in to comment.