diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py index 0ece3ffa9..e15b93616 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py @@ -40,6 +40,12 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): logits = logits * std + mean targets = targets * std + mean ssims = jax.vmap(structural_similarity)(logits, targets, volume_max) + + # map out-of-bounds ssims to 1 and -1, the theoretical + # maximum and minimum values of SSIM. + ssims = jnp.where(ssims > 1, jnp.ones_like(ssims), ssims) + ssims = jnp.where(ssims < -1, jnp.ones_like(ssims) * -1, ssims) + return ssims diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py index ebee661c8..eff6fb62f 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py @@ -47,6 +47,12 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): logits = logits * std + mean targets = targets * std + mean ssims = torch.vmap(structural_similarity)(logits, targets, volume_max) + + # map out-of-bounds ssims to 1 and -1, the theoretical + # maximum and minimum values of SSIM. + ssims = torch.where(ssims > 1, torch.ones_like(ssims), ssims) + ssims = torch.where(ssims < -1, torch.ones_like(ssims) * -1, ssims) + return ssims