forked from mlcommons/algorithmic-efficiency
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_ssim.py
81 lines (67 loc) · 3.06 KB
/
test_ssim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Test for the equality of the SSIM calculation in Jax and PyTorch."""
import os
from typing import Tuple
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
import numpy as np
import torch
from algorithmic_efficiency.pytorch_utils import pytorch_setup
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.ssim import \
_uniform_filter as _jax_uniform_filter
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.ssim import \
ssim as jax_ssim
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.ssim import \
_uniform_filter as _pytorch_uniform_filter
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.ssim import \
ssim as pytorch_ssim
# Make sure no GPU memory is preallocated to Jax.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
DEVICE = pytorch_setup()[2]
def _create_fake_im(height: int, width: int) -> Tuple[jnp.array, torch.Tensor]:
fake_im = np.random.randn(height, width)
jax_fake_im = jnp.asarray(fake_im)
pytorch_fake_im = torch.as_tensor(fake_im, device=DEVICE)
return jax_fake_im, pytorch_fake_im
def _create_fake_batch(
batch_size: int, height: int, width: int
) -> Tuple[Tuple[jnp.array, jnp.array], Tuple[torch.Tensor, torch.Tensor]]:
logits = np.random.randn(batch_size, height, width)
targets = np.random.randn(batch_size, height, width)
jax_logits = jnp.asarray(logits)
jax_targets = jnp.asarray(targets)
pytorch_logits = torch.as_tensor(logits, device=DEVICE)
pytorch_targets = torch.as_tensor(targets, device=DEVICE)
return (jax_logits, jax_targets), (pytorch_logits, pytorch_targets)
class SSIMTest(parameterized.TestCase):
"""Test for equivalence of SSIM and _uniform_filter implementations in Jax
and PyTorch."""
@parameterized.named_parameters(
dict(testcase_name='fastmri_im', height=320, width=320),
dict(testcase_name='uneven_even_im', height=31, width=16),
dict(testcase_name='even_uneven_im', height=42, width=53),
)
def test_uniform_filter(self, height: int, width: int) -> None:
jax_im, pytorch_im = _create_fake_im(height, width)
jax_result = np.asarray(_jax_uniform_filter(jax_im))
torch_result = _pytorch_uniform_filter(pytorch_im).cpu().numpy()
assert np.allclose(jax_result, torch_result, atol=1e-6)
@parameterized.named_parameters(
dict(
testcase_name='fastmri_batch', batch_size=256, height=320, width=320),
dict(
testcase_name='uneven_even_batch', batch_size=8, height=31, width=16),
dict(
testcase_name='even_uneven_batch', batch_size=8, height=42, width=53),
)
def test_ssim(self, batch_size: int, height: int, width: int) -> None:
jax_inputs, pytorch_inputs = _create_fake_batch(batch_size, height, width)
jax_ssim_result = jax_ssim(*jax_inputs)
pytorch_ssim_result = pytorch_ssim(*pytorch_inputs)
self.assertEqual(jax_ssim_result.shape, pytorch_ssim_result.shape)
assert np.allclose(
jax_ssim_result.sum().item(),
pytorch_ssim_result.sum().item(),
atol=1e-6)
if __name__ == '__main__':
absltest.main()