Skip to content

Commit

Permalink
Merge pull request #224 from punch-mission/add-fits
Browse files Browse the repository at this point in the history
add FITS saving and loading
  • Loading branch information
jmbhughes authored Nov 2, 2024
2 parents 15876d9 + 388d947 commit ca4ffb4
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 58 deletions.
107 changes: 74 additions & 33 deletions docs/source/example.ipynb

Large diffs are not rendered by default.

52 changes: 40 additions & 12 deletions regularizepsf/psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import matplotlib as mpl
import numpy as np
import scipy.fft
from astropy.io import fits

from regularizepsf.exceptions import IncorrectShapeError, InvalidCoordinateError, InvalidFunctionError
from regularizepsf.util import IndexedCube
Expand Down Expand Up @@ -259,7 +260,7 @@ def fft_at(self, coord: tuple[int, int]) -> np.ndarray:
return self._fft_cube[coord]

def save(self, path: pathlib.Path) -> None:
"""Save the PSF model to a file.
"""Save the PSF model to a file. Supports h5 and FITS.
Parameters
----------
Expand All @@ -271,14 +272,24 @@ def save(self, path: pathlib.Path) -> None:
None
"""
with h5py.File(path, "w") as f:
f.create_dataset("coordinates", data=self.coordinates)
f.create_dataset("values", data=self.values)
f.create_dataset("fft_evaluations", data=self.fft_evaluations)
if path.suffix == ".h5":
with h5py.File(path, "w") as f:
f.create_dataset("coordinates", data=self.coordinates)
f.create_dataset("values", data=self.values)
f.create_dataset("fft_evaluations", data=self.fft_evaluations)
elif path.suffix == ".fits":
fits.HDUList([fits.PrimaryHDU(),
fits.CompImageHDU(np.array(self.coordinates), name="coordinates"),
fits.CompImageHDU(self.values, name="values"),
fits.CompImageHDU(self.fft_evaluations.real, name="fft_real", quantize_level=32),
fits.CompImageHDU(self.fft_evaluations.imag, name="fft_imag", quantize_level=32),
]).writeto(path)
else:
raise NotImplementedError(f"Unsupported file type {path.suffix}. Change to .h5 or .fits.")

@classmethod
def load(cls, path: pathlib.Path) -> ArrayPSF:
"""Load the PSF model from a file.
"""Load the PSF model from a file. Supports h5 and FITS.
Parameters
----------
Expand All @@ -291,12 +302,29 @@ def load(cls, path: pathlib.Path) -> ArrayPSF:
loaded model
"""
with h5py.File(path, "r") as f:
coordinates = [tuple(c) for c in f["coordinates"][:]]
values = f["values"][:]
fft_evaluations = f["fft_evaluations"][:]
values_cube = IndexedCube(coordinates, values)
fft_cube = IndexedCube(coordinates, fft_evaluations)
if path.suffix == ".h5":
with h5py.File(path, "r") as f:
coordinates = [tuple(c) for c in f["coordinates"][:]]
values = f["values"][:]
fft_evaluations = f["fft_evaluations"][:]
values_cube = IndexedCube(coordinates, values)
fft_cube = IndexedCube(coordinates, fft_evaluations)
elif path.suffix == ".fits":
with fits.open(path) as hdul:
coordinates_index = hdul.index_of("coordinates")
coordinates = [tuple(c) for c in hdul[coordinates_index].data]

values_index = hdul.index_of("values")
values = hdul[values_index].data
values_cube = IndexedCube(coordinates, values)

fft_real_index = hdul.index_of("fft_real")
fft_real = hdul[fft_real_index].data
fft_imag_index = hdul.index_of("fft_imag")
fft_imag = hdul[fft_imag_index].data
fft_cube = IndexedCube(coordinates, fft_real + fft_imag*1j)
else:
raise NotImplementedError(f"Unsupported file type {path.suffix}. Change to .h5 or .fits.")
return cls(values_cube, fft_cube)

