Skip to content

Commit

Permalink
Merge pull request #71 from nalinimsingh/main
Browse files Browse the repository at this point in the history
Building out IDEAL optimization part of package
  • Loading branch information
emarkley authored Feb 13, 2025
2 parents 6ac3b1a + 6c70809 commit 9c07a76
Show file tree
Hide file tree
Showing 41 changed files with 6,482 additions and 15 deletions.
563 changes: 563 additions & 0 deletions ideal/ideal_example.ipynb

Large diffs are not rendered by default.

196 changes: 196 additions & 0 deletions ideal/image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import equinox as eqx
import jax
from encoding_information import image_utils
import numpy as onp
import jax.numpy as jnp
from functools import partial
from encoding_information.models.model_base_class import MeasurementModel, make_dataset_generators
from encoding_information.models.gaussian_process import match_to_generator_data, generate_multivariate_gaussian_samples
import warnings

# Patching functions

@partial(jax.jit, static_argnames=('num_patches', 'patch_size'))
def _extract_random_patches(data, key, num_patches, patch_size):
keys = jax.random.split(key, 3)
image_indices = jax.random.randint(keys[0], shape=(num_patches,),
minval=0, maxval=data.shape[0])
x_indices = jax.random.randint(keys[1], shape=(num_patches,),
minval=0, maxval=data.shape[1] - patch_size + 1)
y_indices = jax.random.randint(keys[2], shape=(num_patches,),
minval=0, maxval=data.shape[2] - patch_size + 1)

def get_patch(i):
return jax.lax.dynamic_slice(
data[image_indices[i]],
(x_indices[i], y_indices[i]),
(patch_size, patch_size)
)

return jax.vmap(get_patch)(jnp.arange(num_patches))

@partial(jax.jit, static_argnames=('num_patches', 'patch_size'))
def _extract_tiled_patches(data, key, num_patches, patch_size):
num_tiles_x = data.shape[1] // patch_size
num_tiles_y = data.shape[2] // patch_size

keys = jax.random.split(key, 3)
image_indices = jax.random.randint(keys[0], shape=(num_patches,),
minval=0, maxval=data.shape[0])
x_tile_indices = jax.random.randint(keys[1], shape=(num_patches,),
minval=0, maxval=num_tiles_x)
y_tile_indices = jax.random.randint(keys[2], shape=(num_patches,),
minval=0, maxval=num_tiles_y)

def get_tile(i):
return jax.lax.dynamic_slice(
data[image_indices[i]],
(x_tile_indices[i] * patch_size, y_tile_indices[i] * patch_size),
(patch_size, patch_size)
)

return jax.vmap(get_tile)(jnp.arange(num_patches))

@partial(jax.jit, static_argnames=('num_patches', 'patch_size'))
def _extract_cropped_patches(data, key, num_patches, patch_size, crop_location):
if crop_location is not None:
y_index, x_index = crop_location
else:
key1, key2 = jax.random.split(key)
x_index = jax.random.randint(key1, shape=(),
minval=0, maxval=data.shape[1] - patch_size + 1)
y_index = jax.random.randint(key2, shape=(),
minval=0, maxval=data.shape[2] - patch_size + 1)

patches = jax.lax.dynamic_slice(data,
(0, x_index, y_index),
(min(data.shape[0], num_patches), patch_size, patch_size))

if num_patches > data.shape[0]:
key3 = jax.random.split(key)[0]
extra_indices = jax.random.randint(key3, shape=(num_patches - data.shape[0],),
minval=0, maxval=data.shape[0])
patches = jnp.concatenate([patches, patches[extra_indices]])
elif num_patches < data.shape[0]:
key3 = jax.random.split(key)[0]
indices = jax.random.permutation(key3, data.shape[0])[:num_patches]
patches = patches[indices]

return patches

@partial(jax.jit, static_argnames=('num_patches', 'patch_size', 'num_masked_pixels'))
def _extract_masked_patches(data, key, num_patches, patch_size, num_masked_pixels):
data_size = data[0].size
indices = jax.random.permutation(key, data_size)[:num_masked_pixels]

def get_masked_data(img):
return jnp.take(img.ravel(), indices)

patches = jax.vmap(get_masked_data)(data)

