Skip to content

Commit

Permalink
Support fw and bw with spmd
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Jan 31, 2025
1 parent 8e6ca60 commit a8c8f47
Show file tree
Hide file tree
Showing 4 changed files with 709 additions and 284 deletions.
188 changes: 165 additions & 23 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import sys
import unittest
from absl.testing import parameterized

import torch
from torch import nn as nn
Expand All @@ -19,7 +21,20 @@
from jax.experimental import pallas as pl


class PallasTest(unittest.TestCase):
def with_jax_high_precision(func):

def wrapper(*args, **kwargs):
jax.config.update('jax_default_matmul_precision', "highest")
try:
result = func(*args, **kwargs)
finally:
jax.config.update('jax_default_matmul_precision', "default")
return result

return wrapper


class PallasTest(parameterized.TestCase):

# This is to create a diagonal mask where only elements within the same segment
# can attend to each other. Since the mask is to mask out the unrelevant parts,
Expand All @@ -33,12 +48,11 @@ def _make_attention_mask_from_segment_ids(self, q_segment_ids,

def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
if attn_mask is not None:
# Masked out the unrelevant parts.
attn_weight = attn_weight.masked_fill(attn_mask,
torch.finfo(attn_weight.dtype).min)
if ab is not None:
attn_weight = attn_weight + ab
if attn_mask is not None:
attn_weight = attn_weight.masked_fill(attn_mask.bool(),
torch.finfo(attn_weight.dtype).min)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
Expand Down Expand Up @@ -216,8 +230,8 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self):

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
Expand All @@ -227,12 +241,11 @@ def test_flash_attention_wrapper(self):
o = flash_attention(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_with_dynamo(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

def flash_attention_wrapper(q, k, v, causal=False):
Expand All @@ -253,12 +266,11 @@ def flash_attention_wrapper(q, k, v, causal=False):
# therefore it speeds up the compute but also changes the output.
self.assertFalse(
torch.allclose(o_with_causal.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_causal(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
Expand All @@ -270,7 +282,6 @@ def test_flash_attention_wrapper_causal(self):
o = flash_attention(q, k, v, causal=True)
expected_o = self._attention(q, k, v)
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_multiple_returns(self):
Expand Down Expand Up @@ -450,8 +461,8 @@ def test__flash_attention_bwd_dkv(self):

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_backward(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

torch.manual_seed(42)
Expand Down Expand Up @@ -486,7 +497,6 @@ def test_flash_attention_backward(self):

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
Expand Down Expand Up @@ -1026,8 +1036,8 @@ def test_flash_attention_wrapper_segment_ids_1(self):

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_segment_ids_2(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
Expand Down Expand Up @@ -1093,12 +1103,11 @@ def test_flash_attention_backward_segment_ids(self):

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_sm_scale(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
Expand All @@ -1109,12 +1118,11 @@ def test_flash_attention_wrapper_sm_scale(self):

expected_o = self._attention(q * sm_scale, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_sm_scale_backward(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

torch.manual_seed(42)
Expand Down Expand Up @@ -1151,12 +1159,11 @@ def test_flash_attention_sm_scale_backward(self):
# Hmm, the gradients are the same even the autograd graph seems different.
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_ab(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
Expand Down Expand Up @@ -1208,12 +1215,11 @@ def test_flash_attention_ab_backward_1(self):

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_ab_backward_2(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention

torch.manual_seed(42)
Expand Down Expand Up @@ -1251,7 +1257,143 @@ def test_flash_attention_ab_backward_2(self):

for i in [(q, q_grad), (k, k_grad), (v, v_grad), (ab, ab_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@parameterized.named_parameters(('off', False), ('on', True))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
@with_jax_high_precision
def test_flash_attention_forward_aot_autograd_traceable_causal(self, causal):
from functorch.compile import aot_function, make_boxed_func
from torch_xla.experimental.custom_kernel import flash_attention
import torch_xla.core.xla_model as xm

def compiler(gm, _):
return make_boxed_func(gm)

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
B, N, SEQ, H = q.size()
q_segment_ids = None
kv_segment_ids = None
sm_scale = 1.0

# def flash_attention_wrapper(q, k, v, casual, q_segment_ids, kv_segment_ids, sm_scale, ab):
# return flash_attention(q, k, v, casual, q_segment_ids, kv_segment_ids, sm_scale, ab=ab)
# AOT compatiable funtion only accepts argument types listed https://github.com/pytorch/pytorch/blob/82859f61857ef39898b34a5cdf0ae56ec25704d9/torch/_functorch/_aot_autograd/utils.py#L23-L34, so we serliaze partition_spec and mesh into string.

compiled_flash_attention = aot_function(
flash_attention, fw_compiler=compiler)
o_actual = compiled_flash_attention(q, k, v, causal, q_segment_ids,
kv_segment_ids, sm_scale)
xm.mark_step()
if causal:
attention_mask = torch.triu(torch.ones(SEQ, SEQ), diagonal=1).to("xla")
else:
attention_mask = None

expected_output = self._attention(q, k, v, attn_mask=attention_mask)
xm.mark_step()
self.assertTrue(
torch.allclose(o_actual.cpu(), expected_output.cpu(), atol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
@with_jax_high_precision
def test_flash_attention_forward_aot_autograd_traceable_ab(self):
from functorch.compile import aot_function, make_boxed_func
from torch_xla.experimental.custom_kernel import flash_attention
import torch_xla.core.xla_model as xm

def compiler(gm, _):
return make_boxed_func(gm)

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8).to("xla")
k = torch.randn(4, 2, 128, 8).to("xla")
v = torch.randn(4, 2, 128, 8).to("xla")
B, N, SEQ, H = q.size()
causal = False
q_segment_ids = None
kv_segment_ids = None
sm_scale = 1.0
mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla")
ab = torch.ones(4, 2, 128, 128).to("xla")
ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min)

compiled_flash_attention = aot_function(
flash_attention, fw_compiler=compiler)
o_actual = compiled_flash_attention(
q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab=ab)
xm.mark_step()

expected_output = self._attention(q, k, v, ab=ab)
xm.mark_step()
self.assertTrue(
torch.allclose(o_actual.cpu(), expected_output.cpu(), atol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
@with_jax_high_precision
def test_flash_attention_backward_aot_autograd_traceable(self):
from functorch.compile import aot_function, make_boxed_func
from torch_xla.experimental.custom_kernel import flash_attention
import torch_xla.core.xla_model as xm

def compiler(gm, _):
return make_boxed_func(gm)

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
B, N, SEQ, H = q.size()
mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla")
ab = torch.ones(4, 2, 128, 128).to("xla")
ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_()
ab.retain_grad()

causal = False
q_segment_ids = None
kv_segment_ids = None
sm_scale = 1.0
compiled_flash_attention = aot_function(
flash_attention, fw_compiler=compiler)
o_actual = compiled_flash_attention(
q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab=ab)
loss = o_actual.sum()
loss.backward()
xm.mark_step()
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
ab_grad = ab.grad

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
ab = torch.ones(4, 2, 128, 128).to("xla")
ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_()
ab.retain_grad()

o = self._attention(q, k, v, ab=ab)
loss = o.sum()
loss.backward()
xm.mark_step()

for i in [(q, q_grad), (k, k_grad), (v, v_grad), (ab, ab_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-02))


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit a8c8f47

Please sign in to comment.