def visualize_psfs(self,
Expand Down
42 changes: 32 additions & 10 deletions regularizepsf/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
import numpy as np
import scipy
from astropy.io import fits

from regularizepsf.exceptions import InvalidCoordinateError
from regularizepsf.util import IndexedCube
Expand Down Expand Up @@ -147,7 +148,7 @@ def visualize(self) -> None:
"""

def save(self, path: pathlib.Path) -> None:
"""Save a PSFTransform to a file.
"""Save a PSFTransform to a file. Supports h5 and FITS.
Parameters
----------
Expand All @@ -159,13 +160,22 @@ def save(self, path: pathlib.Path) -> None:
None
"""
with h5py.File(path, "w") as f:
f.create_dataset("coordinates", data=self.coordinates)
f.create_dataset("transfer_kernel", data=self._transfer_kernel.values)

if path.suffix == ".h5":
with h5py.File(path, "w") as f:
f.create_dataset("coordinates", data=self.coordinates)
f.create_dataset("transfer_kernel", data=self._transfer_kernel.values)
elif path.suffix == ".fits":
fits.HDUList([fits.PrimaryHDU(),
fits.CompImageHDU(np.array(self.coordinates), name="coordinates"),
fits.CompImageHDU(self._transfer_kernel.values.real,
name="transfer_real", quantize_level=32),
fits.CompImageHDU(self._transfer_kernel.values.imag,
name="transfer_imag", quantize_level=32)]).writeto(path)
else:
raise NotImplementedError(f"Unsupported file type {path.suffix}. Change to .h5 or .fits.")
@classmethod
def load(cls, path: pathlib.Path) -> ArrayPSFTransform:
"""Load a PSFTransform object.
"""Load a PSFTransform object. Supports h5 and FITS.
Parameters
----------
Expand All @@ -177,10 +187,22 @@ def load(cls, path: pathlib.Path) -> ArrayPSFTransform:
PSFTransform
"""
with h5py.File(path, "r") as f:
coordinates = [tuple(c) for c in f["coordinates"][:]]
transfer_kernel = f["transfer_kernel"][:]
kernel = IndexedCube(coordinates, transfer_kernel)
if path.suffix == ".h5":
with h5py.File(path, "r") as f:
coordinates = [tuple(c) for c in f["coordinates"][:]]
transfer_kernel = f["transfer_kernel"][:]
kernel = IndexedCube(coordinates, transfer_kernel)
elif path.suffix == ".fits":
with fits.open(path) as hdul:
coordinates_index = hdul.index_of("coordinates")
coordinates = [tuple(c) for c in hdul[coordinates_index].data]
transfer_real_index = hdul.index_of("transfer_real")
transfer_real = hdul[transfer_real_index].data
transfer_imag_index = hdul.index_of("transfer_imag")
transfer_imag = hdul[transfer_imag_index].data
kernel = IndexedCube(coordinates, transfer_real + transfer_imag*1j)
else:
raise NotImplementedError(f"Unsupported file type {path.suffix}. Change to .h5 or .fits.")
return cls(kernel)

def __eq__(self, other: ArrayPSFTransform) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion regularizepsf/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,5 @@ def __eq__(self, other: IndexedCube) -> bool:
return (
self.coordinates == other.coordinates
and self.sample_shape == other.sample_shape
and np.allclose(self.values, other.values)
and np.allclose(self.values, other.values, rtol=1e-04, atol=1e-06)
)
5 changes: 3 additions & 2 deletions tests/test_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
from tests.helper import make_gaussian


def test_arraypsf_saves_and_loads(tmp_path):
@pytest.mark.parametrize("extension", ["fits", "h5"])
def test_arraypsf_saves_and_loads(tmp_path, extension):
"""Can save and reload an ArrayPSF"""
coordinates = [(0, 0), (1, 1), (2, 2)]
gauss = make_gaussian(128, fwhm=3)
values = np.stack([gauss for _ in coordinates])

source = ArrayPSF(IndexedCube(coordinates, values))

path = tmp_path / "psf.h5"
path = tmp_path / f"psf.{extension}"

source.save(path)
reloaded = ArrayPSF.load(path)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@
from tests.helper import make_gaussian


@pytest.mark.parametrize("extension", ["fits", "h5"])
def test_transform_saves_and_loads(tmp_path, extension):
"""Can save and reload an ArrayPSF"""
coordinates = [(0, 0), (1, 1), (2, 2)]
gauss = make_gaussian(128, fwhm=3)
values = np.stack([gauss for _ in coordinates])

source = ArrayPSF(IndexedCube(coordinates, values))
target = ArrayPSF(IndexedCube(coordinates, values))
transform = ArrayPSFTransform.construct(source, target, 1.0, 0.1)

path = tmp_path / f"transform.{extension}"

transform.save(path)
reloaded = ArrayPSFTransform.load(path)

assert transform == reloaded

def test_transform_apply():
"""Test that applying an identity transform does not change the values."""
size = 256
Expand Down

0 comments on commit ca4ffb4

Please sign in to comment.