if num_patches > data.shape[0]:
key2 = jax.random.split(key)[0]
extra_indices = jax.random.randint(key2, shape=(num_patches - data.shape[0],),
minval=0, maxval=data.shape[0])
patches = jnp.concatenate([patches, patches[extra_indices]])
elif num_patches < data.shape[0]:
key2 = jax.random.split(key)[0]
sample_indices = jax.random.permutation(key2, data.shape[0])[:num_patches]
patches = patches[sample_indices]

if patch_size * patch_size == num_masked_pixels:
patches = patches.reshape(num_patches, patch_size, patch_size)

return patches

def extract_patches(data, key, num_patches=1000, patch_size=16, strategy='random',
crop_location=None, num_masked_pixels=256, verbose=False) -> jnp.ndarray:
"""Extract patches from a dataset using various strategies, optimized for JAX."""
strategies = {
'random': lambda: _extract_random_patches(data, key, num_patches, patch_size),
'uniform_random': lambda: _extract_random_patches(
jnp.pad(data,
((0, 0), (patch_size, patch_size),
(patch_size, patch_size)),
mode='constant',
constant_values=jnp.mean(data)),
key, num_patches, patch_size
),
'tiled': lambda: _extract_tiled_patches(data, key, num_patches, patch_size),
'cropped': lambda: _extract_cropped_patches(data, key, num_patches, patch_size, crop_location),
'masked': lambda: _extract_masked_patches(data, key, num_patches, patch_size, num_masked_pixels)
}

return strategies[strategy]()

# Noise functions

@jax.jit
def _add_noise_single(image, key, gaussian_sigma, ensure_positive):
"""Helper function to add noise to a single image."""
if gaussian_sigma is not None:
noisy_image = image + gaussian_sigma * jax.random.normal(key, shape=image.shape)
else:
noisy_image = image + jax.random.normal(key, shape=image.shape) * jnp.sqrt(jnp.maximum(image,1e-8))


return jnp.where(ensure_positive, jnp.maximum(0, noisy_image), noisy_image)

def add_noise(images, ensure_positive=True, gaussian_sigma=None, key=None, seed=None, batch_size=None):
"""
Add Poisson or Gaussian noise to a stack of images using JAX optimization.
Parameters
----------
images : ndarray
A stack of images (NxHxW) or patches (Nx(num pixels)).
ensure_positive : bool, optional
Whether to ensure all resulting pixel values are non-negative.
gaussian_sigma : float, optional
Standard deviation for Gaussian noise. If None, Poisson noise is added.
key : jax.random.PRNGKey, optional
PRNGKey for generating noise. If None, a key is generated based on the seed.
seed : int, optional
Seed for generating noise, if no key is provided.
batch_size : int, optional
Deprecated. Included for backward compatibility but no longer needed.
Returns
-------
ndarray
Noisy images.
"""
if seed is None:
seed = onp.random.randint(0, 100000)
if key is None:
key = jax.random.PRNGKey(seed)

images = images.astype(jnp.float32)

# Create separate keys for each image
keys = jax.random.split(key, images.shape[0])

# Vectorize the noise addition operation across the batch dimension
vectorized_noise = jax.vmap(_add_noise_single, in_axes=(0, 0, None, None))

return vectorized_noise(images, keys, gaussian_sigma, ensure_positive)

def jax_crop2D(target_shape, mat):
"""
Center crop the 2D or 3D input matrix to the target shape.
Args:
target_shape (tuple): Target shape for cropping.
mat (np.array): Input matrix.
Returns:
onp.array: Cropped matrix.
"""
y_margin = (mat.shape[-2] - target_shape[-2]) // 2
x_margin = (mat.shape[-1] - target_shape[-1]) // 2
if mat.ndim == 2:
return mat[y_margin : -y_margin or None, x_margin : -x_margin or None]
elif mat.ndim == 3:
return mat[:, y_margin : -y_margin or None, x_margin : -x_margin or None]
else:
raise ValueError("jax_crop2D only supports 2D and 3D arrays.")
177 changes: 177 additions & 0 deletions ideal/imaging_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""
Imaging System Base Class
"""
from typing import Protocol, runtime_checkable
import jax
import jax.numpy as jnp
from jax import random
import equinox as eqx
import matplotlib.pyplot as plt

@runtime_checkable
class ImagingSystemProtocol(Protocol):
"""Protocol defining the interface for an imaging system."""
# seed: int

def forward_model(self, objects: jnp.ndarray) -> jnp.ndarray:
"""Simulates the forward model of the imaging system."""
...

def reconstruct(self, measurements: jnp.ndarray) -> jnp.ndarray:
"""Reconstructs objects from measurements."""
...

def toy_images(self, batch_size: int, height: int, width: int, channels: int) -> jnp.ndarray:
"""Generates toy images for testing the system."""
...

def next_rng_key(self) -> jax.random.PRNGKey:
"""Generates the next random key."""
...

def display_measurement(self, measurement: jnp.ndarray) -> None:
"""Displays the measurement."""
return NotImplemented

def display_reconstruction(self, reconstruction: jnp.ndarray) -> None:
"""Displays the reconstruction."""
return NotImplemented

def display_object(self, object: jnp.ndarray) -> None:
"""Displays the object."""
return NotImplemented

def display_optics(self) -> None:
"""Displays learned the optics."""
return NotImplemented


class ImagingSystem(eqx.Module):
"""Abstract base class for an imaging system."""
seed: int = eqx.field(static=True)
rng_key: jax.random.PRNGKey = eqx.field(static=True)

def __init__(self, seed: int = 0):
"""
Initializes the imaging system with a given random seed.
Args:
seed: Seed for the random number generator.
"""
self.seed = seed
self.rng_key = random.PRNGKey(seed)

def forward_model(self, objects: jnp.ndarray) -> jnp.ndarray:
"""
Runs the forward model.
Args:
objects: Input objects of shape (H, W, C).
Returns:
measurements: Output measurements of shape (H, W, C).
"""
raise NotImplementedError("Subclasses must implement forward_model.")

def reconstruct(self, measurements: jnp.ndarray) -> jnp.ndarray:
"""
Reconstructs objects from measurements.
Args:
measurements: Input measurements of shape (H, W, C).
Returns:
reconstructions: Reconstructed objects of shape (H, W, C).
"""
raise NotImplementedError("Subclasses must implement reconstruct.")

def toy_images(self, batch_size: int, height: int, width: int, channels: int) -> jnp.ndarray:
"""
Generates toy images for testing the system.
Args:
batch_size: Number of images to generate.
height: Height of each image.
width: Width of each image.
channels: Number of channels in each image.
Returns:
Toy images of shape (batch_size, height, width, channels).
"""
key = self.next_rng_key()
return random.uniform(key, shape=(batch_size, height, width, channels), minval=0, maxval=1)

def next_rng_key(self) -> jax.random.PRNGKey:
"""
Generates the next random key and updates the RNG state.
Returns:
A new random key.
"""
rng_key, subkey = random.split(self.rng_key)
object.__setattr__(self, 'rng_key', rng_key)
return subkey

def display_measurement(self, measurement: jnp.ndarray) -> plt.Figure:
"""
Displays the measurement as a matplotlib figure.
Args:
measurement: Input measurement of shape (H, W, C).
Returns:
fig: Matplotlib figure showing the measurement.
"""
fig, ax = plt.subplots()
im = ax.imshow(measurement, cmap='inferno')
plt.colorbar(im)
ax.set_title('Measurement')
return fig

def display_reconstruction(self, reconstruction: jnp.ndarray) -> plt.Figure:
"""
Displays the reconstruction as a matplotlib figure.
Args:
reconstruction: Input reconstruction of shape (H, W, C).
Returns:
fig: Matplotlib figure showing the reconstruction.
"""
fig, ax = plt.subplots()
ax.text(0.5, 0.5, 'Reconstruction display not implemented for base class',
ha='center', va='center')
ax.set_xticks([])
ax.set_yticks([])
return fig

def display_object(self, object: jnp.ndarray) -> plt.Figure:
"""
Displays the object as a matplotlib figure.
Args:
object: Input object of shape (H, W, C).
Returns:
fig: Matplotlib figure showing the object.
"""
fig, ax = plt.subplots()
ax.text(0.5, 0.5, 'Object display not implemented for base class',
ha='center', va='center')
ax.set_xticks([])
ax.set_yticks([])
return fig

def display_optics(self) -> plt.Figure:
"""
Displays the optical system configuration as a matplotlib figure.
Returns:
fig: Matplotlib figure showing the optical system.
"""
fig, ax = plt.subplots()
ax.text(0.5, 0.5, 'Optics display not implemented for base class',
ha='center', va='center')
ax.set_xticks([])
ax.set_yticks([])
return fig
Loading

0 comments on commit 9c07a76

Please sign in to comment.