From feec5a377c3e6357b112081e3783bbebd3fcc6b9 Mon Sep 17 00:00:00 2001 From: Leyla Kabuli Date: Wed, 23 Oct 2024 16:14:02 -0700 Subject: [PATCH 1/5] working version of multichannel pixelcnn --- src/encoding_information/models/__init__.py | 1 + .../models/multichannel_pixel_cnn.py | 762 ++++++++++++++++++ 2 files changed, 763 insertions(+) create mode 100644 src/encoding_information/models/multichannel_pixel_cnn.py diff --git a/src/encoding_information/models/__init__.py b/src/encoding_information/models/__init__.py index b9d5bdb..ef91f66 100644 --- a/src/encoding_information/models/__init__.py +++ b/src/encoding_information/models/__init__.py @@ -1,4 +1,5 @@ from .pixel_cnn import PixelCNN +from .multichannel_pixel_cnn import MultiChannelPixelCNN from .gaussian_process import FullGaussianProcess, StationaryGaussianProcess from .conditional_entropy_models import AnalyticGaussianNoiseModel, PoissonNoiseModel, AnalyticComplexPixelGaussianNoiseModel \ No newline at end of file diff --git a/src/encoding_information/models/multichannel_pixel_cnn.py b/src/encoding_information/models/multichannel_pixel_cnn.py new file mode 100644 index 0000000..26826da --- /dev/null +++ b/src/encoding_information/models/multichannel_pixel_cnn.py @@ -0,0 +1,762 @@ +""" +MultichannelPixelCNN in Jax/Flax. Adapted from single channel PixelCNN implementation in Flax.: +https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial12/Autoregressive_Image_Modeling.html + +Univariate mixture density output adapted from: +https://github.com/hardmaru/mdn_jax_tutorial/blob/master/mixture_density_networks_jax.ipynb +""" + +## Standard libraries +import os +import numpy as onp +from typing import Any +from tqdm import tqdm +import warnings + +## JAX +import jax +import jax.numpy as np +from jax import random +from jax.scipy.special import logsumexp + +from flax import linen as nn +from flax.training.train_state import TrainState +import optax + + +from .model_base_class import MeasurementModel, MeasurementType, \ + train_model, _evaluate_nll, make_dataset_generators + + + +class PreprocessLayer(nn.Module): + """ + A layer that normalizes the input images using the provided mean and standard deviation. + + Attributes + ---------- + mean : np.ndarray + The mean to subtract from the input images. + std : np.ndarray + The standard deviation to divide the input images by. + """ + mean: np.ndarray + std: np.ndarray + + def __call__(self, x): + return (x - self.mean) / (self.std + 1e-5) + +class MaskedConvolution(nn.Module): + """ + A convolutional layer with a mask to ensure autoregressive behavior. + + This layer ensures that during the convolution, the current pixel does not + have access to any future pixels (either to the right or below in the image). + + Attributes + ---------- + c_out : int + The number of output channels. + mask : np.ndarray + The mask to apply to the convolution, determining which pixels are visible. + dilation : int, optional + The dilation factor for the convolution (default is 1). + """ + c_out : int + mask : np.ndarray + dilation : int = 1 + + @nn.compact + def __call__(self, x): + # Flax's convolution module already supports masking + # The mask must be the same size as kernel + # => extend over input and output feature channels + if len(self.mask.shape) == 2: + mask_ext = self.mask[...,None,None] + mask_ext = np.tile(mask_ext, (1, 1, x.shape[-1], self.c_out)) + else: + mask_ext = self.mask + # Convolution with masking + x = nn.Conv(features=self.c_out, + kernel_size=self.mask.shape[:2], + kernel_dilation=self.dilation, + mask=mask_ext)(x) + return x + + +class VerticalStackConvolution(nn.Module): + """ + A vertical convolutional layer that processes the pixels above the current pixel in an image. + + This layer creates a vertical stack by masking the convolution kernel, ensuring that the pixels + below the current pixel are not visible during the convolution. + + Attributes + ---------- + c_out : int + The number of output channels. + kernel_size : int + The size of the convolution kernel. + mask_center : bool, optional + Whether to mask out the center pixel in the kernel (default is False). + dilation : int, optional + The dilation factor for the convolution (default is 1). + """ + c_out : int + kernel_size : int + mask_center : bool = False + dilation : int = 1 + + def setup(self): + # Mask out all pixels below. For efficiency, we could also reduce the kernel + # size in height, but for simplicity, we stick with masking here. + mask = onp.ones((self.kernel_size, self.kernel_size), dtype=onp.float32) + mask[self.kernel_size//2+1:,:] = 0 + # For the very first convolution, we will also mask the center row + if self.mask_center: + mask[self.kernel_size//2,:] = 0 + # Our convolution module + self.conv = MaskedConvolution(c_out=self.c_out, + mask=mask, + dilation=self.dilation) + + def __call__(self, x): + return self.conv(x) + + +class HorizontalStackConvolution(nn.Module): + """ + A horizontal convolutional layer that processes the pixels to the left of the current pixel in an image. + + This layer creates a horizontal stack by masking the convolution kernel, ensuring that the pixels + to the right of the current pixel are not visible during the convolution. + + Attributes + ---------- + c_out : int + The number of output channels. + kernel_size : int + The size of the convolution kernel. + mask_center : bool, optional + Whether to mask out the center pixel in the kernel (default is False). + dilation : int, optional + The dilation factor for the convolution (default is 1). + """ + c_out : int + kernel_size : int + mask_center : bool = False + dilation : int = 1 + + def setup(self): + # Mask out all pixels on the left. Note that our kernel has a size of 1 + # in height because we only look at the pixel in the same row. + mask = onp.ones((1, self.kernel_size), dtype=onp.float32) + mask[0,self.kernel_size//2+1:] = 0 + # For the very first convolution, we will also mask the center pixel + if self.mask_center: + mask[0,self.kernel_size//2] = 0 + # Our convolution module + self.conv = MaskedConvolution(c_out=self.c_out, + mask=mask, + dilation=self.dilation) + + def __call__(self, x): + return self.conv(x) + + +class GatedMaskedConv(nn.Module): + """ + A gated masked convolution layer used in PixelCNN. This layer uses gated activation functions + to improve gradient flow during training. + + It combines information from a vertical stack and a horizontal stack, each being passed through + masked convolutions, and optionally conditioned on an external vector (such as class labels). + + Attributes + ---------- + dilation : int, optional + The dilation factor for the convolutions (default is 1). + id : int, optional + The layer ID, used for parameter naming. + condition_vector_size : int, optional + The size of the condition vector for conditional PixelCNN. + """ + dilation : int = 1 + id: int = None + condition_vector_size : int = None + + @nn.compact + def __call__(self, v_stack, h_stack, condition_vector=None): + c_in = v_stack.shape[-1] + + # Layers (depend on input shape) + conv_vert = VerticalStackConvolution(c_out=2*c_in, + kernel_size=3, + mask_center=False, + dilation=self.dilation) + conv_horiz = HorizontalStackConvolution(c_out=2*c_in, + kernel_size=3, + mask_center=False, + dilation=self.dilation) + conv_vert_to_horiz = nn.Conv(2*c_in, + kernel_size=(1, 1)) + conv_horiz_1x1 = nn.Conv(c_in, + kernel_size=(1, 1)) + + + + # Vertical stack (left) + v_stack_feat = conv_vert(v_stack) + v_val, v_gate = np.split(v_stack_feat, 2, axis=-1) + + if condition_vector is not None: + weights = self.param(f'conditioning_weights_vert_{self.id}', jax.nn.initializers.lecun_normal(), (1, self.condition_vector_size,)) + y = np.dot(weights, condition_vector.T).reshape(-1,1,1,1) + weights_gate = self.param(f'conditioning_weights_vert_gate{self.id}', jax.nn.initializers.lecun_normal(), (1, self.condition_vector_size,)) + y_gate = np.dot(weights_gate, condition_vector.T).reshape(-1,1,1,1) + v_stack_out = nn.tanh(v_val + y) * nn.sigmoid(v_gate + y_gate) + else: + v_stack_out = nn.tanh(v_val) * nn.sigmoid(v_gate) + + # Horizontal stack (right) + h_stack_feat = conv_horiz(h_stack) + h_stack_feat = h_stack_feat + conv_vert_to_horiz(v_stack_feat) + h_val, h_gate = np.split(h_stack_feat, 2, axis=-1) + if condition_vector is not None: + weights = self.param(f'conditioning_weights_horz{self.id}', jax.nn.initializers.lecun_normal(), (1, self.condition_vector_size,)) + y = np.dot(weights, condition_vector.T).reshape(-1,1,1,1) + weights_gate = self.param(f'conditioning_weights_horz_gate{self.id}', jax.nn.initializers.lecun_normal(), (1, self.condition_vector_size,)) + y_gate = np.dot(weights_gate, condition_vector.T).reshape(-1,1,1,1) + h_stack_feat = nn.tanh(h_val + y) * nn.sigmoid(h_gate + y_gate) + else: + h_stack_feat = nn.tanh(h_val) * nn.sigmoid(h_gate) + h_stack_out = conv_horiz_1x1(h_stack_feat) + h_stack_out = h_stack_out + h_stack + + return v_stack_out, h_stack_out + + +class _MultiChannelPixelCNNFlaxImpl(nn.Module): + """ + The core implementation of the PixelCNN model in Flax. + + This module defines the structure of the PixelCNN, including the vertical and horizontal + masked convolutions, gated activation functions, and a mixture density output layer. + + Attributes + ---------- + data_shape : tuple + The shape of the input data (height, width, channels). + num_hidden_channels : int, optional + The number of hidden channels in the model (default is 64). + num_mixture_components : int, optional + The number of components in the mixture density output (default is 40). + train_data_mean : float + The mean of the training data used for normalization. Multichannel considers a float for each channel. + train_data_std : float + The standard deviation of the training data used for normalization. Multichannel considers a float for each channel. + train_data_min : float + The minimum value of the training data. Multichannel considers a float for each channel. + train_data_max : float + The maximum value of the training data. Multichannel considers a float for each channel. + sigma_min : float, optional + The minimum standard deviation for the mixture density output (default is 1). + condition_vector_size : int, optional + The size of the condition vector for conditional PixelCNN. + use_positional_embedding : bool, optional + Whether to use learned positional embeddings for each pixel (default is False). + """ + data_shape : tuple + num_hidden_channels : int = 64 + num_mixture_components : int = 40 + train_data_mean : float = None + train_data_std : float = None + train_data_min : float = None + train_data_max : float = None + sigma_min : float = 1 + condition_vector_size : int = None + use_positional_embedding : bool = False + + def setup(self): + if None in [self.train_data_mean, self.train_data_std, self.train_data_min, self.train_data_max]: + raise Exception('Must pass in training data statistics constructor') + + if self.train_data_max.dtype != np.float32 or self.train_data_min.dtype != np.float32 or \ + self.train_data_mean.dtype != np.float32 or self.train_data_std.dtype != np.float32: + raise Exception('Must pass in training data statistics as float32') + + self.normalize = PreprocessLayer(mean=self.train_data_mean, std=self.train_data_std) + + if not isinstance(self.num_hidden_channels, int): + raise ValueError("num_hidden_channels must be an integer") + # Initial convolutions skipping the center pixel + self.conv_vstack = VerticalStackConvolution(self.num_hidden_channels, kernel_size=3, mask_center=True) + self.conv_hstack = HorizontalStackConvolution(self.num_hidden_channels, kernel_size=3, mask_center=True) + # Convolution block of PixelCNN. We use dilation instead of downscaling + self.conv_layers = [ + GatedMaskedConv(dilation=1, id=0, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=2, id=1, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=1, id=2, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=4, id=3, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=1, id=4, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=2, id=5, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=1, id=6, condition_vector_size=self.condition_vector_size), + ] + # Output classification convolution (1x1) + self.conv_out = nn.Conv(self.num_hidden_channels, kernel_size=(1, 1)) + + # parameters for mixture density + def my_bias_init(rng, shape, dtype): + return random.uniform(rng, shape, dtype=dtype, + minval=np.min(self.train_data_min), maxval=np.max(self.train_data_max)) # just initializing a learnable parameter so using absolute values across channels + + # Parameters for learned positional embedding + if self.use_positional_embedding: + self.positional_embedding = nn.Embed(num_embeddings=self.data_shape[0] * self.data_shape[1], features=self.num_hidden_channels) + # generate unique index for each pixel + self.position_indices = np.arange(self.data_shape[0] * self.data_shape[1]).reshape(*self.data_shape[:2]) + + self.mu_dense = nn.Dense(self.num_mixture_components * self.data_shape[2], bias_init=my_bias_init) # scale by number of channels + self.sigma_dense = nn.Dense(self.num_mixture_components * self.data_shape[2] * self.data_shape[2]) # scale by squared number of channels since matrix + self.mix_logit_dense = nn.Dense(self.num_mixture_components) # mixture components are scalars for each pixel + + def __call__(self, x, condition_vectors=None): + """ + Do forward pass output the parameters of the gaussian mixture output + """ + # add trailing channel dimension if necessary + if x.ndim == 3: + x = x[..., np.newaxis] + + return self.forward_pass(x, condition_vectors=condition_vectors) + + def compute_gaussian_nll(self, mu, sigma, mix_logit, x): + # numerically efficient implementation of mixture density, slightly modified + # see https://github.com/hardmaru/mdn_jax_tutorial/blob/master/mixture_density_networks_jax.ipynb + # compute per-pixel negative log-likelihood + + # one-by-one step version for debugging. + # lognormal = self.lognormal(x, mu, sigma) + # jax.debug.print("Number of nans in lognormal {test}", test=np.sum(np.isnan(lognormal))) + # logit_normalized = mix_logit - logsumexp(mix_logit, axis=-1, keepdims=True) + # jax.debug.print("Number of nans in logit_normalized {test}", test=np.sum(np.isnan(logit_normalized))) + # nll = - logsumexp(logit_normalized + lognormal, axis=-1) + + #all in one step + nll = - logsumexp(mix_logit - logsumexp(mix_logit, axis=-1, keepdims=True) + self.lognormal(x, mu, sigma), axis=-1) + return nll + + def compute_loss(self, mu, sigma, mix_logit, x): + """ + Compute average negative log likelihood per pixel averaged over batch and pixels + """ + return self.compute_gaussian_nll(mu, sigma, mix_logit, x).mean() + + + def lognormal(self, y, mean, sigma): + # expand the data in the n_components dimension and tile + y = np.expand_dims(y, axis=-2) + y = np.tile(y, (1, 1, 1, self.num_mixture_components, 1)) + logRootDTwoPI = np.log(2.0 * np.pi) * self.data_shape[2] / 2.0 # d / 2 log 2pi + covarianceDeterminant = np.linalg.det(sigma) + matrix_sum = np.einsum('...i, ...ij, ...j->...', y - mean, np.linalg.inv(sigma), y - mean) + # -d/2 log(2pi) - 1/2 log det covariance - 0.5 (x - mu)T covaraince^-1 (x - mu) + return -1.0 * logRootDTwoPI - 0.5 * np.log(covarianceDeterminant) - 0.5 * matrix_sum + #return -0.5 * ((y - mean) / sigma) ** 2 - np.log(sigma) - logSqrtTwoPI # previous version for 1D + + def forward_pass(self, x, condition_vectors=None): + """ + Forward pass of the MultiChannelPixelCNN model. + + The image is passed through the vertical and horizontal masked convolutions, followed by + gated convolutions, and finally a mixture density output layer. The model outputs the parameters + of the mixture density for each pixel (mean, standard deviation, and mixture logits). + + Parameters + ---------- + x : ndarray + The input image, with shape (batch_size, height, width, channels). + condition_vectors : ndarray, optional + A vector to condition the image generation process (e.g., class labels). + + Returns + ------- + mu : ndarray + The mean of the Gaussian components for each pixel. + sigma : ndarray + The standard deviation of the Gaussian components for each pixel. + mix_logit : ndarray + The logits for the mixture components. + """ + # check shape + if x.ndim != 4: + raise ValueError("Input image must have shape BxHxWxC") + + # rescale to 0-1ish + x = self.normalize(x) + # Initial convolutions + v_stack = self.conv_vstack(x) + h_stack = self.conv_hstack(x) + # Gated Convolutions + for layer in self.conv_layers: + v_stack, h_stack = layer(v_stack, h_stack, condition_vector=condition_vectors) + # 1x1 classification convolution + # Apply ELU before 1x1 convolution for non-linearity on residual connection + out = self.conv_out(nn.elu(h_stack)) + + if self.use_positional_embedding: + # add positional embedding + indices = self.position_indices + # apply positional embedding + out = out + self.positional_embedding(indices) + # must be positive and within data range + #mu = np.clip(self.mu_dense(out), self.train_data_min, self.train_data_max) # 1D version + # mu items need to be reshaped and clipped + mu_out = self.mu_dense(out) + mu_out = np.reshape(mu_out, (out.shape[0], out.shape[1], out.shape[2], self.num_mixture_components, self.data_shape[2])) # reshape from b x h x w x components*num_channels to b x h x w x components x num_channels + mu = np.clip(mu_out, self.train_data_min, self.train_data_max) + + #sigma = nn.activation.softplus(self.sigma_dense(out)) # 1D version + # avoid having tiny components that overly concentrate mass, and don't need components larger than data standard deviation + #sigma = np.clip(sigma, self.sigma_min, self.train_data_std) # previous version + + # sigma items need to be reshaped to be a covariance matrix, and clipped to be a valid cholesky decomposition + sigma_out = self.sigma_dense(out) + # reshape to covariance matrix dimensions + sigma_out = np.reshape(sigma_out, (out.shape[0], out.shape[1], out.shape[2], self.num_mixture_components, self.data_shape[2], self.data_shape[2])) # reshape from b x h x w x components*num_channels**2 to b x h x w x components x num_channels x num_channels + # make a lower triangular matrix L for L L^T + sigma_out = np.tril(sigma_out) + # manually loop through the channel components to clip the diagonals TODO could be more intelligently done maybe? + for channel_idx in range(self.data_shape[2]): + # apply softplus to this diagonal + sigma_out = sigma_out.at[..., channel_idx, channel_idx].set(nn.softplus(sigma_out[..., channel_idx, channel_idx])) + # then clip the components TODO need to change sigma_min to not be 1 in the future when it's not image data + sigma_out = sigma_out.at[..., channel_idx, channel_idx].set(np.clip(sigma_out[..., channel_idx, channel_idx], self.sigma_min, self.train_data_std[channel_idx])) # TODO think about if there needs to be an absolute train_data_std + # now turn this into a covariance matrix + # transpose, swap the last two dimensions + sigma_out_transpose = np.einsum('...ij->...ji', sigma_out) + # multiply cov = L L^T + sigma = np.einsum('...ij, ...jk->...ik', sigma_out, sigma_out_transpose) + # add a small amount to the diagonal to make sure it's positive definite, 1e-6 + sigma = sigma + 1e-6 * np.eye(self.data_shape[2]) + + mix_logit = self.mix_logit_dense(out) # stays as b x h x w x n_components. there isn't a channel dimension for this one + + return mu, sigma, mix_logit + + + +class MultiChannelPixelCNN(MeasurementModel): + """ + The PixelCNN model for autoregressive image modeling. + + This class handles the training and evaluation of the PixelCNN model and wraps the Flax implementation + in a higher-level interface that conforms to the MeasurementModel class. It provides methods for fitting + the model to data, computing the negative log-likelihood of images, and generating new images. + + Attributes + ---------- + num_hidden_channels : int + The number of hidden channels in the model. + num_mixture_components : int + The number of components in the mixture density output. + """ + + def __init__(self, num_hidden_channels=64, num_mixture_components=40): + """ + Initialize the PixelCNN model with image shape, number of hidden channels, and mixture components. + + Parameters + ---------- + num_hidden_channels : int + Number of hidden channels in the convolutional layers. + num_mixture_components : int + Number of mixture components for the output layer. + """ + + super().__init__([MeasurementType.HW, MeasurementType.HWC], measurement_dtype=float) + self.num_hidden_channels = num_hidden_channels + self.num_mixture_components = num_mixture_components + self._flax_model = None + + def fit(self, train_images, condition_vectors=None, learning_rate=1e-2, max_epochs=200, steps_per_epoch=100, patience=40, + sigma_min=1, batch_size=64, num_val_samples=None, percent_samples_for_validation=0.1, do_lr_decay=False, verbose=True, + add_gaussian_noise=False, add_uniform_noise=True, model_seed=None, data_seed=None, use_positional_embedding=False, + # deprecated + seed=None,): + """ + Train the PixelCNN model on a dataset of images. + + Parameters + ---------- + train_images : ndarray + The input dataset, with shape (N, H, W, C). + condition_vectors : ndarray, optional + Vectors to condition the image generation process (e.g., class labels). + learning_rate : float, optional + The learning rate for optimization (default is 1e-2). + max_epochs : int, optional + The maximum number of training epochs (default is 200). + steps_per_epoch : int, optional + The number of steps per epoch (default is 100). + patience : int, optional + The number of epochs to wait before early stopping (default is 40). + sigma_min : float, optional + The minimum standard deviation for the mixture density output (default is 1). + batch_size : int, optional + The batch size for training (default is 64). + num_val_samples : int, optional + The number of validation samples. If None, a percentage is used (default is None). + percent_samples_for_validation : float, optional + The percentage of samples to use for validation (default is 0.1). + do_lr_decay : bool, optional + Whether to apply learning rate decay during training (default is False). + verbose : bool, optional + Whether to print progress during training (default is True). + add_gaussian_noise : bool, optional + Whether to add Gaussian noise to the training images (default is False). + add_uniform_noise : bool, optional + Whether to add uniform noise to the training images (default is True). + model_seed : int, optional + Seed for model initialization. + data_seed : int, optional + Seed for data shuffling. + + Returns + ------- + val_loss_history : list + A list of validation loss values for each epoch. + """ + if seed is not None: + warnings.warn("seed argument is deprecated. Use model_seed and data_seed instead") + model_seed = seed + data_seed = seed + + if model_seed is not None: + onp.random.seed(model_seed) + model_key = jax.random.PRNGKey(onp.random.randint(0, 100000)) + + if condition_vectors is not None: + warnings.warn("For multi-channel PixelCNN condition vectors have not been implemented or double checked.") + + self._validate_data(train_images) + + train_images = train_images.astype(np.float32) + + # check that only one type of noise is added + if add_gaussian_noise and add_uniform_noise: + raise ValueError("Only one type of noise can be added to the training data") + + num_val_samples = int(train_images.shape[0] * percent_samples_for_validation) if num_val_samples is None else num_val_samples + + # add trailing channel dimension if necessary + if train_images.ndim == 3: + train_images = train_images[..., np.newaxis] + + self.image_shape = train_images.shape[1:4] # 3D to include the image channels + + # Use the make dataset generators function because training data may be modified here during training + # (i.e. adding small amounts of noise to account for discrete data and continuous model) + _, dataset_fn = make_dataset_generators(train_images, batch_size=400, num_val_samples=train_images.shape[0], + add_gaussian_noise=add_gaussian_noise, add_uniform_noise=add_uniform_noise, + seed=data_seed) + example_images = dataset_fn().next() # TODO can make this batch size bigger if needed just to get the settings for the values in the following model initialization, currently at 400 + + if self._flax_model is None: + self.add_gaussian_noise = add_gaussian_noise + self.add_uniform_noise = add_uniform_noise + self._flax_model = _MultiChannelPixelCNNFlaxImpl(num_hidden_channels=self.num_hidden_channels, num_mixture_components=self.num_mixture_components, + train_data_mean=np.mean(example_images, axis=(0, 1, 2)), train_data_std=np.std(example_images, axis=(0, 1, 2)), + train_data_min=np.min(example_images, axis=(0, 1, 2)), train_data_max=np.max(example_images, axis=(0, 1, 2)), sigma_min=sigma_min, + condition_vector_size=None if condition_vectors is None else condition_vectors.shape[-1], + data_shape=train_images.shape[1:], use_positional_embedding=use_positional_embedding) + + # pass in an intial batch + initial_params = self._flax_model.init(model_key, train_images[:3], + condition_vectors[:3] if condition_vectors is not None else None) + + if do_lr_decay: + lr_schedule = optax.exponential_decay(init_value=learning_rate, + transition_steps=steps_per_epoch, + decay_rate=0.99,) + + self._optimizer = optax.adam(lr_schedule) + else: + self._optimizer = optax.adam(learning_rate) + + def apply_fn(params, x, condition_vector=None): + output = self._flax_model.apply(params, x, condition_vector) + return self._flax_model.compute_loss(*output, x) + + self._state = TrainState.create(apply_fn=apply_fn, params=initial_params, tx=self._optimizer) + + if condition_vectors is None: + + def loss_fn(params, state, imgs): + return state.apply_fn(params, imgs) + grad_fn = jax.value_and_grad(loss_fn) + + @jax.jit + def train_step(state, imgs): + """ + A standard gradient descent training step + """ + loss, grads = grad_fn(state.params, state, imgs) + state = state.apply_gradients(grads=grads) + return state, loss + else: + + def loss_fn(params, state, imgs, condition_vecs): + return state.apply_fn(params, imgs, condition_vecs) + grad_fn = jax.value_and_grad(loss_fn) + + @jax.jit + def train_step(state, imgs, condition_vecs): + """ + A standard gradient descent training step + """ + loss, grads = grad_fn(state.params, state, imgs, condition_vecs) + state = state.apply_gradients(grads=grads) + return state, loss + + + best_params, val_loss_history = train_model(train_images=train_images, condition_vectors=condition_vectors, train_step=train_step, + state=self._state, batch_size=batch_size, num_val_samples=int(num_val_samples), + add_gaussian_noise=add_gaussian_noise, add_uniform_noise=add_uniform_noise, + steps_per_epoch=steps_per_epoch, num_epochs=max_epochs, patience=patience, seed=data_seed, + verbose=verbose) + self._state = self._state.replace(params=best_params) + self.val_loss_history = val_loss_history + return val_loss_history + + + + def compute_negative_log_likelihood(self, data, conditioning_vecs=None, data_seed=None, average=True, verbose=True, seed=None): + """ + Compute the negative log-likelihood (NLL) of images under the trained PixelCNN model. + + Parameters + ---------- + data : ndarray + The input images for which to compute the NLL. + conditioning_vecs : ndarray, optional + Vectors to condition the image generation process (e.g., class labels). + data_seed : int, optional + Seed for data shuffling. + average : bool, optional + If True, return the average NLL over all images (default is True). + verbose : bool, optional + Whether to print progress (default is True). + seed : int, optional + Deprecated. Use data_seed instead. + + Returns + ------- + nll : float + The negative log-likelihood of the input images. + """ + # See superclass for docstring + if seed is not None: + warnings.warn("seed argument is deprecated. Use data_seed instead") + data_seed = seed + + if data.ndim == 3: + # add a trailing channel dimension if necessary + data = data[..., np.newaxis] + elif data.ndim == 2: + # add trailing channel and batch dimensions + data = data[np.newaxis, ..., np.newaxis] + + # check if data shape is different than image shape + if data.shape[1:4] != self.image_shape: + raise ValueError("Data shape is different than image shape of trained model. This is not yet supported" + "Expected {}, got {}".format(self.image_shape, data.shape[1:4])) + + # get test data generator. Here all data is "validation", because the data passed into this should already be + # (in the typical case) a test set + _, dataset_fn = make_dataset_generators(data, batch_size=32 if average else 1, num_val_samples=data.shape[0], + add_gaussian_noise=self.add_gaussian_noise, add_uniform_noise=self.add_uniform_noise, + condition_vectors=conditioning_vecs, seed=data_seed) + @jax.jit + def conditional_eval_step(state, imgs, condition_vecs): + return state.apply_fn(state.params, imgs, condition_vecs) + + return _evaluate_nll(dataset_fn(), self._state, return_average=average, + eval_step=conditional_eval_step if conditioning_vecs is not None else None, verbose=verbose) + + + def generate_samples(self, num_samples, conditioning_vecs=None, sample_shape=None, ensure_nonnegative=True, seed=None, verbose=True): + """ + Generate new images from the trained PixelCNN model by sampling pixel by pixel. + + Parameters + ---------- + num_samples : int + Number of images to generate. + conditioning_vecs : jax.Array, optional + Optional conditioning vectors. If provided, the shape should match + (num_samples, condition_vector_size). Default is None. + sample_shape : tuple of int or int, optional + Shape of the images to generate. If None, the model's image_shape is used. + If a single int is provided, it will be treated as a square shape. Default is None. + ensure_nonnegative : bool, optional + If True, ensure that the generated pixel values are non-negative. Default is True. + seed : int, optional + Random seed for reproducibility. Default is 123 if not provided. + verbose : bool, optional + If True, display progress during the generation process. Default is True. + + Returns + ------- + jax.Array + Generated images with the specified shape. + """ + if seed is None: + seed = 123 + key = jax.random.PRNGKey(seed) + if sample_shape is None: + sample_shape = self.image_shape + if type(sample_shape) == int: + sample_shape = (sample_shape, sample_shape) + + if conditioning_vecs is not None: + assert conditioning_vecs.shape[0] == num_samples + assert conditioning_vecs.shape[1] == self._flax_model.condition_vector_size + + sampled_images = onp.zeros((num_samples, *sample_shape)) + for i in tqdm(onp.arange(sample_shape[0]), desc='Generating PixelCNN samples') if verbose else np.arange(sample_shape[0]): + for j in onp.arange(sample_shape[1]): + i_limits = max(0, i - self.image_shape[0] + 1), max(self.image_shape[0], i+1) + j_limits = max(0, j - self.image_shape[1] + 1), max(self.image_shape[1], j+1) + + conditioning_images = sampled_images[:, i_limits[0]:i_limits[1], j_limits[0]:j_limits[1]] + i_in_cropped_image = i - i_limits[0] + j_in_cropped_image = j - j_limits[0] + + assert conditioning_images.shape[1:] == self.image_shape + + key, key2 = jax.random.split(key) + if conditioning_vecs is None: + mu, sigma, mix_logit = self._flax_model.apply(self._state.params, conditioning_images) + else: + mu, sigma, mix_logit = self._flax_model.apply(self._state.params, conditioning_images, conditioning_vecs) + # only sampling one pixel at a time + # make onp arrays for range checking + mu = onp.array(mu)[:, i_in_cropped_image, j_in_cropped_image, :] + sigma = onp.array(sigma)[:, i_in_cropped_image, j_in_cropped_image, :] + mix_logit = onp.array(mix_logit)[:, i_in_cropped_image, j_in_cropped_image, :] + + # mix_probs = np.exp(mix_logit - logsumexp(mix_logit, axis=-1, keepdims=True)) # this was commented out in 1D pixelcnn as well + component_indices = jax.random.categorical(key, mix_logit, axis=-1) + # draw categorical sample + sample_mus = mu[np.arange(num_samples), component_indices] + sample_sigmas = sigma[np.arange(num_samples), component_indices] + #sample = jax.random.normal(key2, shape=sample_mus.shape) * sample_sigmas + sample_mus # 1D pixelcnn version + # switching to a multivariate normal distribution for the sigmas + sample = jax.random.multivariate_normal(key2, sample_mus, sample_sigmas) + sampled_images[:, i, j] = sample + + if ensure_nonnegative: + sampled_images = np.where(sampled_images < 0, 0, sampled_images) + return sampled_images + From 59dda0093c277fb7266f01d0071d161d6e74062d Mon Sep 17 00:00:00 2001 From: Leyla Kabuli Date: Mon, 18 Nov 2024 12:55:52 -0800 Subject: [PATCH 2/5] renaming lensless helper function file and updating files accordingly --- ...ssification_plots_updated_mi_cifar10.ipynb | 2 +- ..._17_2024_pixelcnn_cifar10_all_lenses.ipynb | 2 +- ...eep_deconvolution_procedure_per_lens.ipynb | 2 +- .../02_12_2024_make_lenses_3D_edges.ipynb | 2 +- ...pixelcnn_cifar10_extra_photon_counts.ipynb | 2 +- ...onvolution_per_lens_extra_photon_counts.py | 2 +- ...lots_cifar10_all_systems_log_photons.ipynb | 2 +- ...4_04_2024_make_lenses_3D_edges_IDEAL.ipynb | 2 +- .../11_14_2023_run_classifier_cifar10.py | 2 +- lensless_imager/leyla_fns.py | 581 ------------------ 10 files changed, 9 insertions(+), 590 deletions(-) delete mode 100644 lensless_imager/leyla_fns.py diff --git a/lensless_imager/01_09_2024_mi_vs_classification_plots_updated_mi_cifar10.ipynb b/lensless_imager/01_09_2024_mi_vs_classification_plots_updated_mi_cifar10.ipynb index 9596c5e..5ef078d 100644 --- a/lensless_imager/01_09_2024_mi_vs_classification_plots_updated_mi_cifar10.ipynb +++ b/lensless_imager/01_09_2024_mi_vs_classification_plots_updated_mi_cifar10.ipynb @@ -30,7 +30,7 @@ "import sys\n", "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments')\n", - "from leyla_fns import *\n", + "from lensless_helpers import *\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", diff --git a/lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb b/lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb index b78d352..fe1b6a9 100644 --- a/lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb +++ b/lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb @@ -34,7 +34,7 @@ "from jax.scipy.special import logsumexp\n", "import numpy as onp\n", "\n", - "from leyla_fns import *" + "from lensless_helpers import *" ] }, { diff --git a/lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb b/lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb index 15c463a..75b5199 100644 --- a/lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb +++ b/lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb @@ -39,7 +39,7 @@ "from jax.scipy.special import logsumexp\n", "import numpy as np\n", "\n", - "from leyla_fns import *" + "from lensless_helpers import *" ] }, { diff --git a/lensless_imager/02_12_2024_make_lenses_3D_edges.ipynb b/lensless_imager/02_12_2024_make_lenses_3D_edges.ipynb index 899af57..bfbd03a 100644 --- a/lensless_imager/02_12_2024_make_lenses_3D_edges.ipynb +++ b/lensless_imager/02_12_2024_make_lenses_3D_edges.ipynb @@ -17,7 +17,7 @@ "import numpy as np \n", "import matplotlib.pyplot as plt\n", "import plotly\n", - "from leyla_fns import *\n", + "from lensless_helpers import *\n", "from cleanplots import *" ] }, diff --git a/lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb b/lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb index 9e57d71..efa614b 100644 --- a/lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb +++ b/lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb @@ -34,7 +34,7 @@ "from jax.scipy.special import logsumexp\n", "import numpy as onp\n", "\n", - "from leyla_fns import *" + "from lensless_helpers import *" ] }, { diff --git a/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py b/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py index 3da899f..7c07a7f 100644 --- a/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py +++ b/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py @@ -40,7 +40,7 @@ from jax.scipy.special import logsumexp import numpy as np -from leyla_fns import * +from EncodingInformation.lensless_imager.lensless_helpers import * # %% from encoding_information.image_utils import add_noise, extract_patches diff --git a/lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb b/lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb index 166c75e..3a81d83 100644 --- a/lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb +++ b/lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb @@ -30,7 +30,7 @@ "import sys\n", "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments')\n", - "from leyla_fns import *\n", + "from lensless_helpers import *\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", diff --git a/lensless_imager/04_04_2024_make_lenses_3D_edges_IDEAL.ipynb b/lensless_imager/04_04_2024_make_lenses_3D_edges_IDEAL.ipynb index 37f9cf8..64e76e3 100644 --- a/lensless_imager/04_04_2024_make_lenses_3D_edges_IDEAL.ipynb +++ b/lensless_imager/04_04_2024_make_lenses_3D_edges_IDEAL.ipynb @@ -18,7 +18,7 @@ "import plotly\n", "import sys\n", "sys.path.append('/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments')\n", - "from leyla_fns import *\n", + "from lensless_helpers import *\n", "from cleanplots import *\n" ] }, diff --git a/lensless_imager/11_14_2023_run_classifier_cifar10.py b/lensless_imager/11_14_2023_run_classifier_cifar10.py index 2f1d528..5839bc3 100644 --- a/lensless_imager/11_14_2023_run_classifier_cifar10.py +++ b/lensless_imager/11_14_2023_run_classifier_cifar10.py @@ -33,7 +33,7 @@ from jax.scipy.special import logsumexp import numpy as onp -from leyla_fns import * +from EncodingInformation.lensless_imager.lensless_helpers import * from encoding_information.image_utils import add_noise # %% diff --git a/lensless_imager/leyla_fns.py b/lensless_imager/leyla_fns.py deleted file mode 100644 index 131b4fe..0000000 --- a/lensless_imager/leyla_fns.py +++ /dev/null @@ -1,581 +0,0 @@ -import numpy as np # use regular numpy for now, simpler -import scipy -from tqdm import tqdm -# import tensorflow as tf -# import tensorflow.keras as tfk -import gc -import warnings - -import skimage -import skimage.io -from skimage.transform import resize - -# from tensorflow.keras.optimizers import SGD - -def tile_9_images(data_set): - # takes 9 images and forms a tiled image - assert len(data_set) == 9 - return np.block([[data_set[0], data_set[1], data_set[2]],[data_set[3], data_set[4], data_set[5]],[data_set[6], data_set[7], data_set[8]]]) - -def generate_random_tiled_data(x_set, y_set, seed_value=-1): - # takes a set of images and labels and returns a set of tiled images and corresponding labels - # the size of the output should be 3x the size of the input - vert_shape = x_set.shape[1] * 3 - horiz_shape = x_set.shape[2] * 3 - random_data = np.zeros((x_set.shape[0], vert_shape, horiz_shape)) # for mnist this was 84 x 84 - random_labels = np.zeros((y_set.shape[0], 1)) - if seed_value==-1: - np.random.seed() - else: - np.random.seed(seed_value) - for i in range(x_set.shape[0]): - img_items = np.random.choice(x_set.shape[0], size=9, replace=True) - data_set = x_set[img_items] - random_labels[i] = y_set[img_items[4]] - random_data[i] = tile_9_images(data_set) - return random_data, random_labels - -def generate_repeated_tiled_data(x_set, y_set): - # takes set of images and labels and returns a set of repeated tiled images and corresponding labels, no randomness - # the size of the output is 3x the size of the input, this essentially is a wrapper for np.tile - repeated_data = np.tile(x_set, (1, 3, 3)) - repeated_labels = y_set # the labels are just what they were - return repeated_data, repeated_labels - -def convolved_dataset(psf, random_tiled_data): - # takes a psf and a set of tiled images and returns a set of convolved images, convolved image size is 2n + 1? same size as the random data when it's cropped - # tile size is two images worth plus one extra index value - vert_shape = psf.shape[0] * 2 + 1 - horiz_shape = psf.shape[1] * 2 + 1 - psf_dataset = np.zeros((random_tiled_data.shape[0], vert_shape, horiz_shape)) # 57 x 57 for the case of mnist 28x28 images, 65 x 65 for the cifar 32 x 32 images - for i in tqdm(range(random_tiled_data.shape[0])): - psf_dataset[i] = scipy.signal.fftconvolve(psf, random_tiled_data[i], mode='valid') - return psf_dataset - -def compute_entropy(eigenvalues): - sum_log_evs = np.sum(np.log2(eigenvalues)) - D = eigenvalues.shape[0] - gaussian_entropy = 0.5 * (sum_log_evs + D * np.log2(2 * np.pi * np.e)) - return gaussian_entropy - -def add_shot_noise(photon_scaled_images, photon_fraction=None, photons_per_pixel=None, assume_noiseless=True, seed_value=-1): - #adapted from henry, also uses a seed though - if seed_value==-1: - np.random.seed() - else: - np.random.seed(seed_value) - - # check all pixels greater than 0 - if np.any(photon_scaled_images < 0): - #warning about negative - warnings.warn(f"Negative pixel values detected. Clipping to 0.") - photon_scaled_images[photon_scaled_images < 0] = 0 - if photons_per_pixel is not None: - if photons_per_pixel > np.mean(photon_scaled_images): - warnings.warn(f"photons_per_pixel is greater than actual photon count ({photons_per_pixel}). Clipping to {np.mean(photon_scaled_images)}") - photons_per_pixel = np.mean(photon_scaled_images) - photon_fraction = photons_per_pixel / np.mean(photon_scaled_images) - - if photon_fraction > 1: - warnings.warn(f"photon_fraction is greater than 1 ({photon_fraction}). Clipping to 1.") - photon_fraction = 1 - - if assume_noiseless: - additional_sd = np.sqrt(photon_fraction * photon_scaled_images) - if np.any(np.isnan(additional_sd)): - warnings.warn('There are nans here') - additional_sd[np.isnan(additional_sd)] = 0 - # something here goes weird for RML - # - #else: - # additional_sd = np.sqrt(photon_fraction * photon_scaled_images) - photon_fraction * np.sqrt(photon_scaled_images) - simulated_images = photon_scaled_images * photon_fraction + additional_sd * np.random.randn(*photon_scaled_images.shape) - positive = np.array(simulated_images) - positive[positive < 0] = 0 # cant have negative counts - return np.array(positive) - -def tf_cast(data): - # normalizes data, loads it to a tensorflow array of type float32 - return tf.cast(data / np.max(data), tf.float32) -def tf_labels(labels): - # loads labels to a tensorflow array of type int64 - return tf.cast(labels, tf.int64) - - - -def run_model_simple(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1): - if seed_value == -1: - seed_val = np.random.randint(10, 1000) - tfk.utils.set_random_seed(seed_val) - else: - tfk.utils.set_random_seed(seed_value) - - model = tfk.models.Sequential() - model.add(tfk.layers.Flatten()) - model.add(tfk.layers.Dense(256, activation='relu')) - model.add(tfk.layers.Dense(256, activation='relu')) - model.add(tfk.layers.Dense(10, activation='softmax')) - - model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) - - early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option - mode="min", patience=5, - restore_best_weights=True, verbose=1) - history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), batch_size=32, epochs=50, callbacks=[early_stop]) - test_loss, test_acc = model.evaluate(test_data, test_labels) - return history, model, test_loss, test_acc - -def run_model_cnn(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1): - # structure from https://www.kaggle.com/code/cdeotte/how-to-choose-cnn-architecture-mnist - if seed_value == -1: - seed_val = np.random.randint(10, 1000) - tfk.utils.set_random_seed(seed_val) - else: - tfk.utils.set_random_seed(seed_value) - - model = tfk.models.Sequential() - model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu', input_shape=(57, 57, 1))) #64 and 128 works very slightly better - model.add(tfk.layers.MaxPool2D()) - model.add(tfk.layers.Conv2D(128, kernel_size=5, padding='same', activation='relu')) - model.add(tfk.layers.MaxPool2D()) - #model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu')) - #model.add(tfk.layers.MaxPool2D(padding='same')) - model.add(tfk.layers.Flatten()) - - #model.add(tfk.layers.Dense(256, activation='relu')) - model.add(tfk.layers.Dense(128, activation='relu')) - model.add(tfk.layers.Dense(10, activation='softmax')) - - model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) - - early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option - mode="min", patience=5, - restore_best_weights=True, verbose=1) - history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=50, batch_size=32, callbacks=[early_stop]) #validation data is not test data - test_loss, test_acc = model.evaluate(test_data, test_labels) - return history, model, test_loss, test_acc - -def seeded_permutation(seed_value, n): - # given fixed seed returns permutation order - np.random.seed(seed_value) - permutation_order = np.random.permutation(n) - return permutation_order - -def segmented_indices(permutation_order, n, training_fraction, test_fraction): - #given permutation order returns indices for each of the three sets - training_indices = permutation_order[:int(training_fraction*n)] - test_indices = permutation_order[int(training_fraction*n):int((training_fraction+test_fraction)*n)] - validation_indices = permutation_order[int((training_fraction+test_fraction)*n):] - return training_indices, test_indices, validation_indices - -def permute_data(data, labels, seed_value, training_fraction=0.8, test_fraction=0.1): - #validation fraction is implicit, if including a validation set, expect to use the remaining fraction of the data - permutation_order = seeded_permutation(seed_value, data.shape[0]) - training_indices, test_indices, validation_indices = segmented_indices(permutation_order, data.shape[0], training_fraction, test_fraction) - - training_data = data[training_indices] - training_labels = labels[training_indices] - testing_data = data[test_indices] - testing_labels = labels[test_indices] - validation_data = data[validation_indices] - validation_labels = labels[validation_indices] - - return (training_data, training_labels), (testing_data, testing_labels), (validation_data, validation_labels) - -def add_gaussian_noise(data, noise_level, seed_value=-1): - if seed_value==-1: - np.random.seed() - else: - np.random.seed(seed_value) - return data + noise_level * np.random.randn(*data.shape) - -def confidence_bars(data_array, noise_length, confidence_interval=0.95): - # can also use confidence interval 0.9 or 0.99 if want slightly different bounds - error_lo = np.percentile(data_array, 100 * (1 - confidence_interval) / 2, axis=1) - error_hi = np.percentile(data_array, 100 * (1 - (1 - confidence_interval) / 2), axis=1) - mean = np.mean(data_array, axis=1) - assert len(error_lo) == len(mean) == len(error_hi) == noise_length - return error_lo, error_hi, mean - - -######### This function is very outdated, don't use it!! used to be called test_system use the ones below instead -######### -def test_system_old(noise_level, psf_name, model_name, seed_values, data, labels, training_fraction, testing_fraction, diffuser_region, phlat_region, psf, noise_type, rml_region): - # runs the model for the number of seeds given, returns the test accuracy for each seed - test_accuracy_list = [] - for seed_value in seed_values: - seed_value = int(seed_value) - tfk.backend.clear_session() - gc.collect() - tfk.utils.set_random_seed(seed_value) # set random seed out here too? - training, testing, validation = permute_data(data, labels, seed_value, training_fraction, testing_fraction) - x_train, y_train = training - x_test, y_test = testing - x_validation, y_validation = validation - - random_test_data, random_test_labels = generate_random_tiled_data(x_test, y_test, seed_value) - random_train_data, random_train_labels = generate_random_tiled_data(x_train, y_train, seed_value) - random_valid_data, random_valid_labels = generate_random_tiled_data(x_validation, y_validation, seed_value) - - if psf_name == 'uc': - test_data = random_test_data[:, 14:-13, 14:-13] - train_data = random_train_data[:, 14:-13, 14:-13] - valid_data = random_valid_data[:, 14:-13, 14:-13] - if psf_name == 'psf_4': - test_data = convolved_dataset(psf, random_test_data) - train_data = convolved_dataset(psf, random_train_data) - valid_data = convolved_dataset(psf, random_valid_data) - if psf_name == 'diffuser': - test_data = convolved_dataset(diffuser_region, random_test_data) - train_data = convolved_dataset(diffuser_region, random_train_data) - valid_data = convolved_dataset(diffuser_region, random_valid_data) - if psf_name == 'phlat': - test_data = convolved_dataset(phlat_region, random_test_data) - train_data = convolved_dataset(phlat_region, random_train_data) - valid_data = convolved_dataset(phlat_region, random_valid_data) - # 6/19/23 added RML option - if psf_name == 'rml': - test_data = convolved_dataset(rml_region, random_test_data) - train_data = convolved_dataset(rml_region, random_train_data) - valid_data = convolved_dataset(rml_region, random_valid_data) - - # address any tiny floating point negative values, which only occur in RML data - if np.any(test_data < 0): - #print('negative values in test data for {} psf'.format(psf_name)) - test_data[test_data < 0] = 0 - if np.any(train_data < 0): - #print('negative values in train data for {} psf'.format(psf_name)) - train_data[train_data < 0] = 0 - if np.any(valid_data < 0): - #print('negative values in valid data for {} psf'.format(psf_name)) - valid_data[valid_data < 0] = 0 - - - # additive gaussian noise, add noise after convolving, fixed 5/15/2023 - if noise_type == 'gaussian': - test_data = add_gaussian_noise(test_data, noise_level, seed_value) - train_data = add_gaussian_noise(train_data, noise_level, seed_value) - valid_data = add_gaussian_noise(valid_data, noise_level, seed_value) - if noise_type == 'poisson': - test_data = add_shot_noise(test_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True) - train_data = add_shot_noise(train_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True) - valid_data = add_shot_noise(valid_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True) - - train_data, test_data, valid_data = tf_cast(train_data), tf_cast(test_data), tf_cast(valid_data) - random_train_labels, random_test_labels, random_valid_labels = tf_labels(random_train_labels), tf_labels(random_test_labels), tf_labels(random_valid_labels) - - if model_name == 'simple': - history, model, test_loss, test_acc = run_model_simple(train_data, random_train_labels, test_data, random_test_labels, valid_data, random_valid_labels, seed_value) - if model_name == 'cnn': - history, model, test_loss, test_acc = run_model_cnn(train_data, random_train_labels, test_data, random_test_labels, valid_data, random_valid_labels, seed_value) - test_accuracy_list.append(test_acc) - np.save('classification_results_rml_psf_619/test_accuracy_{}_noise_{}_{}_psf_{}_model.npy'.format(noise_level, noise_type, psf_name, model_name), test_accuracy_list) - - ###### CNN for 32x32 CIFAR10 images - # Originally written 11/14/2023, but then lost in a merge, recopied 1/14/2024 -def run_model_cnn_cifar(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=5): - # structure from https://www.kaggle.com/code/cdeotte/how-to-choose-cnn-architecture-mnist - # default architecture is 50 epochs and patience 5, but recently some need longer patience - if seed_value == -1: - seed_val = np.random.randint(10, 1000) - tfk.utils.set_random_seed(seed_val) - else: - tfk.utils.set_random_seed(seed_value) - model = tfk.models.Sequential() - model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu', input_shape=(65, 65, 1))) - model.add(tfk.layers.MaxPool2D()) - model.add(tfk.layers.Conv2D(128, kernel_size=5, padding='same', activation='relu')) - model.add(tfk.layers.MaxPool2D()) - #model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu')) - #model.add(tfk.layers.MaxPool2D(padding='same')) - model.add(tfk.layers.Flatten()) - - #model.add(tfk.layers.Dense(256, activation='relu')) - model.add(tfk.layers.Dense(128, activation='relu')) - model.add(tfk.layers.Dense(10, activation='softmax')) - - model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) - - early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option - mode="min", patience=patience, - restore_best_weights=True, verbose=1) - history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data - test_loss, test_acc = model.evaluate(test_data, test_labels) - return history, model, test_loss, test_acc - -def make_ttv_sets(data, labels, seed_value, training_fraction, testing_fraction): - training, testing, validation = permute_data(data, labels, seed_value, training_fraction, testing_fraction) - training_data, training_labels = training - testing_data, testing_labels = testing - validation_data, validation_labels = validation - training_data, testing_data, validation_data = tf_cast(training_data), tf_cast(testing_data), tf_cast(validation_data) - training_labels, testing_labels, validation_labels = tf_labels(training_labels), tf_labels(testing_labels), tf_labels(validation_labels) - return (training_data, training_labels), (testing_data, testing_labels), (validation_data, validation_labels) - -def run_network_cifar(data, labels, seed_value, training_fraction, testing_fraction, mode='cnn', max_epochs=50, patience=5): - # small modification to be able to run 32x32 image data - training, testing, validation = make_ttv_sets(data, labels, seed_value, training_fraction, testing_fraction) - if mode == 'cnn': - history, model, test_loss, test_acc = run_model_cnn_cifar(training[0], training[1], - testing[0], testing[1], - validation[0], validation[1], seed_value, max_epochs, patience) - elif mode == 'simple': - history, model, test_loss, test_acc = run_model_simple(training[0], training[1], - testing[0], testing[1], - validation[0], validation[1], seed_value) - elif mode == 'new_cnn': - history, model, test_loss, test_acc = current_testing_model(training[0], training[1], - testing[0], testing[1], - validation[0], validation[1], seed_value, max_epochs, patience) - elif mode == 'mom_cnn': - history, model, test_loss, test_acc = momentum_testing_model(training[0], training[1], - testing[0], testing[1], - validation[0], validation[1], seed_value, max_epochs, patience) - return history, model, test_loss, test_acc - - -def load_diffuser_psf(): - diffuser_psf = skimage.io.imread('psfs/diffuser_psf.png') - diffuser_psf = diffuser_psf[:,:,1] - diffuser_resize = diffuser_psf[200:500, 250:550] - diffuser_resize = resize(diffuser_resize, (100, 100), anti_aliasing=True) #resize(diffuser_psf, (28, 28)) - diffuser_region = diffuser_resize[:28, :28] - diffuser_region /= np.sum(diffuser_region) - return diffuser_region - -def load_phlat_psf(): - phlat_psf = skimage.io.imread('psfs/phlat_psf.png') - phlat_psf = phlat_psf[900:2900, 1500:3500, 1] - phlat_psf = resize(phlat_psf, (200, 200), anti_aliasing=True) - phlat_region = phlat_psf[10:38, 20:48] - phlat_region /= np.sum(phlat_region) - return phlat_region - -def load_4_psf(): - psf = np.zeros((28, 28)) - psf[20,20] = 1 - psf[15, 10] = 1 - psf[5, 13] = 1 - psf[23, 6] = 1 - psf = scipy.ndimage.gaussian_filter(psf, sigma=1) - psf /= np.sum(psf) - return psf - -# 6/9/23 added rml option -def load_rml_psf(): - rml_psf = skimage.io.imread('psfs/psf_8holes.png') - rml_psf = rml_psf[1000:3000, 1500:3500] - rml_psf_resize = resize(rml_psf, (100, 100), anti_aliasing=True) - rml_psf_region = rml_psf_resize[40:100, :60] - rml_psf_region = resize(rml_psf_region, (28, 28), anti_aliasing=True) - rml_psf_region /= np.sum(rml_psf_region) - return rml_psf_region - -def load_rml_new_psf(): - rml_psf = skimage.io.imread('psfs/psf_8holes.png') - rml_psf = rml_psf[1000:3000, 1500:3500] - rml_psf_small = resize(rml_psf, (85, 85), anti_aliasing=True) - rml_psf_region = rml_psf_small[52:80, 10:38] - rml_psf_region /= np.sum(rml_psf_region) - return rml_psf_region - -def load_single_lens(): - one_lens = np.zeros((28, 28)) - one_lens[14, 14] = 1 - one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8) - one_lens /= np.sum(one_lens) - return one_lens - -def load_two_lens(): - two_lens = np.zeros((28, 28)) - two_lens[10, 10] = 1 - two_lens[20, 20] = 1 - two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8) - two_lens /= np.sum(two_lens) - return two_lens - -def load_three_lens(): - three_lens = np.zeros((28, 28)) - three_lens[8, 12] = 1 - three_lens[16, 20] = 1 - three_lens[20, 7] = 1 - three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8) - three_lens /= np.sum(three_lens) - return three_lens - - -def load_single_lens_32(): - one_lens = np.zeros((32, 32)) - one_lens[16, 16] = 1 - one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8) - one_lens /= np.sum(one_lens) - return one_lens - -def load_two_lens_32(): - two_lens = np.zeros((32, 32)) - two_lens[10, 10] = 1 - two_lens[21, 21] = 1 - two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8) - two_lens /= np.sum(two_lens) - return two_lens - -def load_three_lens_32(): - three_lens = np.zeros((32, 32)) - three_lens[9, 12] = 1 - three_lens[17, 22] = 1 - three_lens[24, 8] = 1 - three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8) - three_lens /= np.sum(three_lens) - return three_lens - -def load_four_lens_32(): - psf = np.zeros((32, 32)) - psf[22, 22] = 1 - psf[15, 10] = 1 - psf[5, 12] = 1 - psf[28, 8] = 1 - psf = scipy.ndimage.gaussian_filter(psf, sigma=1) # note that this one is sigma 1, for mnist it's sigma 0.8 - psf /= np.sum(psf) - return psf - -def load_diffuser_32(): - diffuser_psf = skimage.io.imread('psfs/diffuser_psf.png') - diffuser_psf = diffuser_psf[:,:,1] - diffuser_resize = diffuser_psf[200:500, 250:550] - diffuser_resize = resize(diffuser_resize, (100, 100), anti_aliasing=True) #resize(diffuser_psf, (28, 28)) - diffuser_region = diffuser_resize[:32, :32] - diffuser_region /= np.sum(diffuser_region) - return diffuser_region - - - -### 10/15/2023: Make new versions of the model functions that train with Datasets - first attempt failed - -# lenses with centralized positions for use in task-specific estimations -def load_single_lens_uniform(size=32): - one_lens = np.zeros((size, size)) - one_lens[16, 16] = 1 - one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8) - one_lens /= np.sum(one_lens) - return one_lens - -def load_two_lens_uniform(size=32): - two_lens = np.zeros((size, size)) - two_lens[16, 16] = 1 - two_lens[7, 9] = 1 - two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8) - two_lens /= np.sum(two_lens) - return two_lens - -def load_three_lens_uniform(size=32): - three_lens = np.zeros((size, size)) - three_lens[16, 16] = 1 - three_lens[7, 9] = 1 - three_lens[23, 21] = 1 - three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8) - three_lens /= np.sum(three_lens) - return three_lens - -def load_four_lens_uniform(size=32): - four_lens = np.zeros((size, size)) - four_lens[16, 16] = 1 - four_lens[7, 9] = 1 - four_lens[23, 21] = 1 - four_lens[8, 24] = 1 - four_lens = scipy.ndimage.gaussian_filter(four_lens, sigma=0.8) - four_lens /= np.sum(four_lens) - return four_lens -def load_five_lens_uniform(size=32): - five_lens = np.zeros((size, size)) - five_lens[16, 16] = 1 - five_lens[7, 9] = 1 - five_lens[23, 21] = 1 - five_lens[8, 24] = 1 - five_lens[21, 5] = 1 - five_lens = scipy.ndimage.gaussian_filter(five_lens, sigma=0.8) - five_lens /= np.sum(five_lens) - return five_lens - - - -## 01/24/2024 new CNN that's slightly deeper -def current_testing_model(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=20): - # structure from https://www.kaggle.com/code/amyjang/tensorflow-cifar10-cnn-tutorial - - if seed_value == -1: - seed_val = np.random.randint(10, 1000) - tfk.utils.set_random_seed(seed_val) - else: - tfk.utils.set_random_seed(seed_value) - - model = tf.keras.models.Sequential([ - tf.keras.layers.Conv2D(32, kernel_size=5, padding='same', input_shape=(65, 65, 1), activation='relu'), - tf.keras.layers.Conv2D(32, kernel_size=5, activation='relu'), - tf.keras.layers.MaxPooling2D(), - tf.keras.layers.Dropout(0.25), - - tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'), - tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu'), - tf.keras.layers.MaxPooling2D(), - tf.keras.layers.Dropout(0.25), - - tf.keras.layers.Flatten(), - tf.keras.layers.Dense(512, activation='relu'), - tf.keras.layers.Dropout(0.5), - tf.keras.layers.Dense(128, activation='relu'), - tf.keras.layers.Dense(10, activation='softmax'), - ]) - - model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) - - early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option - mode="min", patience=patience, - restore_best_weights=True, verbose=1) - print(model.optimizer.get_config()) - - history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data - test_loss, test_acc = model.evaluate(test_data, test_labels) - return history, model, test_loss, test_acc - - - - -## 01/24/2024 new CNN that's slightly deeper -def momentum_testing_model(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=20): - # structure from https://www.kaggle.com/code/amyjang/tensorflow-cifar10-cnn-tutorial - # includes nesterov momentum feature, rather than regular momentum - if seed_value == -1: - seed_val = np.random.randint(10, 1000) - tfk.utils.set_random_seed(seed_val) - else: - tfk.utils.set_random_seed(seed_value) - - model = tf.keras.models.Sequential([ - tf.keras.layers.Conv2D(32, kernel_size=5, padding='same', input_shape=(65, 65, 1), activation='relu'), - tf.keras.layers.Conv2D(32, kernel_size=5, activation='relu'), - tf.keras.layers.MaxPooling2D(), - tf.keras.layers.Dropout(0.25), - - tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'), - tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu'), - tf.keras.layers.MaxPooling2D(), - tf.keras.layers.Dropout(0.25), - - tf.keras.layers.Flatten(), - tf.keras.layers.Dense(512, activation='relu'), - tf.keras.layers.Dropout(0.5), - tf.keras.layers.Dense(128, activation='relu'), - tf.keras.layers.Dense(10, activation='softmax'), - ]) - - model.compile(optimizer=SGD(momentum=0.9, nesterov=True), loss='sparse_categorical_crossentropy', metrics=['accuracy']) - - early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option - mode="min", patience=patience, - restore_best_weights=True, verbose=1) - - print(model.optimizer.get_config()) - - history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data - test_loss, test_acc = model.evaluate(test_data, test_labels) - return history, model, test_loss, test_acc \ No newline at end of file From a952b060182e38d0d19029bc1aa89c7a938dcaaa Mon Sep 17 00:00:00 2001 From: Leyla Kabuli Date: Mon, 18 Nov 2024 13:06:28 -0800 Subject: [PATCH 3/5] fixed incorrect import on lensless imager helper function --- ...4_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py | 2 +- lensless_imager/11_14_2023_run_classifier_cifar10.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py b/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py index 7c07a7f..04ded8d 100644 --- a/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py +++ b/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py @@ -40,7 +40,7 @@ from jax.scipy.special import logsumexp import numpy as np -from EncodingInformation.lensless_imager.lensless_helpers import * +from lensless_helpers import * # %% from encoding_information.image_utils import add_noise, extract_patches diff --git a/lensless_imager/11_14_2023_run_classifier_cifar10.py b/lensless_imager/11_14_2023_run_classifier_cifar10.py index 5839bc3..7893d27 100644 --- a/lensless_imager/11_14_2023_run_classifier_cifar10.py +++ b/lensless_imager/11_14_2023_run_classifier_cifar10.py @@ -33,7 +33,7 @@ from jax.scipy.special import logsumexp import numpy as onp -from EncodingInformation.lensless_imager.lensless_helpers import * +from lensless_helpers import * from encoding_information.image_utils import add_noise # %% From dfdd16eb36d093b577af975794ff93ef2cc5b2fe Mon Sep 17 00:00:00 2001 From: Leyla Kabuli Date: Mon, 18 Nov 2024 13:14:57 -0800 Subject: [PATCH 4/5] removing old api lensless mi and deconvolution runs --- ..._17_2024_pixelcnn_cifar10_all_lenses.ipynb | 175 ----- ...eep_deconvolution_procedure_per_lens.ipynb | 335 --------- ...pixelcnn_cifar10_extra_photon_counts.ipynb | 185 ----- ...onvolution_per_lens_extra_photon_counts.py | 195 ----- ...lots_cifar10_all_systems_log_photons.ipynb | 708 ------------------ 5 files changed, 1598 deletions(-) delete mode 100644 lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb delete mode 100644 lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb delete mode 100644 lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb delete mode 100644 lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py delete mode 100644 lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb diff --git a/lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb b/lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb deleted file mode 100644 index fe1b6a9..0000000 --- a/lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb +++ /dev/null @@ -1,175 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "2f0168a5", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import os\n", - "from jax import config\n", - "config.update(\"jax_enable_x64\", True)\n", - "import sys\n", - "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", - "sys.path.append('/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments/')\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '2'\n", - "from encoding_information.gpu_utils import limit_gpu_memory_growth\n", - "limit_gpu_memory_growth()\n", - "\n", - "# import tensorflow_datasets as tfds # TFDS for MNIST #TODO INSTALL AGAIN LATER\n", - "#import tensorflow as tf # TensorFlow operations\n", - "\n", - "\n", - "\n", - "# from image_distribution_models import PixelCNN\n", - "\n", - "from cleanplots import *\n", - "import jax.numpy as np\n", - "from jax.scipy.special import logsumexp\n", - "import numpy as onp\n", - "\n", - "from lensless_helpers import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "34552381", - "metadata": {}, - "outputs": [], - "source": [ - "from encoding_information.image_utils import add_noise, extract_patches\n", - "from encoding_information.models.gaussian_process import StationaryGaussianProcess\n", - "from encoding_information.models.pixel_cnn import PixelCNN\n", - "from encoding_information.information_estimation import estimate_mutual_information" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Sweep Photon Count and Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# load the PSFs\n", - "\n", - "one_psf = load_single_lens_uniform(32)\n", - "two_psf = load_two_lens_uniform(32)\n", - "three_psf = load_three_lens_uniform(32)\n", - "four_psf = load_four_lens_uniform(32)\n", - "five_psf = load_five_lens_uniform(32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# set seed values for reproducibility\n", - "seed_values_full = np.arange(1, 5)\n", - "\n", - "# set photon properties \n", - "bias = 10 # in photons\n", - "mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300]\n", - "\n", - "# set eligible psfs\n", - "\n", - "# psf_patterns = [None, one_psf, four_psf, diffuser_psf]\n", - "# psf_names = ['uc', 'one', 'four', 'diffuser']\n", - "psf_patterns = [one_psf, two_psf, three_psf, four_psf, five_psf]\n", - "psf_names = ['one', 'two', 'three', 'four', 'five']\n", - "# MI estimator parameters \n", - "patch_size = 32\n", - "num_patches = 10000\n", - "bs = 500\n", - "max_epochs = 50" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for photon_count in mean_photon_count_list:\n", - " for index, psf_pattern in enumerate(psf_patterns):\n", - " gaussian_mi_estimates = []\n", - " pixelcnn_mi_estimates = []\n", - " print('Mean photon count: {}, PSF: {}'.format(photon_count, psf_names[index]))\n", - " for seed_value in seed_values_full:\n", - " # load dataset\n", - " (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data()\n", - " data = onp.concatenate((x_train, x_test), axis=0) # make one big glob of data\n", - " data = data.astype(np.float32)\n", - " data /= onp.mean(data)\n", - " data *= photon_count # convert to photons with mean photon_count\n", - " labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. \n", - " # for CIFAR 100, need to convert images to grayscale\n", - " if len(data.shape) == 4:\n", - " data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale\n", - " data = data.squeeze()\n", - " # make tiled data\n", - " random_data, random_labels = generate_random_tiled_data(data, labels, seed_value)\n", - " \n", - " if psf_pattern is None:\n", - " start_idx = data.shape[-1] // 2\n", - " end_idx = data.shape[-1] // 2 - 1 \n", - " psf_data = random_data[:, start_idx:-end_idx, start_idx:-end_idx]\n", - " else:\n", - " psf_data = convolved_dataset(psf_pattern, random_data)\n", - " # add bias to data \n", - " psf_data += bias\n", - " # make patches and add noise\n", - " psf_data_patch = extract_patches(psf_data, patch_size=patch_size, num_patches=num_patches, seed=seed_value)\n", - " psf_data_shot_patch = add_noise(psf_data_patch, seed=seed_value, batch_size=bs)\n", - " # compute gaussian MI estimate, use comparison clean images\n", - " mi_gaussian_psf = estimate_mutual_information(psf_data_shot_patch, clean_images=psf_data_patch, entropy_model='gaussian',\n", - " max_epochs=max_epochs, verbose=True)\n", - " # compute PixelCNN MI estimate, use comparison clean images\n", - " mi_pixelcnn_psf = estimate_mutual_information(psf_data_shot_patch, clean_images=psf_data_patch, entropy_model='pixelcnn', num_val_samples=1000,\n", - " max_epochs=max_epochs, do_lr_decay=True, verbose=True)\n", - " gaussian_mi_estimates.append(mi_gaussian_psf)\n", - " pixelcnn_mi_estimates.append(mi_pixelcnn_psf)\n", - " #np.save('cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(pixelcnn_mi_estimates))\n", - " #np.save('cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(gaussian_mi_estimates))\n", - " # save the results once the seeds are done, file includes photon count and psf name\n", - " #np.save('cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(pixelcnn_mi_estimates))\n", - " #np.save('cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(gaussian_mi_estimates))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "phenotypes", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb b/lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb deleted file mode 100644 index 75b5199..0000000 --- a/lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb +++ /dev/null @@ -1,335 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Sweeping Wiener Deconvolution, 01/24/2024\n", - "\n", - "When you randomly tile, you can make the problem much harder for deconvolution. Info is getting pushed out of the FOV and info is getting pulled into the FOV without knowing where it came from. Cropped convolution ends up being a compressive sensing problem. Instead, doing the reconstruction on the padded FOV including the center 32x32 region with a black border. \n", - "\n", - "There is no bias in this system. However, poisson noise is being added at each photon count." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import os\n", - "from jax import config\n", - "config.update(\"jax_enable_x64\", True)\n", - "import sys\n", - "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", - "sys.path.append('/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments/')\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", - "from encoding_information.gpu_utils import limit_gpu_memory_growth\n", - "limit_gpu_memory_growth()\n", - "\n", - "# from image_distribution_models import PixelCNN\n", - "\n", - "from cleanplots import *\n", - "#import jax.numpy as np\n", - "from jax.scipy.special import logsumexp\n", - "import numpy as np\n", - "\n", - "from lensless_helpers import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from encoding_information.image_utils import add_noise, extract_patches\n", - "from encoding_information.models.gaussian_process import StationaryGaussianProcess\n", - "from encoding_information.models.pixel_cnn import PixelCNN\n", - "from encoding_information.information_estimation import estimate_mutual_information" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from skimage.restoration import wiener, unsupervised_wiener, richardson_lucy\n", - "import skimage.metrics as skm\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# load the PSFs\n", - "\n", - "diffuser_psf = load_diffuser_32()\n", - "one_psf = load_single_lens_uniform(32)\n", - "two_psf = load_two_lens_uniform(32)\n", - "three_psf = load_three_lens_uniform(32)\n", - "four_psf = load_four_lens_uniform(32)\n", - "five_psf = load_five_lens_uniform(32)\n", - "aperture_psf = np.copy(diffuser_psf)\n", - "aperture_psf[:5] = 0\n", - "aperture_psf[-5:] = 0\n", - "aperture_psf[:,:5] = 0\n", - "aperture_psf[:,-5:] = 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def compute_skm_metrics(gt, recon):\n", - " # takes in already normalized gt\n", - " mse = skm.mean_squared_error(gt, recon)\n", - " psnr = skm.peak_signal_noise_ratio(gt, recon)\n", - " nmse = skm.normalized_root_mse(gt, recon)\n", - " ssim = skm.structural_similarity(gt, recon, data_range=1)\n", - " return mse, psnr, nmse, ssim" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# set seed values for reproducibility\n", - "seed_values_full = np.arange(1, 4)\n", - "\n", - "# set photon properties \n", - "mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300]\n", - "\n", - "# set eligible psfs\n", - "\n", - "psf_patterns = [None, one_psf, two_psf, three_psf, four_psf, five_psf, diffuser_psf, aperture_psf]\n", - "psf_names = ['uc', 'one', 'two', 'three', 'four', 'five', 'diffuser', 'aperture']\n", - "\n", - "# MI estimator parameters \n", - "patch_size = 32\n", - "num_patches = 10000\n", - "bs = 500\n", - "max_epochs = 50" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "psf_patterns_use = [one_psf, two_psf, three_psf, four_psf, five_psf, diffuser_psf, aperture_psf]\n", - "psf_names_use = ['one', 'two', 'three', 'four', 'five', 'diffuser', 'aperture']\n", - "\n", - "mean_photon_count_list = [300, 250, 200, 150, 100, 80, 60, 40, 20]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for photon_count in mean_photon_count_list:\n", - " for psf_idx, psf_use in enumerate(psf_patterns_use):\n", - " print('PSF: {}, Photon Count: {}'.format(psf_names_use[psf_idx], photon_count))\n", - " seed_value = 1\n", - " # make the data and scale by the photon count \n", - " (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data()\n", - " data = np.concatenate((x_train, x_test), axis=0) # make one big glob of data\n", - " data = data.astype(np.float64)\n", - " data /= np.mean(data)\n", - " data *= photon_count # convert to photons with mean value photon_count\n", - " max_val = np.max(data)\n", - " labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. \n", - " # for CIFAR 100, need to convert images to grayscale\n", - " if len(data.shape) == 4:\n", - " data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale\n", - " data = data.squeeze()\n", - " # zero pad data to be 96 x 96\n", - " data_padded = np.zeros((data.shape[0], 96, 96))\n", - " data_padded[:, 32:64, 32:64] = data\n", - "\n", - " convolved_data = convolved_dataset(psf_use, data_padded)\n", - " convolved_data_noise = add_noise(convolved_data)\n", - " # output of this noisy data is a jax array of float32, correct to regular numpy and float64\n", - " convolved_data_noise = np.array(convolved_data_noise).astype(np.float64)\n", - "\n", - " mse_psf = []\n", - " psnr_psf = []\n", - " for i in range(convolved_data_noise.shape[0]):\n", - " recon, _ = unsupervised_wiener(convolved_data_noise[i] / max_val, psf_use)\n", - " recon = recon[17:49, 17:49] #this is the crop window to look at\n", - " mse = skm.mean_squared_error(data[i] / max_val, recon)\n", - " psnr = skm.peak_signal_noise_ratio(data[i] / max_val, recon)\n", - " mse_psf.append(mse)\n", - " psnr_psf.append(psnr)\n", - " print('PSF: {}, Mean MSE: {}, Mean PSNR: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf)))\n", - " #np.save('unsupervised_wiener_deconvolution/recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf])\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Repeating Wiener Deconvolution including fixed seed=10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "psf_patterns_use = [one_psf, two_psf, three_psf, four_psf, five_psf, diffuser_psf, aperture_psf]\n", - "psf_names_use = ['one', 'two', 'three', 'four', 'five', 'diffuser', 'aperture']\n", - "\n", - "mean_photon_count_list = [300, 250, 200, 150, 100, 80, 60, 40, 20]\n", - "\n", - "seed_value = 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for photon_count in mean_photon_count_list:\n", - " for psf_idx, psf_use in enumerate(psf_patterns_use):\n", - " print('PSF: {}, Photon Count: {}'.format(psf_names_use[psf_idx], photon_count))\n", - " # make the data and scale by the photon count \n", - " (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data()\n", - " data = np.concatenate((x_train, x_test), axis=0) # make one big glob of data\n", - " data = data.astype(np.float64)\n", - " data /= np.mean(data)\n", - " data *= photon_count # convert to photons with mean value photon_count\n", - " max_val = np.max(data)\n", - " labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. \n", - " # for CIFAR 100, need to convert images to grayscale\n", - " if len(data.shape) == 4:\n", - " data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale\n", - " data = data.squeeze()\n", - " # zero pad data to be 96 x 96\n", - " data_padded = np.zeros((data.shape[0], 96, 96))\n", - " data_padded[:, 32:64, 32:64] = data\n", - "\n", - " convolved_data = convolved_dataset(psf_use, data_padded)\n", - " convolved_data_noise = add_noise(convolved_data, seed=seed_value)\n", - " # output of this noisy data is a jax array of float32, correct to regular numpy and float64\n", - " convolved_data_noise = np.array(convolved_data_noise).astype(np.float64)\n", - "\n", - " mse_psf = []\n", - " psnr_psf = []\n", - " for i in range(convolved_data_noise.shape[0]):\n", - " recon, _ = unsupervised_wiener(convolved_data_noise[i] / max_val, psf_use)\n", - " recon = recon[17:49, 17:49] #this is the crop window to look at\n", - " mse = skm.mean_squared_error(data[i] / max_val, recon)\n", - " psnr = skm.peak_signal_noise_ratio(data[i] / max_val, recon)\n", - " mse_psf.append(mse)\n", - " psnr_psf.append(psnr)\n", - " print('PSF: {}, Mean MSE: {}, Mean PSNR: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf)))\n", - " #np.save('unsupervised_wiener_deconvolution_fixed_seed/recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Archive: Detour to figure out jax types" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "type(convolved_data_noise)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "convolved_data = convolved_dataset(psf_use, data_padded)\n", - "convolved_data_noise = add_noise(convolved_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(type(convolved_data), convolved_data.dtype)\n", - "print(type(convolved_data_noise), convolved_data_noise.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "np" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "convolved_data_noise_test = np.array(convolved_data_noise).astype(np.float64)\n", - "print(type(convolved_data_noise_test))\n", - "recon, _ = unsupervised_wiener(convolved_data_noise_test[0] / max_val, psf_use) #TODO change to convolved_data_noise\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "convolved_data_noise_test = convolved_data_noise.astype(np.float64)\n", - "recon, _ = unsupervised_wiener(convolved_data_noise_test[0] / max_val, psf_use)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "info_jax_flax_23", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb b/lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb deleted file mode 100644 index efa614b..0000000 --- a/lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb +++ /dev/null @@ -1,185 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "2f0168a5", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import os\n", - "from jax import config\n", - "config.update(\"jax_enable_x64\", True)\n", - "import sys\n", - "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", - "sys.path.append('/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments/')\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '2'\n", - "from encoding_information.gpu_utils import limit_gpu_memory_growth\n", - "limit_gpu_memory_growth()\n", - "\n", - "# import tensorflow_datasets as tfds # TFDS for MNIST #TODO INSTALL AGAIN LATER\n", - "#import tensorflow as tf # TensorFlow operations\n", - "\n", - "\n", - "\n", - "# from image_distribution_models import PixelCNN\n", - "\n", - "from cleanplots import *\n", - "import jax.numpy as np\n", - "from jax.scipy.special import logsumexp\n", - "import numpy as onp\n", - "\n", - "from lensless_helpers import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "34552381", - "metadata": {}, - "outputs": [], - "source": [ - "from encoding_information.image_utils import add_noise, extract_patches\n", - "from encoding_information.models.gaussian_process import StationaryGaussianProcess\n", - "from encoding_information.models.pixel_cnn import PixelCNN\n", - "from encoding_information.information_estimation import estimate_mutual_information" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Sweep Photon Count and Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "48df0226", - "metadata": {}, - "outputs": [], - "source": [ - "diffuser_psf = load_diffuser_32()\n", - "one_psf = load_single_lens_uniform(32)\n", - "two_psf = load_two_lens_uniform(32)\n", - "three_psf = load_three_lens_uniform(32)\n", - "four_psf = load_four_lens_uniform(32)\n", - "five_psf = load_five_lens_uniform(32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# set seed values for reproducibility\n", - "seed_values_full = np.arange(1, 5)\n", - "\n", - "# set photon properties \n", - "bias = 10 # in photons\n", - "#mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300]\n", - "mean_photon_count_list = [160, 320]\n", - "\n", - "# set eligible psfs\n", - "\n", - "# psf_patterns = [None, one_psf, four_psf, diffuser_psf]\n", - "# psf_names = ['uc', 'one', 'four', 'diffuser']\n", - "psf_patterns = [one_psf, four_psf, diffuser_psf]\n", - "psf_names = ['one', 'four', 'diffuser']\n", - "\n", - "# MI estimator parameters \n", - "patch_size = 32\n", - "num_patches = 10000\n", - "bs = 500\n", - "max_epochs = 50" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for photon_count in mean_photon_count_list:\n", - " for index, psf_pattern in enumerate(psf_patterns):\n", - " gaussian_mi_estimates = []\n", - " pixelcnn_mi_estimates = []\n", - " print('Mean photon count: {}, PSF: {}'.format(photon_count, psf_names[index]))\n", - " for seed_value in seed_values_full:\n", - " # load dataset\n", - " (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data()\n", - " data = onp.concatenate((x_train, x_test), axis=0) # make one big glob of data\n", - " data = data.astype(np.float32)\n", - " data /= onp.mean(data)\n", - " data *= photon_count # convert to photons with mean value of photon_count\n", - " labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. \n", - " # for CIFAR 100, need to convert images to grayscale\n", - " if len(data.shape) == 4:\n", - " data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale\n", - " data = data.squeeze()\n", - " # make tiled data\n", - " random_data, random_labels = generate_random_tiled_data(data, labels, seed_value)\n", - " \n", - " if psf_pattern is None:\n", - " start_idx = data.shape[-1] // 2\n", - " end_idx = data.shape[-1] // 2 - 1 \n", - " psf_data = random_data[:, start_idx:-end_idx, start_idx:-end_idx]\n", - " else:\n", - " psf_data = convolved_dataset(psf_pattern, random_data)\n", - " # add small bias to data \n", - " psf_data += bias\n", - " # make patches and add noise\n", - " psf_data_patch = extract_patches(psf_data, patch_size=patch_size, num_patches=num_patches, seed=seed_value)\n", - " psf_data_shot_patch = add_noise(psf_data_patch, seed=seed_value, batch_size=bs)\n", - " # compute gaussian MI estimate, use comparison clean images\n", - " mi_gaussian_psf = estimate_mutual_information(psf_data_shot_patch, clean_images=psf_data_patch, entropy_model='gaussian',\n", - " max_epochs=max_epochs, verbose=True)\n", - " # compute PixelCNN MI estimate, use comparison clean images\n", - " mi_pixelcnn_psf = estimate_mutual_information(psf_data_shot_patch, clean_images=psf_data_patch, entropy_model='pixelcnn', num_val_samples=1000,\n", - " max_epochs=max_epochs, do_lr_decay=True, verbose=True)\n", - " gaussian_mi_estimates.append(mi_gaussian_psf)\n", - " pixelcnn_mi_estimates.append(mi_pixelcnn_psf)\n", - " #np.save('cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(pixelcnn_mi_estimates))\n", - " #np.save('cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(gaussian_mi_estimates))\n", - " # save the results once the seeds are done, file includes photon count and psf name\n", - " #np.save('cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(pixelcnn_mi_estimates))\n", - " #np.save('cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(gaussian_mi_estimates))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f667a120", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "phenotypes", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py b/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py deleted file mode 100644 index 04ded8d..0000000 --- a/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py +++ /dev/null @@ -1,195 +0,0 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.15.2 -# kernelspec: -# display_name: info_jax_flax_23 -# language: python -# name: python3 -# --- - -# %% [markdown] -# ## Sweeping non-unsupervised Wiener Deconvolution with hand-tuned parameter, 01/29/2024 -# -# Using a fixed seed (10) for consistency. - -# %% -# %load_ext autoreload -# %autoreload 2 - -import os -from jax import config -config.update("jax_enable_x64", True) -import sys -sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/') -sys.path.append('/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments/') - -os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["CUDA_VISIBLE_DEVICES"] = '3' -from encoding_information.gpu_utils import limit_gpu_memory_growth -limit_gpu_memory_growth() - -# from image_distribution_models import PixelCNN - -from cleanplots import * -#import jax.numpy as np -from jax.scipy.special import logsumexp -import numpy as np - -from lensless_helpers import * - -# %% -from encoding_information.image_utils import add_noise, extract_patches -from encoding_information.models.gaussian_process import StationaryGaussianProcess -from encoding_information.models.pixel_cnn import PixelCNN -from encoding_information.information_estimation import estimate_mutual_information - -# %% -from skimage.restoration import wiener, unsupervised_wiener, richardson_lucy -import skimage.metrics as skm - - -# %% -# load the PSFs - -diffuser_psf = load_diffuser_32() -one_psf = load_single_lens_uniform(32) -two_psf = load_two_lens_uniform(32) -three_psf = load_three_lens_uniform(32) -four_psf = load_four_lens_uniform(32) -five_psf = load_five_lens_uniform(32) -aperture_psf = np.copy(diffuser_psf) -aperture_psf[:5] = 0 -aperture_psf[-5:] = 0 -aperture_psf[:,:5] = 0 -aperture_psf[:,-5:] = 0 - - -# %% -def compute_skm_metrics(gt, recon): - # takes in already normalized gt - mse = skm.mean_squared_error(gt, recon) - psnr = skm.peak_signal_noise_ratio(gt, recon) - nmse = skm.normalized_root_mse(gt, recon) - ssim = skm.structural_similarity(gt, recon, data_range=1) - return mse, psnr, nmse, ssim - - -# %% -# set seed values for reproducibility -seed_values_full = np.arange(1, 4) - -# set photon properties -#mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300] -mean_photon_count_list = [160, 320] - -# set eligible psfs - -psf_patterns = [None, one_psf, two_psf, three_psf, four_psf, five_psf, diffuser_psf, aperture_psf] -psf_names = ['uc', 'one', 'two', 'three', 'four', 'five', 'diffuser', 'aperture'] - -# MI estimator parameters -patch_size = 32 -num_patches = 10000 -bs = 500 -max_epochs = 50 - -# %% -reg_value_best = 10**-2 -print(reg_value_best) - -# %% [markdown] -# ## Regular Wiener Deconvolution including fixed seed 10 - -# %% -psf_patterns_use = [one_psf, two_psf, three_psf, four_psf, five_psf, diffuser_psf, aperture_psf] -psf_names_use = ['one', 'two', 'three', 'four', 'five', 'diffuser', 'aperture'] - -#mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300] -mean_photon_count_list = [160, 320] - -seed_value = 10 - - -for photon_count in mean_photon_count_list: - for psf_idx, psf_use in enumerate(psf_patterns_use): - print('PSF: {}, Photon Count: {}'.format(psf_names_use[psf_idx], photon_count)) - # make the data and scale by the photon count - (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() - data = np.concatenate((x_train, x_test), axis=0) # make one big glob of data - data = data.astype(np.float64) - data /= np.mean(data) - data *= photon_count # convert to photons with mean value of photon_count - max_val = np.max(data) - labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. - # for CIFAR 100, need to convert images to grayscale - if len(data.shape) == 4: - data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale - data = data.squeeze() - # zero pad data to be 96 x 96 - data_padded = np.zeros((data.shape[0], 96, 96)) - data_padded[:, 32:64, 32:64] = data - - convolved_data = convolved_dataset(psf_use, data_padded) - convolved_data_noise = add_noise(convolved_data, seed=seed_value) - # output of this noisy data is a jax array of float32, correct to regular numpy and float64 - convolved_data_noise = np.array(convolved_data_noise).astype(np.float64) - - mse_psf = [] - psnr_psf = [] - for i in range(convolved_data_noise.shape[0]): - recon, _ = unsupervised_wiener(convolved_data_noise[i] / max_val, psf_use) - recon = recon[17:49, 17:49] #this is the crop window to look at - mse = skm.mean_squared_error(data[i] / max_val, recon) - psnr = skm.peak_signal_noise_ratio(data[i] / max_val, recon) - mse_psf.append(mse) - psnr_psf.append(psnr) - print('PSF: {}, Mean MSE: {}, Mean PSNR: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf))) - #np.save('unsupervised_wiener_deconvolution_fixed_seed/recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf]) - - - -# %% -for photon_count in mean_photon_count_list: - for psf_idx, psf_use in enumerate(psf_patterns_use): - print('PSF: {}, Photon Count: {}'.format(psf_names_use[psf_idx], photon_count)) - # make the data and scale by the photon count - (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() - data = np.concatenate((x_train, x_test), axis=0) # make one big glob of data - data = data.astype(np.float64) - data /= np.mean(data) - data *= photon_count # convert to photons with mean value of photon_count - max_val = np.max(data) - labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. - # for CIFAR 100, need to convert images to grayscale - if len(data.shape) == 4: - data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale - data = data.squeeze() - # zero pad data to be 96 x 96 - data_padded = np.zeros((data.shape[0], 96, 96)) - data_padded[:, 32:64, 32:64] = data - - convolved_data = convolved_dataset(psf_use, data_padded) - convolved_data_noise = add_noise(convolved_data, seed=seed_value) - # output of this noisy data is a jax array of float32, correct to regular numpy and float64 - convolved_data_noise = np.array(convolved_data_noise).astype(np.float64) - - mse_psf = [] - psnr_psf = [] - for i in range(convolved_data_noise.shape[0]): - recon = wiener(convolved_data_noise[i] / max_val, psf_use, reg_value_best) - recon = recon[17:49, 17:49] #this is the crop window to look at - mse = skm.mean_squared_error(data[i] / max_val, recon) - psnr = skm.peak_signal_noise_ratio(data[i] / max_val, recon) - mse_psf.append(mse) - psnr_psf.append(psnr) - print('PSF: {}, Mean MSE: {}, Mean PSNR: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf))) - #np.save('regular_wiener_deconvolution_fixed_seed/recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf]) - -# %% - - diff --git a/lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb b/lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb deleted file mode 100644 index 3a81d83..0000000 --- a/lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb +++ /dev/null @@ -1,708 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Make the plot for MI and deconvolution relationship, 01/29/2024" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload \n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from jax import config\n", - "config.update(\"jax_enable_x64\", True)\n", - "import numpy as np\n", - "\n", - "import sys\n", - "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", - "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments')\n", - "from lensless_helpers import *\n", - "import os\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", - "print(os.environ.get('PYTHONPATH'))\n", - "from cleanplots import * " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "seed_value = 10\n", - "\n", - "# set photon properties \n", - "bias = 10 # in photons\n", - "mean_photon_count_list = [20, 40, 60, 80, 100, 150, 160, 200, 250, 300, 320]\n", - "max_photon_count = mean_photon_count_list[-1]\n", - "\n", - "# set eligible psfs\n", - "\n", - "psf_names = ['one', 'four', 'diffuser'] # later make it all of them, but haven't gotten diffuser and aperture yet\n", - "\n", - "# MI estimator parameters \n", - "patch_size = 32\n", - "num_patches = 10000\n", - "bs = 500" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load MI data and make plots of it\n", - "Using updated MI data from 01/17/2024 which is run for the uniform data\n", - "\n", - "The plot has essentially invisible error bars. No more outlier issues" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from cleanplots import *\n", - "get_color_cycle()[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mi_folder = ''" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Minimum plot with no error bars" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gaussian_mi_estimates_across_psfs = [] # only keeps the minimum values, no outliers\n", - "pixelcnn_mi_estimates_across_psfs = [] # only keeps the minimum values, no outliers\n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", - "for psf_name in psf_names:\n", - " gaussian_across_photons = [] \n", - " pixelcnn_across_photons = []\n", - " for photon_count in mean_photon_count_list:\n", - " #gaussian_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", - " pixelcnn_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", - " #gaussian_across_photons.append(gaussian_mi_estimate)\n", - " pixelcnn_across_photons.append(pixelcnn_mi_estimate)\n", - " assert pixelcnn_mi_estimate.shape[0] == 4\n", - " #gaussian_mins = np.min(gaussian_across_photons, axis=1)\n", - " pixelcnn_mins = np.min(pixelcnn_across_photons, axis=1)\n", - " ax.plot(mean_photon_count_list, gaussian_mins, '-', label='Gaussian {}'.format(psf_name))\n", - " ax.plot(mean_photon_count_list, pixelcnn_mins, '-', label='PixelCNN {}'.format(psf_name))\n", - " gaussian_mi_estimates_across_psfs.append(gaussian_mins) # only keep mean dataset for use\n", - " pixelcnn_mi_estimates_across_psfs.append(pixelcnn_mins) # only keep mean datas\n", - "plt.legend()\n", - "plt.title(\"Gaussian vs. PixelCNN MI Estimates Across Photon Count, CIFAR10, 4 Seeds, Minimums\")\n", - "plt.ylabel('Estimated Mutual Information')\n", - "plt.xlabel('Mean Photon Count')\n", - "\n", - "gaussian_mi_estimates_across_psfs = np.array(gaussian_mi_estimates_across_psfs)\n", - "pixelcnn_mi_estimates_across_psfs = np.array(pixelcnn_mi_estimates_across_psfs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "psf_names_verbose = ['One Lens', 'Two Lens', 'Three Lens', 'Four Lens', 'Five Lens', 'Diffuser', 'Aperture']\n", - "plt.figure(figsize=(6, 5))\n", - "ax = plt.axes()\n", - "for i, modality in enumerate(psf_names_verbose):\n", - " #plt.plot(mean_photon_count_list, gaussian_mi_estimates_across_psfs[i], label = '{} Gaussian'.format(modality), color = get_color_cycle()[i], linestyle='--')\n", - " plt.plot(mean_photon_count_list, pixelcnn_mi_estimates_across_psfs[i], label = '{}'.format(modality), color = get_color_cycle()[i-1]) # manual color correct\n", - "plt.legend()\n", - "plt.xlabel('Mean Photon Count')\n", - "plt.ylabel(\"Mutual Information (bits per pixel)\")\n", - "#plt.title('Estimated Mutual Information vs. Mean Photon Count, CIFAR10')\n", - "clear_spines(ax)\n", - "#plt.savefig('mi_vs_photon_count.pdf', bbox_inches='tight', transparent=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Mean plot with error bars included" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", - "for psf_name in psf_names:\n", - " gaussian_across_photons = [] \n", - " pixelcnn_across_photons = []\n", - " for photon_count in mean_photon_count_list:\n", - " #gaussian_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", - " pixelcnn_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", - " if np.max(pixelcnn_mi_estimate) / np.min(pixelcnn_mi_estimate) > 2:\n", - " pixelcnn_mi_estimate[pixelcnn_mi_estimate > 2 * np.min(pixelcnn_mi_estimate)] = np.min(pixelcnn_mi_estimate)\n", - " #gaussian_across_photons.append(gaussian_mi_estimate)\n", - " pixelcnn_across_photons.append(pixelcnn_mi_estimate)\n", - " #error_lo, error_hi, mean = confidence_bars(gaussian_across_photons, 9)\n", - " error_lo_2, error_hi_2, mean_2 = confidence_bars(pixelcnn_across_photons, 11)\n", - " #ax.plot(mean_photon_count_list, mean, '-', label='Gaussian {}'.format(psf_name))\n", - " ax.plot(mean_photon_count_list, mean_2, '-', label='PixelCNN {}'.format(psf_name))\n", - " #ax.fill_between(mean_photon_count_list, error_lo, error_hi, alpha=0.4)\n", - " ax.fill_between(mean_photon_count_list, error_lo_2, error_hi_2, alpha=0.4)\n", - "plt.legend()\n", - "plt.title(\"Gaussian vs. PixelCNN MI Estimates Across Photon Count, CIFAR10, 4 Seeds, Means, Outliers Removed\")\n", - "plt.ylabel('Estimated Mutual Information')\n", - "plt.xlabel('Mean Photon Count')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load deconvolution data and make plots of it\n", - "Use means" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "deconvolution_folder = 'unsupervised_wiener_deconvolution_fixed_seed/'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mse_across_psfs = [] #5 x 9 x 1 array, 5 psfs, 9 photon counts, one value on each \n", - "psnr_across_psfs = [] #5 x 9 x 1 array, 5 psfs, 9 photon counts, one value on each\n", - "mse_lists_across_psfs = []\n", - "psnr_lists_across_psfs = []\n", - "for psf_name in psf_names:\n", - " mse_across_photons = []\n", - " psnr_across_photons = []\n", - " mse_lists_across_photons = []\n", - " psnr_lists_across_photons = []\n", - " for photon_count in mean_photon_count_list:\n", - " mse_list, psnr_list = np.load(deconvolution_folder + 'recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", - " mse_list = np.array(mse_list)\n", - " psnr_list = np.array(psnr_list)\n", - " mean_mse = np.mean(mse_list)\n", - " mean_psnr = np.mean(psnr_list)\n", - " mse_across_photons.append(mean_mse)\n", - " psnr_across_photons.append(mean_psnr)\n", - " mse_lists_across_photons.append(mse_list)\n", - " psnr_lists_across_photons.append(psnr_list)\n", - " mse_across_psfs.append(mse_across_photons)\n", - " psnr_across_psfs.append(psnr_across_photons)\n", - " mse_lists_across_psfs.append(mse_lists_across_photons)\n", - " psnr_lists_across_psfs.append(psnr_lists_across_photons)\n", - "mse_across_psfs = np.array(mse_across_psfs)\n", - "psnr_across_psfs = np.array(psnr_across_psfs)\n", - "mse_lists_across_psfs = np.array(mse_lists_across_psfs)\n", - "psnr_lists_across_psfs = np.array(psnr_lists_across_psfs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for idx, psf_name in enumerate(psf_names):\n", - " plt.plot(mean_photon_count_list, mse_across_psfs[idx], label='{}'.format(psf_name))\n", - "plt.legend()\n", - "plt.title(\"Deconvolution MSE vs. Mean Photon Count, CIFAR10\")\n", - "plt.ylabel('Mean Squared Error')\n", - "plt.xlabel('Mean Photon Count')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Make figures, include classifier error bars" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def marker_for_psf(psf_name):\n", - " if psf_name =='one':\n", - " marker = 'o'\n", - " elif psf_name == 'four':\n", - " marker = 's' \n", - " elif psf_name == 'diffuser':\n", - " marker = '*'\n", - " elif psf_name == 'uc':\n", - " marker = 'x'\n", - " elif psf_name =='two':\n", - " marker = 'd'\n", - " elif psf_name == 'three':\n", - " marker = 'v'\n", - " elif psf_name == 'five':\n", - " marker = 'p'\n", - " elif psf_name == 'aperture':\n", - " marker = 'P'\n", - " return marker" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Choose a base colormap\n", - "base_colormap = plt.cm.get_cmap('inferno')\n", - "# Define the start and end points--used so that high values aren't too light against white background\n", - "start, end = 0, 0.88 # making end point 0.8\n", - "from matplotlib.colors import LinearSegmentedColormap\n", - "# Create a new colormap from the portion of the original colormap\n", - "colormap = LinearSegmentedColormap.from_list(\n", - " 'trunc({n},{a:.2f},{b:.2f})'.format(n=base_colormap.name, a=start, b=end),\n", - " base_colormap(np.linspace(start, end, 256))\n", - ")\n", - "\n", - "min_photons_per_pixel = min(mean_photon_count_list)\n", - "max_photons_per_pixel = max(mean_photon_count_list)\n", - "\n", - "min_log_photons = np.log(min_photons_per_pixel)\n", - "max_log_photons = np.log(max_photons_per_pixel)\n", - "\n", - "def color_for_photon_level(photons_per_pixel):\n", - " log_photons = np.log(photons_per_pixel)\n", - " return colormap((log_photons - min_log_photons) / (max_log_photons - min_log_photons) )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Update parameters in below block to display the things you want to display, then run the block after to make the figure" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "estimator_type = 1 # 0 for gaussian, 1 for pixelcnn\n", - "metric_type = 1\n", - "valid_psfs = [0, 1, 2]\n", - "valid_photon_counts = [20, 40, 80, 150, 160, 300, 320]\n", - "print(psf_names)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "confidence_level = 0.9 \n", - "# using min-valued MI estimates \n", - "mi_estimate_lists = [gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]\n", - "deconv_estimate_lists = [mse_across_psfs, psnr_across_psfs]\n", - "metric_names = ['MSE', 'PSNR']\n", - "# classifier array is classifier_all_trials_across_psfs, 4x9x10 array. 4 psfs, 9 photon counts, 10 trials on each one \n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n", - "\n", - "mi_list_use = mi_estimate_lists[estimator_type] # use pixelcnn or gaussian, choose pixelcnn \n", - "deconv_list_use = deconv_estimate_lists[metric_type] # use mse or psnr, choose psnr\n", - "metric_name = metric_names[metric_type]\n", - "\n", - "for psf_idx, psf_name in enumerate(psf_names):\n", - " if psf_idx in valid_psfs:\n", - " mi_means_across_photons = [] # track mean MI values to make trendline \n", - " deconv_means_across_photons = [] # track mean MI values to make trendline\n", - " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", - " if photon_count in valid_photon_counts:\n", - " # load mean values and colors to plot \n", - " color = color_for_photon_level(photon_count)\n", - " mi_value = mi_list_use[psf_idx][photon_idx] # only use an MI value if the psf is valid, correctly indexed \n", - " deconv_value = deconv_list_use[psf_idx][photon_idx]\n", - " ax.scatter(mi_value, deconv_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", - " # add to lists to track later \n", - " mi_means_across_photons.append(mi_value)\n", - " deconv_means_across_photons.append(deconv_value)\n", - " mi_means_across_photons = np.array(mi_means_across_photons)\n", - " deconv_means_across_photons = np.array(deconv_means_across_photons)\n", - " ax.plot(mi_means_across_photons, deconv_means_across_photons, '--', color='grey', alpha=1, linewidth=2)\n", - "\n", - "ax.set_xlabel('Mutual Information (bits per pixel)')\n", - "ax.set_ylabel('Deconvolution Performance, {}'.format(metric_name))\n", - "clear_spines(ax)\n", - "\n", - "\n", - "# legend\n", - "# ax.scatter([], [], color='k', marker='x', label='No PSF')\n", - "ax.scatter([], [], color='k', marker='o', label='One Lens')\n", - "# ax.scatter([], [], color='k', marker='d', label='Two Lens')\n", - "# ax.scatter([], [], color='k', marker='v', label='Three Lens')\n", - "ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", - "# ax.scatter([], [], color='k', marker='p', label='Five Lens')\n", - "ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", - "# ax.scatter([], [], color='k', marker='P', label='Aperture')\n", - "\n", - "ax.legend(loc='lower right', frameon=True)\n", - "ax.set_xlim([0, None])\n", - "\n", - "\n", - "\n", - "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", - "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", - "sm.set_array([])\n", - "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", - "# set tick labels\n", - "cbar.ax.set_yticklabels(valid_photon_counts)\n", - "\n", - "\n", - "cbar.set_label('Mean Photon Count')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "estimator_type = 1 # 0 for gaussian, 1 for pixelcnn\n", - "metric_type = 0 # 0 for MSE, 1 for PSNR\n", - "valid_psfs = [0, 1, 2, 3, 4, 5, 6]\n", - "valid_photon_counts = [20, 40, 80, 150, 160, 300, 320]\n", - "print(psf_names)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "confidence_level = 0.9 \n", - "# using min-valued MI estimates \n", - "mi_estimate_lists = [gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]\n", - "deconv_estimate_lists = [mse_across_psfs, psnr_across_psfs]\n", - "metric_names = ['MSE', 'PSNR']\n", - "# classifier array is classifier_all_trials_across_psfs, 4x9x10 array. 4 psfs, 9 photon counts, 10 trials on each one \n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n", - "\n", - "mi_list_use = mi_estimate_lists[estimator_type] # use pixelcnn or gaussian, choose pixelcnn \n", - "deconv_list_use = deconv_estimate_lists[metric_type] # use mse or psnr, choose psnr\n", - "metric_name = metric_names[metric_type]\n", - "\n", - "for psf_idx, psf_name in enumerate(psf_names):\n", - " if psf_idx in valid_psfs:\n", - " mi_means_across_photons = [] # track mean MI values to make trendline \n", - " deconv_means_across_photons = [] # track mean MI values to make trendline\n", - " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", - " if photon_count in valid_photon_counts:\n", - " # load mean values and colors to plot \n", - " color = color_for_photon_level(photon_count)\n", - " mi_value = mi_list_use[psf_idx][photon_idx] # only use an MI value if the psf is valid, correctly indexed \n", - " deconv_value = deconv_list_use[psf_idx][photon_idx]\n", - " ax.scatter(mi_value, deconv_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", - " # add to lists to track later \n", - " mi_means_across_photons.append(mi_value)\n", - " deconv_means_across_photons.append(deconv_value)\n", - " mi_means_across_photons = np.array(mi_means_across_photons)\n", - " deconv_means_across_photons = np.array(deconv_means_across_photons)\n", - " ax.plot(mi_means_across_photons, deconv_means_across_photons, '--', color='grey', alpha=1, linewidth=2)\n", - "\n", - "ax.set_xlabel('Mutual Information (bits per pixel)')\n", - "ax.set_ylabel('Deconvolution Performance, {}'.format(metric_name))\n", - "clear_spines(ax)\n", - "\n", - "\n", - "# legend\n", - "# ax.scatter([], [], color='k', marker='x', label='No PSF')\n", - "ax.scatter([], [], color='k', marker='o', label='One Lens')\n", - "ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", - "ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", - "\n", - "ax.legend(loc='upper right', frameon=True)\n", - "ax.set_xlim([0, None])\n", - "\n", - "\n", - "\n", - "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", - "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", - "sm.set_array([])\n", - "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", - "# set tick labels\n", - "cbar.ax.set_yticklabels(valid_photon_counts)\n", - "\n", - "\n", - "cbar.set_label('Mean Photon Count')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Repeat same thing for just one lens, four lens and diffuser, include error bars - final figure" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "estimator_type = 1 # 0 for gaussian, 1 for pixelcnn\n", - "metric_type = 1\n", - "valid_psfs = [0, 1, 2]\n", - "valid_photon_counts = [20, 40, 80, 160, 320]\n", - "print([psf_names[i] for i in valid_psfs])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "confidence_level = 0.9\n", - "# using min-valued MI estimates \n", - "mi_estimate_lists = [gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]\n", - "#deconv_estimate_lists = [mse_across_psfs, psnr_across_psfs]\n", - "deconv_estimate_lists = [mse_lists_across_psfs, psnr_lists_across_psfs] # use full list versions\n", - "metric_names = ['MSE', 'PSNR']\n", - "# classifier array is classifier_all_trials_across_psfs, 4x9x10 array. 4 psfs, 9 photon counts, 10 trials on each one \n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n", - "\n", - "mi_list_use = mi_estimate_lists[estimator_type] # use pixelcnn or gaussian, choose pixelcnn \n", - "deconv_list_use = deconv_estimate_lists[metric_type] # use mse or psnr, choose psnr\n", - "metric_name = metric_names[metric_type]\n", - "\n", - "for psf_idx, psf_name in enumerate(psf_names):\n", - " if psf_idx in valid_psfs:\n", - " mi_means_across_photons = [] # track mean MI values to make trendline \n", - " deconv_means_across_photons = [] # track mean deconvolution values to make trendline\n", - " deconv_lower_across_photons = [] # track lower bounds\n", - " deconv_upper_across_photons = [] # track upper bounds\n", - " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", - " if photon_count in valid_photon_counts:\n", - " # load mean values and colors to plot \n", - " color = color_for_photon_level(photon_count)\n", - " mi_value = mi_list_use[psf_idx][photon_idx] # only use an MI value if the psf is valid, correctly indexed \n", - " deconv_value = np.mean(deconv_list_use[psf_idx][photon_idx])\n", - " ax.scatter(mi_value, deconv_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", - " # add to lists to track later \n", - " mi_means_across_photons.append(mi_value)\n", - " deconv_means_across_photons.append(deconv_value)\n", - " # calculate error bars \n", - " deconv_lower_across_photons.append(np.percentile(deconv_list_use[psf_idx][photon_idx], 100 - 100 * (1 + confidence_level) / 2))\n", - " deconv_upper_across_photons.append(np.percentile(deconv_list_use[psf_idx][photon_idx], 100 * (1 + confidence_level) / 2))\n", - " mi_means_across_photons = np.array(mi_means_across_photons)\n", - " deconv_means_across_photons = np.array(deconv_means_across_photons)\n", - " ax.plot(mi_means_across_photons, deconv_means_across_photons, '--', color='grey', alpha=1, linewidth=2)\n", - " ax.fill_between(mi_means_across_photons, deconv_lower_across_photons, deconv_upper_across_photons, color='grey', alpha=0.3, linewidth=0, zorder=-100)\n", - "\n", - "ax.set_xlabel('Mutual Information (bits per pixel)')\n", - "ax.set_ylabel('Deconvolution Performance, {}'.format(metric_name))\n", - "clear_spines(ax)\n", - "\n", - "\n", - "# legend\n", - "# ax.scatter([], [], color='k', marker='x', label='No PSF')\n", - "ax.scatter([], [], color='k', marker='o', label='One Lens')\n", - "#ax.scatter([], [], color='k', marker='d', label='Two Lens')\n", - "#ax.scatter([], [], color='k', marker='v', label='Three Lens')\n", - "ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", - "#ax.scatter([], [], color='k', marker='p', label='Five Lens')\n", - "ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", - "#ax.scatter([], [], color='k', marker='P', label='Aperture')\n", - "\n", - "ax.legend(loc='lower right', frameon=True)\n", - "ax.set_xlim([0, None])\n", - "\n", - "\n", - "\n", - "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", - "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", - "sm.set_array([])\n", - "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", - "# set tick labels\n", - "cbar.ax.set_yticklabels(valid_photon_counts)\n", - "\n", - "\n", - "cbar.set_label('Mean Photon Count')\n", - "\n", - "\n", - "#plt.savefig('{}_vs_MI_with_confidence_intervals_log_photons.pdf'.format(metric_name), bbox_inches='tight', transparent=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "estimator_type = 1 # 0 for gaussian, 1 for pixelcnn\n", - "metric_type = 0 # 0 for MSE, 1 for PSNR\n", - "valid_psfs = [0, 1, 2]\n", - "valid_photon_counts = [20, 40, 80, 160, 320]\n", - "print([psf_names[i] for i in valid_psfs])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "confidence_level = 0.9\n", - "# using min-valued MI estimates \n", - "mi_estimate_lists = [gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]\n", - "#deconv_estimate_lists = [mse_across_psfs, psnr_across_psfs]\n", - "deconv_estimate_lists = [mse_lists_across_psfs, psnr_lists_across_psfs] # use full list versions\n", - "metric_names = ['MSE', 'PSNR']\n", - "# classifier array is classifier_all_trials_across_psfs, 4x9x10 array. 4 psfs, 9 photon counts, 10 trials on each one \n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n", - "\n", - "mi_list_use = mi_estimate_lists[estimator_type] # use pixelcnn or gaussian, choose pixelcnn \n", - "deconv_list_use = deconv_estimate_lists[metric_type] # use mse or psnr, choose psnr\n", - "metric_name = metric_names[metric_type]\n", - "\n", - "for psf_idx, psf_name in enumerate(psf_names):\n", - " if psf_idx in valid_psfs:\n", - " mi_means_across_photons = [] # track mean MI values to make trendline \n", - " deconv_means_across_photons = [] # track mean deconvolution values to make trendline\n", - " deconv_lower_across_photons = [] # track lower bounds\n", - " deconv_upper_across_photons = [] # track upper bounds\n", - " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", - " if photon_count in valid_photon_counts:\n", - " # load mean values and colors to plot \n", - " color = color_for_photon_level(photon_count)\n", - " mi_value = mi_list_use[psf_idx][photon_idx] # only use an MI value if the psf is valid, correctly indexed \n", - " deconv_value = np.mean(deconv_list_use[psf_idx][photon_idx])\n", - " ax.scatter(mi_value, deconv_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", - " # add to lists to track later \n", - " mi_means_across_photons.append(mi_value)\n", - " deconv_means_across_photons.append(deconv_value)\n", - " # calculate error bars \n", - " deconv_lower_across_photons.append(np.percentile(deconv_list_use[psf_idx][photon_idx], 100 - 100 * (1 + confidence_level) / 2))\n", - " deconv_upper_across_photons.append(np.percentile(deconv_list_use[psf_idx][photon_idx], 100 * (1 + confidence_level) / 2))\n", - " mi_means_across_photons = np.array(mi_means_across_photons)\n", - " deconv_means_across_photons = np.array(deconv_means_across_photons)\n", - " ax.plot(mi_means_across_photons, deconv_means_across_photons, '--', color='grey', alpha=1, linewidth=2)\n", - " ax.fill_between(mi_means_across_photons, deconv_lower_across_photons, deconv_upper_across_photons, color='grey', alpha=0.3, linewidth=0, zorder=-100)\n", - "\n", - "ax.set_xlabel('Mutual Information (bits per pixel)')\n", - "ax.set_ylabel('Deconvolution Performance, {}'.format(metric_name))\n", - "clear_spines(ax)\n", - "\n", - "\n", - "# legend\n", - "# ax.scatter([], [], color='k', marker='x', label='No PSF')\n", - "ax.scatter([], [], color='k', marker='o', label='One Lens')\n", - "#ax.scatter([], [], color='k', marker='d', label='Two Lens')\n", - "#ax.scatter([], [], color='k', marker='v', label='Three Lens')\n", - "ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", - "#ax.scatter([], [], color='k', marker='p', label='Five Lens')\n", - "ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", - "#ax.scatter([], [], color='k', marker='P', label='Aperture')\n", - "\n", - "ax.legend(loc='upper right', frameon=True)\n", - "ax.set_xlim([0, None])\n", - "\n", - "\n", - "\n", - "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", - "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", - "sm.set_array([])\n", - "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", - "# set tick labels\n", - "cbar.ax.set_yticklabels(valid_photon_counts)\n", - "\n", - "\n", - "cbar.set_label('Mean Photon Count')\n", - "\n", - "\n", - "#plt.savefig('{}_vs_MI_with_confidence_intervals_log_photons.pdf'.format(metric_name), bbox_inches='tight', transparent=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "info_jax", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 13e3e73076c976a6306be41e865d7da484396f24 Mon Sep 17 00:00:00 2001 From: Leyla Kabuli Date: Mon, 18 Nov 2024 13:24:05 -0800 Subject: [PATCH 5/5] lensless mi experiments with updated api --- ...upervised_wiener_deconvolution_per_lens.py | 176 +++++ ...s_deconvolution_plots_cifar10_figure.ipynb | 591 +++++++++++++++++ ...n_cifar10_updated_api_reruns_smaller_lr.py | 154 +++++ lensless_imager/lensless_helpers.py | 612 ++++++++++++++++++ 4 files changed, 1533 insertions(+) create mode 100644 lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py create mode 100644 lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb create mode 100644 lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py create mode 100644 lensless_imager/lensless_helpers.py diff --git a/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py b/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py new file mode 100644 index 0000000..15fc594 --- /dev/null +++ b/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py @@ -0,0 +1,176 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.15.2 +# kernelspec: +# display_name: infotransformer +# language: python +# name: python3 +# --- + +# %% [markdown] +# ## Sweeping both unsupervised Wiener Deconvolution and non-unsupervised Wiener Deconvolution with hand-tuned paramete +# +# Using a fixed seed (10) for consistency. + +# %% +# %load_ext autoreload +# %autoreload 2 + +import os +from jax import config +config.update("jax_enable_x64", True) +import sys +sys.path.append('/home/lakabuli/workspace/EncodingInformation/src') + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = '1' +from encoding_information.gpu_utils import limit_gpu_memory_growth +limit_gpu_memory_growth() + + +from cleanplots import * +import numpy as np +import tensorflow as tf +import tensorflow.keras as tfk + +from lensless_helpers import * +from tqdm import tqdm + +# %% +from encoding_information.image_utils import add_noise +from skimage.restoration import wiener, unsupervised_wiener, richardson_lucy +import skimage.metrics as skm + +# %% +# load the PSFs + +diffuser_psf = load_diffuser_32() +one_psf = load_single_lens_uniform(32) +two_psf = load_two_lens_uniform(32) +three_psf = load_three_lens_uniform(32) +four_psf = load_four_lens_uniform(32) +five_psf = load_five_lens_uniform(32) +aperture_psf = np.copy(diffuser_psf) +aperture_psf[:5] = 0 +aperture_psf[-5:] = 0 +aperture_psf[:,:5] = 0 +aperture_psf[:,-5:] = 0 + + +# %% +def compute_skm_metrics(gt, recon): + # takes in already normalized gt + mse = skm.mean_squared_error(gt, recon) + psnr = skm.peak_signal_noise_ratio(gt, recon) + nmse = skm.normalized_root_mse(gt, recon) + ssim = skm.structural_similarity(gt, recon, data_range=1) + return mse, psnr, nmse, ssim + + +# %% +# set seed values for reproducibility +seed_values_full = np.arange(1, 4) + +# set photon properties +#mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300] +mean_photon_count_list = [20, 40, 80, 160, 320] + +# set eligible psfs + +psf_patterns_use = [one_psf, four_psf, diffuser_psf] +psf_names_use = ['one', 'four', 'diffuser'] + +save_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/deconvolutions/' + + +# MI estimator parameters +patch_size = 32 +num_patches = 10000 +test_set_size = 1500 +bs = 500 +max_epochs = 50 + +seed_value = 10 + +reg_value_best = 10**-2 + +# %% +# data generation process + +for photon_count in mean_photon_count_list: + for psf_idx, psf_pattern in enumerate(psf_patterns_use): + # load dataset + (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() + data = np.concatenate((x_train, x_test), axis=0) + data = data.astype(np.float64) + labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. + # convert data to grayscale before converting to photons + if len(data.shape) == 4: + data = tf.image.rgb_to_grayscale(data).numpy() + data = data.squeeze() + # convert to photons with mean value of photon_count + data /= np.mean(data) + data *= photon_count + # get maximum value in this data + max_val = np.max(data) + # make tiled data + random_data, random_labels = generate_random_tiled_data(data, labels, seed_value) + # only keep the middle part of the data + data_padded = np.zeros((data.shape[0], 96, 96)) + data_padded[:, 32:64, 32:64] = random_data[:, 32:64, 32:64] + # save the middle part of the data as the gt for metric computation, include only the test set portion. + gt_data = data_padded[:, 32:64, 32:64] + gt_data = gt_data[-test_set_size:] + # extract the test set before doing convolution + test_data = data_padded[-test_set_size:] + # convolve the data + convolved_data = convolved_dataset(psf_pattern, test_data) + convolved_data_noisy = add_noise(convolved_data, seed=seed_value) + # output of add_noise is a jax array that's float32, convert to regular numpy array and float64. + convolved_data_noisy = np.array(convolved_data_noisy).astype(np.float64) + + # compute metrics using unsupervised wiener deconvolution + mse_psf = [] + psnr_psf = [] + ssim_psf = [] + for i in tqdm(range(convolved_data_noisy.shape[0])): + recon, _ = unsupervised_wiener(convolved_data_noisy[i] / max_val, psf_pattern) + recon = recon[17:49, 17:49] #this is the crop window to look at + mse = skm.mean_squared_error(gt_data[i] / max_val, recon) + psnr = skm.peak_signal_noise_ratio(gt_data[i] / max_val, recon) + ssim = skm.structural_similarity(gt_data[i] / max_val, recon, data_range=1) + mse_psf.append(mse) + psnr_psf.append(psnr) + ssim_psf.append(ssim) + + print('PSF: {}, Mean MSE: {}, Mean PSNR: {}, Mean SSIM: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf), np.mean(ssim_psf))) + np.save(save_dir + 'unsupervised_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf, ssim_psf]) + + # repeat with regular deconvolution + mse_psf = [] + psnr_psf = [] + ssim_psf = [] + for i in tqdm(range(convolved_data_noisy.shape[0])): + recon = wiener(convolved_data_noisy[i] / max_val, psf_pattern, reg_value_best) + recon = recon[17:49, 17:49] #this is the crop window to look at + mse = skm.mean_squared_error(gt_data[i] / max_val, recon) + psnr = skm.peak_signal_noise_ratio(gt_data[i] / max_val, recon) + ssim = skm.structural_similarity(gt_data[i] / max_val, recon, data_range=1) + mse_psf.append(mse) + psnr_psf.append(psnr) + ssim_psf.append(ssim) + print('PSF: {}, Mean MSE: {}, Mean PSNR: {}, Mean SSIM: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf), np.mean(ssim_psf))) + np.save(save_dir + 'regular_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf, ssim_psf]) + + + + + +# %% + + diff --git a/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb b/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb new file mode 100644 index 0000000..97c9575 --- /dev/null +++ b/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb @@ -0,0 +1,591 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Make the plot for MI and deconvolution relationship for paper figure, 2024/10/23" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload \n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax import config\n", + "config.update(\"jax_enable_x64\", True)\n", + "import numpy as np\n", + "\n", + "import sys \n", + "sys.path.append('/home/lakabuli/workspace/EncodingInformation/src')\n", + "from lensless_helpers import *\n", + "import os\n", + "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", + "print(os.environ.get('PYTHONPATH'))\n", + "from cleanplots import * " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "seed_value = 10\n", + "\n", + "# set photon properties \n", + "bias = 10 # in photons\n", + "mean_photon_count_list = [20, 40, 80, 160, 320]\n", + "max_photon_count = mean_photon_count_list[-1]\n", + "\n", + "# set eligible psfs\n", + "\n", + "psf_names = ['one', 'four', 'diffuser']\n", + "\n", + "# MI estimator parameters \n", + "patch_size = 32\n", + "num_patches = 10000\n", + "val_set_size = 1000\n", + "test_set_size = 1500\n", + "\n", + "mi_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/mi_estimates_smaller_lr/'\n", + "recon_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/deconvolutions/'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load MI data and make plots of it\n", + "\n", + "The plot has essentially invisible error bars. No outlier issues" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cleanplots import *\n", + "get_color_cycle()[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", + "mis_across_psfs = []\n", + "lowers_across_psfs = []\n", + "uppers_across_psfs = []\n", + "for psf_name in psf_names:\n", + " mis = []\n", + " lowers = []\n", + " uppers = []\n", + " for photon_count in mean_photon_count_list:\n", + " mi_estimates = np.load(mi_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", + " mi_values = mi_estimates[0]\n", + " print(np.max(mi_values) - np.min(mi_values))\n", + " lower_bounds = mi_estimates[1]\n", + " upper_bounds = mi_estimates[2]\n", + " # get index that has smallest mi value across the different model runs.\n", + " min_mi_index = np.argmin(mi_values)\n", + " mis.append(mi_values[min_mi_index])\n", + " lowers.append(lower_bounds[min_mi_index])\n", + " uppers.append(upper_bounds[min_mi_index])\n", + " ax.plot(mean_photon_count_list, mis, label=psf_name) \n", + " ax.fill_between(mean_photon_count_list, lowers, uppers, alpha=0.3)\n", + " mis_across_psfs.append(mis)\n", + " lowers_across_psfs.append(lowers)\n", + " uppers_across_psfs.append(uppers)\n", + "plt.legend()\n", + "plt.title(\"PixelCNN MI estimates across Photon Count, CIFAR10\")\n", + "plt.xlabel(\"Mean Photon Count\")\n", + "plt.ylabel(\"Estimated Mutual Information\")\n", + "mis_across_psfs = np.array(mis_across_psfs)\n", + "lowers_across_psfs = np.array(lowers_across_psfs)\n", + "uppers_across_psfs = np.array(uppers_across_psfs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load recon data and make plots of it" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mses_across_psfs = []\n", + "mse_lowers_across_psfs = []\n", + "mse_uppers_across_psfs = []\n", + "psnrs_across_psfs = []\n", + "psnr_lowers_across_psfs = []\n", + "psnr_uppers_across_psfs = []\n", + "ssims_across_psfs = []\n", + "ssim_lowers_across_psfs = []\n", + "ssim_uppers_across_psfs = []\n", + "\n", + "for psf_name in psf_names: \n", + " mse_vals = []\n", + " mse_lowers = []\n", + " mse_uppers = []\n", + " psnr_vals = []\n", + " psnr_lowers = []\n", + " psnr_uppers = []\n", + " ssim_vals = []\n", + " ssim_lowers = []\n", + " ssim_uppers = []\n", + " for photon_count in mean_photon_count_list:\n", + " metrics = np.load(recon_dir + 'unsupervised_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", + " mse = metrics[0]\n", + " psnr = metrics[1] \n", + " ssim = metrics[2]\n", + " bootstrap_mse, bootstrap_psnr, bootstrap_ssim = compute_bootstraps(mse, psnr, ssim, test_set_size)\n", + " mean_mse, lower_bound_mse, upper_bound_mse = compute_confidence_interval(bootstrap_mse, confidence_interval=0.95)\n", + " mean_psnr, lower_bound_psnr, upper_bound_psnr = compute_confidence_interval(bootstrap_psnr, confidence_interval=0.95)\n", + " mean_ssim, lower_bound_ssim, upper_bound_ssim = compute_confidence_interval(bootstrap_ssim, confidence_interval=0.95)\n", + " mse_vals.append(mean_mse)\n", + " mse_lowers.append(lower_bound_mse)\n", + " mse_uppers.append(upper_bound_mse)\n", + " psnr_vals.append(mean_psnr)\n", + " psnr_lowers.append(lower_bound_psnr)\n", + " psnr_uppers.append(upper_bound_psnr)\n", + " ssim_vals.append(mean_ssim)\n", + " ssim_lowers.append(lower_bound_ssim)\n", + " ssim_uppers.append(upper_bound_ssim)\n", + " mses_across_psfs.append(mse_vals)\n", + " mse_lowers_across_psfs.append(mse_lowers)\n", + " mse_uppers_across_psfs.append(mse_uppers)\n", + " psnrs_across_psfs.append(psnr_vals)\n", + " psnr_lowers_across_psfs.append(psnr_lowers)\n", + " psnr_uppers_across_psfs.append(psnr_uppers)\n", + " ssims_across_psfs.append(ssim_vals)\n", + " ssim_lowers_across_psfs.append(ssim_lowers)\n", + " ssim_uppers_across_psfs.append(ssim_uppers)\n", + "mses_across_psfs = np.array(mses_across_psfs)\n", + "mse_lowers_across_psfs = np.array(mse_lowers_across_psfs)\n", + "mse_uppers_across_psfs = np.array(mse_uppers_across_psfs)\n", + "psnrs_across_psfs = np.array(psnrs_across_psfs)\n", + "psnr_lowers_across_psfs = np.array(psnr_lowers_across_psfs)\n", + "psnr_uppers_across_psfs = np.array(psnr_uppers_across_psfs)\n", + "ssims_across_psfs = np.array(ssims_across_psfs)\n", + "ssim_lowers_across_psfs = np.array(ssim_lowers_across_psfs)\n", + "ssim_uppers_across_psfs = np.array(ssim_uppers_across_psfs)\n", + "plt.figure(figsize=(20, 5))\n", + "plt.subplot(1, 3, 1)\n", + "for i in range(len(psf_names)):\n", + " plt.plot(mean_photon_count_list, mses_across_psfs[i], label=psf_names[i])\n", + " plt.fill_between(mean_photon_count_list, mse_lowers_across_psfs[i], mse_uppers_across_psfs[i], alpha=0.5)\n", + "plt.title(\"MSE\")\n", + "plt.legend()\n", + "plt.subplot(1, 3, 2)\n", + "for i in range(len(psf_names)):\n", + " plt.plot(mean_photon_count_list, psnrs_across_psfs[i], label=psf_names[i])\n", + " plt.fill_between(mean_photon_count_list, psnr_lowers_across_psfs[i], psnr_uppers_across_psfs[i], alpha=0.5)\n", + "plt.title(\"PSNR\")\n", + "plt.subplot(1, 3, 3)\n", + "for i in range(len(psf_names)):\n", + " plt.plot(mean_photon_count_list, ssims_across_psfs[i], label=psf_names[i])\n", + " plt.fill_between(mean_photon_count_list, ssim_lowers_across_psfs[i], ssim_uppers_across_psfs[i], alpha=0.5)\n", + "plt.title(\"SSIM\")\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make figures, omitting error bars since smaller than marker size and reverting to circular markers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def marker_for_psf(psf_name):\n", + " if psf_name =='one':\n", + " marker = 'o'\n", + " elif psf_name == 'four':\n", + " marker = 'o'\n", + " #marker = 's' \n", + " elif psf_name == 'diffuser':\n", + " #marker = '*'\n", + " marker = 'o'\n", + " elif psf_name == 'uc':\n", + " marker = 'x'\n", + " elif psf_name =='two':\n", + " marker = 'd'\n", + " elif psf_name == 'three':\n", + " marker = 'v'\n", + " elif psf_name == 'five':\n", + " marker = 'p'\n", + " elif psf_name == 'aperture':\n", + " marker = 'P'\n", + " return marker" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Choose a base colormap\n", + "base_colormap = plt.get_cmap('inferno')\n", + "# Define the start and end points--used so that high values aren't too light against white background\n", + "start, end = 0, 0.88 # making end point 0.8\n", + "from matplotlib.colors import LinearSegmentedColormap\n", + "# Create a new colormap from the portion of the original colormap\n", + "colormap = LinearSegmentedColormap.from_list(\n", + " 'trunc({n},{a:.2f},{b:.2f})'.format(n=base_colormap.name, a=start, b=end),\n", + " base_colormap(np.linspace(start, end, 256))\n", + ")\n", + "\n", + "min_photons_per_pixel = min(mean_photon_count_list)\n", + "max_photons_per_pixel = max(mean_photon_count_list)\n", + "\n", + "min_log_photons = np.log(min_photons_per_pixel)\n", + "max_log_photons = np.log(max_photons_per_pixel)\n", + "\n", + "def color_for_photon_level(photons_per_pixel):\n", + " log_photons = np.log(photons_per_pixel)\n", + " return colormap((log_photons - min_log_photons) / (max_log_photons - min_log_photons) )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# old format for selecting target indices, now not used much\n", + "metric_type = 1 # 0 for MSE, 1 for PSNR \n", + "valid_psfs = [0, 1, 2]\n", + "valid_photon_counts = [20, 40, 80, 160, 320]\n", + "psf_names = [psf_names[i] for i in valid_psfs]\n", + "print(psf_names)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mse_error_lower = np.abs(mses_across_psfs - mse_lowers_across_psfs)\n", + "mse_error_upper = np.abs(mse_uppers_across_psfs - mses_across_psfs)\n", + "psnr_error_lower = np.abs(psnrs_across_psfs - psnr_lowers_across_psfs)\n", + "psnr_error_upper = np.abs(psnr_uppers_across_psfs - psnrs_across_psfs)\n", + "mi_error_lower = np.abs(mis_across_psfs - lowers_across_psfs)\n", + "mi_error_upper = np.abs(uppers_across_psfs - mis_across_psfs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = mses_across_psfs[psf_idx][photon_idx] \n", + " ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], mses_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[mse_error_lower[psf_idx], mse_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "ax.set_ylabel(\"Mean Squared Error\")\n", + "clear_spines(ax)\n", + "\n", + "\n", + "# legend\n", + "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n", + "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", + "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", + "\n", + "ax.legend(loc='upper right', frameon=True)\n", + "ax.set_xlim([0, None])\n", + "\n", + "\n", + "\n", + "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "sm.set_array([])\n", + "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", + "# set tick labels\n", + "cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig('mse_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = psnrs_across_psfs[psf_idx][photon_idx] \n", + " ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], psnrs_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[psnr_error_lower[psf_idx], psnr_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "ax.set_ylabel(\"Peak Signal-to-Noise Ratio (dB)\")\n", + "clear_spines(ax)\n", + "\n", + "\n", + "# legend\n", + "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n", + "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", + "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", + "\n", + "ax.legend(loc='lower right', frameon=True)\n", + "ax.set_xlim([0, None])\n", + "\n", + "\n", + "\n", + "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "sm.set_array([])\n", + "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", + "# set tick labels\n", + "cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig('psnr_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = ssims_across_psfs[psf_idx][photon_idx] \n", + " ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], ssims_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[ssim_error_lower[psf_idx], ssim_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "ax.set_ylabel(\"Structural Similarity Index Measure (SSIM)\")\n", + "clear_spines(ax)\n", + "\n", + "\n", + "# legend\n", + "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n", + "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", + "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", + "\n", + "ax.legend(loc='lower right', frameon=True)\n", + "ax.set_xlim([0, None])\n", + "\n", + "\n", + "\n", + "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "sm.set_array([])\n", + "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", + "# set tick labels\n", + "cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig('ssim_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Put all 3 into one figure" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from cleanplots import *\n", + "from matplotlib.ticker import ScalarFormatter\n", + "\n", + "figs, axs = plt.subplots(1, 3, figsize=(12, 4), sharex=True)\n", + "\n", + "\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = mses_across_psfs[psf_idx][photon_idx] \n", + " axs[0].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], mses_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[mse_error_lower[psf_idx], mse_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " axs[0].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "#axs[0].set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "axs[0].set_title(\"Mean Squared Error\")\n", + "clear_spines(axs[0])\n", + "\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = ssims_across_psfs[psf_idx][photon_idx] \n", + " axs[1].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], ssims_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[ssim_error_lower[psf_idx], ssim_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " axs[1].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "axs[1].set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "axs[1].set_title(\"Structural Similarity Index Measure (SSIM)\")\n", + "clear_spines(axs[1])\n", + "\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = psnrs_across_psfs[psf_idx][photon_idx] \n", + " axs[2].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], psnrs_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[psnr_error_lower[psf_idx], psnr_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " axs[2].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "#axs[2].set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "axs[2].set_title(\"Peak Signal-to-Noise Ratio (dB)\")\n", + "clear_spines(axs[2])\n", + "\n", + "# norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "# sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "# sm.set_array([])\n", + "# cbar = plt.colorbar(sm, ax=axs[2], ticks=(np.log(valid_photon_counts)))\n", + "# # set tick labels\n", + "# cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "# cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig(\"metrics_vs_MI_with_confidence_intervals_log_photons.pdf\", bbox_inches='tight', transparent=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "infotransformer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py b/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py new file mode 100644 index 0000000..bd9148a --- /dev/null +++ b/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py @@ -0,0 +1,154 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.15.2 +# kernelspec: +# display_name: infotransformer +# language: python +# name: python3 +# --- + +# %% +# %load_ext autoreload +# %autoreload 2 + +# Final MI estimation script for lensless imager, used in paper. + +import os +from jax import config +config.update("jax_enable_x64", True) +import sys +sys.path.append('/home/lakabuli/workspace/EncodingInformation/src') + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = '0' +from encoding_information.gpu_utils import limit_gpu_memory_growth +limit_gpu_memory_growth() + +from cleanplots import * +import jax.numpy as np +import numpy as onp +import tensorflow as tf +import tensorflow.keras as tfk + + +from lensless_helpers import * + +# %% +from encoding_information import extract_patches +from encoding_information.models import PixelCNN +from encoding_information.plot_utils import plot_samples +from encoding_information.models import PoissonNoiseModel +from encoding_information.image_utils import add_noise +from encoding_information import estimate_information + +# %% [markdown] +# ### Sweep Photon Count and Diffusers + +# %% +diffuser_psf = load_diffuser_32() +one_psf = load_single_lens_uniform(32) +two_psf = load_two_lens_uniform(32) +three_psf = load_three_lens_uniform(32) +four_psf = load_four_lens_uniform(32) +five_psf = load_five_lens_uniform(32) + +# %% +# set seed values for reproducibility +seed_values_full = np.arange(1, 5) + +# set photon properties +bias = 10 # in photons +mean_photon_count_list = [20, 40, 80, 160, 320] + +# set eligible psfs + +psf_patterns = [diffuser_psf, four_psf, one_psf] +psf_names = ['diffuser', 'four', 'one'] + +# MI estimator parameters +patch_size = 32 +num_patches = 10000 +val_set_size = 1000 +test_set_size = 1500 +num_samples = 8 +learning_rate = 1e-3 # using 5x iterations per epoch, using smaller lr, and using less patience since it should be a smoother curve. +num_iters_per_epoch = 500 +patience_val = 20 + + +save_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/mi_estimates_smaller_lr/' + + +# %% +for photon_count in mean_photon_count_list: + for index, psf_pattern in enumerate(psf_patterns): + val_loss_log = [] + mi_estimates = [] + lower_bounds = [] + upper_bounds = [] + for seed_value in seed_values_full: + # load dataset + (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() + data = onp.concatenate((x_train, x_test), axis=0) + labels = np.concatenate((y_train, y_test), axis=0) + data = data.astype(np.float32) + # convert data to grayscale before converting to photons + if len(data.shape) == 4: + data = tf.image.rgb_to_grayscale(data).numpy() + data = data.squeeze() + # convert to photons with mean value of photon_count + data /= onp.mean(data) + data *= photon_count + # make tiled data + random_data, random_labels = generate_random_tiled_data(data, labels, seed_value) + + if psf_pattern is None: + start_idx = data.shape[-1] // 2 + end_idx = data.shape[-1] // 2 - 1 + psf_data = random_data[:, start_idx:-end_idx, start_idx:-end_idx] + else: + psf_data = convolved_dataset(psf_pattern, random_data) + # add small bias to data + psf_data += bias + # make patches for training and testing splits, random patching + patches = extract_patches(psf_data[:-test_set_size], patch_size=patch_size, num_patches=num_patches, seed=seed_value, verbose=True) + test_patches = extract_patches(psf_data[-test_set_size:], patch_size=patch_size, num_patches=test_set_size, seed=seed_value, verbose=True) + # put all the clean patches together for use in MI estimatino function later + full_clean_patches = onp.concatenate([patches, test_patches]) + # add noise to both sets + patches_noisy = add_noise(patches, seed=seed_value) + test_patches_noisy = add_noise(test_patches, seed=seed_value) + + # initialize pixelcnn + pixel_cnn = PixelCNN() + # fit pixelcnn to noisy patches. defaults to 10% val samples which will be 1k as desired. + # using smaller lr this time and adding seeding, letting it go for full training time. + val_loss_history = pixel_cnn.fit(patches_noisy, seed=seed_value, learning_rate=learning_rate, do_lr_decay=False, steps_per_epoch=num_iters_per_epoch, patience=patience_val) + # generate samples, not necessary for MI sweeps + # pixel_cnn_samples = pixel_cnn.generate_samples(num_samples=num_samples) + # # visualize samples + # plot_samples([pixel_cnn_samples], test_patches, model_names=['PixelCNN']) + + # instantiate noise model + noise_model = PoissonNoiseModel() + # estimate information using the fit pixelcnn and noise model, with clean data + pixel_cnn_info, pixel_cnn_lower_bound, pixel_cnn_upper_bound = estimate_information(pixel_cnn, noise_model, patches_noisy, + test_patches_noisy, clean_data=full_clean_patches, + confidence_interval=0.95) + print("PixelCNN estimated information: ", pixel_cnn_info) + print("PixelCNN lower bound: ", pixel_cnn_lower_bound) + print("PixelCNN upper bound: ", pixel_cnn_upper_bound) + # append results to lists + val_loss_log.append(val_loss_history) + mi_estimates.append(pixel_cnn_info) + lower_bounds.append(pixel_cnn_lower_bound) + upper_bounds.append(pixel_cnn_upper_bound) + np.save(save_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array([mi_estimates, lower_bounds, upper_bounds])) + np.save(save_dir + 'pixelcnn_val_loss_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(val_loss_log, dtype=object)) + np.save(save_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array([mi_estimates, lower_bounds, upper_bounds])) + np.save(save_dir + 'pixelcnn_val_loss_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(val_loss_log, dtype=object)) diff --git a/lensless_imager/lensless_helpers.py b/lensless_imager/lensless_helpers.py new file mode 100644 index 0000000..0693b28 --- /dev/null +++ b/lensless_imager/lensless_helpers.py @@ -0,0 +1,612 @@ +import numpy as np # use regular numpy for now, simpler +import scipy +from tqdm import tqdm +# import tensorflow as tf +# import tensorflow.keras as tfk +import gc +import warnings + +import skimage +import skimage.io +from skimage.transform import resize + +# from tensorflow.keras.optimizers import SGD + +def tile_9_images(data_set): + # takes 9 images and forms a tiled image + assert len(data_set) == 9 + return np.block([[data_set[0], data_set[1], data_set[2]],[data_set[3], data_set[4], data_set[5]],[data_set[6], data_set[7], data_set[8]]]) + +def generate_random_tiled_data(x_set, y_set, seed_value=-1): + # takes a set of images and labels and returns a set of tiled images and corresponding labels + # the size of the output should be 3x the size of the input + vert_shape = x_set.shape[1] * 3 + horiz_shape = x_set.shape[2] * 3 + random_data = np.zeros((x_set.shape[0], vert_shape, horiz_shape)) # for mnist this was 84 x 84 + random_labels = np.zeros((y_set.shape[0], 1)) + if seed_value==-1: + np.random.seed() + else: + np.random.seed(seed_value) + for i in range(x_set.shape[0]): + img_items = np.random.choice(x_set.shape[0], size=9, replace=True) + data_set = x_set[img_items] + random_labels[i] = y_set[img_items[4]] + random_data[i] = tile_9_images(data_set) + return random_data, random_labels + +def generate_repeated_tiled_data(x_set, y_set): + # takes set of images and labels and returns a set of repeated tiled images and corresponding labels, no randomness + # the size of the output is 3x the size of the input, this essentially is a wrapper for np.tile + repeated_data = np.tile(x_set, (1, 3, 3)) + repeated_labels = y_set # the labels are just what they were + return repeated_data, repeated_labels + +def convolved_dataset(psf, random_tiled_data): + # takes a psf and a set of tiled images and returns a set of convolved images, convolved image size is 2n + 1? same size as the random data when it's cropped + # tile size is two images worth plus one extra index value + vert_shape = psf.shape[0] * 2 + 1 + horiz_shape = psf.shape[1] * 2 + 1 + psf_dataset = np.zeros((random_tiled_data.shape[0], vert_shape, horiz_shape)) # 57 x 57 for the case of mnist 28x28 images, 65 x 65 for the cifar 32 x 32 images + for i in tqdm(range(random_tiled_data.shape[0])): + psf_dataset[i] = scipy.signal.fftconvolve(psf, random_tiled_data[i], mode='valid') + return psf_dataset + +def compute_entropy(eigenvalues): + sum_log_evs = np.sum(np.log2(eigenvalues)) + D = eigenvalues.shape[0] + gaussian_entropy = 0.5 * (sum_log_evs + D * np.log2(2 * np.pi * np.e)) + return gaussian_entropy + +def add_shot_noise(photon_scaled_images, photon_fraction=None, photons_per_pixel=None, assume_noiseless=True, seed_value=-1): + #adapted from henry, also uses a seed though + if seed_value==-1: + np.random.seed() + else: + np.random.seed(seed_value) + + # check all pixels greater than 0 + if np.any(photon_scaled_images < 0): + #warning about negative + warnings.warn(f"Negative pixel values detected. Clipping to 0.") + photon_scaled_images[photon_scaled_images < 0] = 0 + if photons_per_pixel is not None: + if photons_per_pixel > np.mean(photon_scaled_images): + warnings.warn(f"photons_per_pixel is greater than actual photon count ({photons_per_pixel}). Clipping to {np.mean(photon_scaled_images)}") + photons_per_pixel = np.mean(photon_scaled_images) + photon_fraction = photons_per_pixel / np.mean(photon_scaled_images) + + if photon_fraction > 1: + warnings.warn(f"photon_fraction is greater than 1 ({photon_fraction}). Clipping to 1.") + photon_fraction = 1 + + if assume_noiseless: + additional_sd = np.sqrt(photon_fraction * photon_scaled_images) + if np.any(np.isnan(additional_sd)): + warnings.warn('There are nans here') + additional_sd[np.isnan(additional_sd)] = 0 + # something here goes weird for RML + # + #else: + # additional_sd = np.sqrt(photon_fraction * photon_scaled_images) - photon_fraction * np.sqrt(photon_scaled_images) + simulated_images = photon_scaled_images * photon_fraction + additional_sd * np.random.randn(*photon_scaled_images.shape) + positive = np.array(simulated_images) + positive[positive < 0] = 0 # cant have negative counts + return np.array(positive) + +def tf_cast(data): + # normalizes data, loads it to a tensorflow array of type float32 + return tf.cast(data / np.max(data), tf.float32) +def tf_labels(labels): + # loads labels to a tensorflow array of type int64 + return tf.cast(labels, tf.int64) + + + +def run_model_simple(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1): + if seed_value == -1: + seed_val = np.random.randint(10, 1000) + tfk.utils.set_random_seed(seed_val) + else: + tfk.utils.set_random_seed(seed_value) + + model = tfk.models.Sequential() + model.add(tfk.layers.Flatten()) + model.add(tfk.layers.Dense(256, activation='relu')) + model.add(tfk.layers.Dense(256, activation='relu')) + model.add(tfk.layers.Dense(10, activation='softmax')) + + model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + + early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option + mode="min", patience=5, + restore_best_weights=True, verbose=1) + history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), batch_size=32, epochs=50, callbacks=[early_stop]) + test_loss, test_acc = model.evaluate(test_data, test_labels) + return history, model, test_loss, test_acc + +def run_model_cnn(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1): + # structure from https://www.kaggle.com/code/cdeotte/how-to-choose-cnn-architecture-mnist + if seed_value == -1: + seed_val = np.random.randint(10, 1000) + tfk.utils.set_random_seed(seed_val) + else: + tfk.utils.set_random_seed(seed_value) + + model = tfk.models.Sequential() + model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu', input_shape=(57, 57, 1))) #64 and 128 works very slightly better + model.add(tfk.layers.MaxPool2D()) + model.add(tfk.layers.Conv2D(128, kernel_size=5, padding='same', activation='relu')) + model.add(tfk.layers.MaxPool2D()) + #model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu')) + #model.add(tfk.layers.MaxPool2D(padding='same')) + model.add(tfk.layers.Flatten()) + + #model.add(tfk.layers.Dense(256, activation='relu')) + model.add(tfk.layers.Dense(128, activation='relu')) + model.add(tfk.layers.Dense(10, activation='softmax')) + + model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + + early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option + mode="min", patience=5, + restore_best_weights=True, verbose=1) + history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=50, batch_size=32, callbacks=[early_stop]) #validation data is not test data + test_loss, test_acc = model.evaluate(test_data, test_labels) + return history, model, test_loss, test_acc + +def seeded_permutation(seed_value, n): + # given fixed seed returns permutation order + np.random.seed(seed_value) + permutation_order = np.random.permutation(n) + return permutation_order + +def segmented_indices(permutation_order, n, training_fraction, test_fraction): + #given permutation order returns indices for each of the three sets + training_indices = permutation_order[:int(training_fraction*n)] + test_indices = permutation_order[int(training_fraction*n):int((training_fraction+test_fraction)*n)] + validation_indices = permutation_order[int((training_fraction+test_fraction)*n):] + return training_indices, test_indices, validation_indices + +def permute_data(data, labels, seed_value, training_fraction=0.8, test_fraction=0.1): + #validation fraction is implicit, if including a validation set, expect to use the remaining fraction of the data + permutation_order = seeded_permutation(seed_value, data.shape[0]) + training_indices, test_indices, validation_indices = segmented_indices(permutation_order, data.shape[0], training_fraction, test_fraction) + + training_data = data[training_indices] + training_labels = labels[training_indices] + testing_data = data[test_indices] + testing_labels = labels[test_indices] + validation_data = data[validation_indices] + validation_labels = labels[validation_indices] + + return (training_data, training_labels), (testing_data, testing_labels), (validation_data, validation_labels) + +def add_gaussian_noise(data, noise_level, seed_value=-1): + if seed_value==-1: + np.random.seed() + else: + np.random.seed(seed_value) + return data + noise_level * np.random.randn(*data.shape) + +def confidence_bars(data_array, noise_length, confidence_interval=0.95): + # can also use confidence interval 0.9 or 0.99 if want slightly different bounds + error_lo = np.percentile(data_array, 100 * (1 - confidence_interval) / 2, axis=1) + error_hi = np.percentile(data_array, 100 * (1 - (1 - confidence_interval) / 2), axis=1) + mean = np.mean(data_array, axis=1) + assert len(error_lo) == len(mean) == len(error_hi) == noise_length + return error_lo, error_hi, mean + + +######### This function is very outdated, don't use it!! used to be called test_system use the ones below instead +######### +def test_system_old(noise_level, psf_name, model_name, seed_values, data, labels, training_fraction, testing_fraction, diffuser_region, phlat_region, psf, noise_type, rml_region): + # runs the model for the number of seeds given, returns the test accuracy for each seed + test_accuracy_list = [] + for seed_value in seed_values: + seed_value = int(seed_value) + tfk.backend.clear_session() + gc.collect() + tfk.utils.set_random_seed(seed_value) # set random seed out here too? + training, testing, validation = permute_data(data, labels, seed_value, training_fraction, testing_fraction) + x_train, y_train = training + x_test, y_test = testing + x_validation, y_validation = validation + + random_test_data, random_test_labels = generate_random_tiled_data(x_test, y_test, seed_value) + random_train_data, random_train_labels = generate_random_tiled_data(x_train, y_train, seed_value) + random_valid_data, random_valid_labels = generate_random_tiled_data(x_validation, y_validation, seed_value) + + if psf_name == 'uc': + test_data = random_test_data[:, 14:-13, 14:-13] + train_data = random_train_data[:, 14:-13, 14:-13] + valid_data = random_valid_data[:, 14:-13, 14:-13] + if psf_name == 'psf_4': + test_data = convolved_dataset(psf, random_test_data) + train_data = convolved_dataset(psf, random_train_data) + valid_data = convolved_dataset(psf, random_valid_data) + if psf_name == 'diffuser': + test_data = convolved_dataset(diffuser_region, random_test_data) + train_data = convolved_dataset(diffuser_region, random_train_data) + valid_data = convolved_dataset(diffuser_region, random_valid_data) + if psf_name == 'phlat': + test_data = convolved_dataset(phlat_region, random_test_data) + train_data = convolved_dataset(phlat_region, random_train_data) + valid_data = convolved_dataset(phlat_region, random_valid_data) + # 6/19/23 added RML option + if psf_name == 'rml': + test_data = convolved_dataset(rml_region, random_test_data) + train_data = convolved_dataset(rml_region, random_train_data) + valid_data = convolved_dataset(rml_region, random_valid_data) + + # address any tiny floating point negative values, which only occur in RML data + if np.any(test_data < 0): + #print('negative values in test data for {} psf'.format(psf_name)) + test_data[test_data < 0] = 0 + if np.any(train_data < 0): + #print('negative values in train data for {} psf'.format(psf_name)) + train_data[train_data < 0] = 0 + if np.any(valid_data < 0): + #print('negative values in valid data for {} psf'.format(psf_name)) + valid_data[valid_data < 0] = 0 + + + # additive gaussian noise, add noise after convolving, fixed 5/15/2023 + if noise_type == 'gaussian': + test_data = add_gaussian_noise(test_data, noise_level, seed_value) + train_data = add_gaussian_noise(train_data, noise_level, seed_value) + valid_data = add_gaussian_noise(valid_data, noise_level, seed_value) + if noise_type == 'poisson': + test_data = add_shot_noise(test_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True) + train_data = add_shot_noise(train_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True) + valid_data = add_shot_noise(valid_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True) + + train_data, test_data, valid_data = tf_cast(train_data), tf_cast(test_data), tf_cast(valid_data) + random_train_labels, random_test_labels, random_valid_labels = tf_labels(random_train_labels), tf_labels(random_test_labels), tf_labels(random_valid_labels) + + if model_name == 'simple': + history, model, test_loss, test_acc = run_model_simple(train_data, random_train_labels, test_data, random_test_labels, valid_data, random_valid_labels, seed_value) + if model_name == 'cnn': + history, model, test_loss, test_acc = run_model_cnn(train_data, random_train_labels, test_data, random_test_labels, valid_data, random_valid_labels, seed_value) + test_accuracy_list.append(test_acc) + np.save('classification_results_rml_psf_619/test_accuracy_{}_noise_{}_{}_psf_{}_model.npy'.format(noise_level, noise_type, psf_name, model_name), test_accuracy_list) + + ###### CNN for 32x32 CIFAR10 images + # Originally written 11/14/2023, but then lost in a merge, recopied 1/14/2024 +def run_model_cnn_cifar(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=5): + # structure from https://www.kaggle.com/code/cdeotte/how-to-choose-cnn-architecture-mnist + # default architecture is 50 epochs and patience 5, but recently some need longer patience + if seed_value == -1: + seed_val = np.random.randint(10, 1000) + tfk.utils.set_random_seed(seed_val) + else: + tfk.utils.set_random_seed(seed_value) + model = tfk.models.Sequential() + model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu', input_shape=(65, 65, 1))) + model.add(tfk.layers.MaxPool2D()) + model.add(tfk.layers.Conv2D(128, kernel_size=5, padding='same', activation='relu')) + model.add(tfk.layers.MaxPool2D()) + #model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu')) + #model.add(tfk.layers.MaxPool2D(padding='same')) + model.add(tfk.layers.Flatten()) + + #model.add(tfk.layers.Dense(256, activation='relu')) + model.add(tfk.layers.Dense(128, activation='relu')) + model.add(tfk.layers.Dense(10, activation='softmax')) + + model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + + early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option + mode="min", patience=patience, + restore_best_weights=True, verbose=1) + history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data + test_loss, test_acc = model.evaluate(test_data, test_labels) + return history, model, test_loss, test_acc + +def make_ttv_sets(data, labels, seed_value, training_fraction, testing_fraction): + training, testing, validation = permute_data(data, labels, seed_value, training_fraction, testing_fraction) + training_data, training_labels = training + testing_data, testing_labels = testing + validation_data, validation_labels = validation + training_data, testing_data, validation_data = tf_cast(training_data), tf_cast(testing_data), tf_cast(validation_data) + training_labels, testing_labels, validation_labels = tf_labels(training_labels), tf_labels(testing_labels), tf_labels(validation_labels) + return (training_data, training_labels), (testing_data, testing_labels), (validation_data, validation_labels) + +def run_network_cifar(data, labels, seed_value, training_fraction, testing_fraction, mode='cnn', max_epochs=50, patience=5): + # small modification to be able to run 32x32 image data + training, testing, validation = make_ttv_sets(data, labels, seed_value, training_fraction, testing_fraction) + if mode == 'cnn': + history, model, test_loss, test_acc = run_model_cnn_cifar(training[0], training[1], + testing[0], testing[1], + validation[0], validation[1], seed_value, max_epochs, patience) + elif mode == 'simple': + history, model, test_loss, test_acc = run_model_simple(training[0], training[1], + testing[0], testing[1], + validation[0], validation[1], seed_value) + elif mode == 'new_cnn': + history, model, test_loss, test_acc = current_testing_model(training[0], training[1], + testing[0], testing[1], + validation[0], validation[1], seed_value, max_epochs, patience) + elif mode == 'mom_cnn': + history, model, test_loss, test_acc = momentum_testing_model(training[0], training[1], + testing[0], testing[1], + validation[0], validation[1], seed_value, max_epochs, patience) + return history, model, test_loss, test_acc + + +def load_diffuser_psf(): + diffuser_psf = skimage.io.imread('psfs/diffuser_psf.png') + diffuser_psf = diffuser_psf[:,:,1] + diffuser_resize = diffuser_psf[200:500, 250:550] + diffuser_resize = resize(diffuser_resize, (100, 100), anti_aliasing=True) #resize(diffuser_psf, (28, 28)) + diffuser_region = diffuser_resize[:28, :28] + diffuser_region /= np.sum(diffuser_region) + return diffuser_region + +def load_phlat_psf(): + phlat_psf = skimage.io.imread('psfs/phlat_psf.png') + phlat_psf = phlat_psf[900:2900, 1500:3500, 1] + phlat_psf = resize(phlat_psf, (200, 200), anti_aliasing=True) + phlat_region = phlat_psf[10:38, 20:48] + phlat_region /= np.sum(phlat_region) + return phlat_region + +def load_4_psf(): + psf = np.zeros((28, 28)) + psf[20,20] = 1 + psf[15, 10] = 1 + psf[5, 13] = 1 + psf[23, 6] = 1 + psf = scipy.ndimage.gaussian_filter(psf, sigma=1) + psf /= np.sum(psf) + return psf + +# 6/9/23 added rml option +def load_rml_psf(): + rml_psf = skimage.io.imread('psfs/psf_8holes.png') + rml_psf = rml_psf[1000:3000, 1500:3500] + rml_psf_resize = resize(rml_psf, (100, 100), anti_aliasing=True) + rml_psf_region = rml_psf_resize[40:100, :60] + rml_psf_region = resize(rml_psf_region, (28, 28), anti_aliasing=True) + rml_psf_region /= np.sum(rml_psf_region) + return rml_psf_region + +def load_rml_new_psf(): + rml_psf = skimage.io.imread('psfs/psf_8holes.png') + rml_psf = rml_psf[1000:3000, 1500:3500] + rml_psf_small = resize(rml_psf, (85, 85), anti_aliasing=True) + rml_psf_region = rml_psf_small[52:80, 10:38] + rml_psf_region /= np.sum(rml_psf_region) + return rml_psf_region + +def load_single_lens(): + one_lens = np.zeros((28, 28)) + one_lens[14, 14] = 1 + one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8) + one_lens /= np.sum(one_lens) + return one_lens + +def load_two_lens(): + two_lens = np.zeros((28, 28)) + two_lens[10, 10] = 1 + two_lens[20, 20] = 1 + two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8) + two_lens /= np.sum(two_lens) + return two_lens + +def load_three_lens(): + three_lens = np.zeros((28, 28)) + three_lens[8, 12] = 1 + three_lens[16, 20] = 1 + three_lens[20, 7] = 1 + three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8) + three_lens /= np.sum(three_lens) + return three_lens + + +def load_single_lens_32(): + one_lens = np.zeros((32, 32)) + one_lens[16, 16] = 1 + one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8) + one_lens /= np.sum(one_lens) + return one_lens + +def load_two_lens_32(): + two_lens = np.zeros((32, 32)) + two_lens[10, 10] = 1 + two_lens[21, 21] = 1 + two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8) + two_lens /= np.sum(two_lens) + return two_lens + +def load_three_lens_32(): + three_lens = np.zeros((32, 32)) + three_lens[9, 12] = 1 + three_lens[17, 22] = 1 + three_lens[24, 8] = 1 + three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8) + three_lens /= np.sum(three_lens) + return three_lens + +def load_four_lens_32(): + psf = np.zeros((32, 32)) + psf[22, 22] = 1 + psf[15, 10] = 1 + psf[5, 12] = 1 + psf[28, 8] = 1 + psf = scipy.ndimage.gaussian_filter(psf, sigma=1) # note that this one is sigma 1, for mnist it's sigma 0.8 + psf /= np.sum(psf) + return psf + +def load_diffuser_32(): + diffuser_psf = skimage.io.imread('psfs/diffuser_psf.png') + diffuser_psf = diffuser_psf[:,:,1] + diffuser_resize = diffuser_psf[200:500, 250:550] + diffuser_resize = resize(diffuser_resize, (100, 100), anti_aliasing=True) #resize(diffuser_psf, (28, 28)) + diffuser_region = diffuser_resize[:32, :32] + diffuser_region /= np.sum(diffuser_region) + return diffuser_region + + + +### 10/15/2023: Make new versions of the model functions that train with Datasets - first attempt failed + +# lenses with centralized positions for use in task-specific estimations +def load_single_lens_uniform(size=32): + one_lens = np.zeros((size, size)) + one_lens[16, 16] = 1 + one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8) + one_lens /= np.sum(one_lens) + return one_lens + +def load_two_lens_uniform(size=32): + two_lens = np.zeros((size, size)) + two_lens[16, 16] = 1 + two_lens[7, 9] = 1 + two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8) + two_lens /= np.sum(two_lens) + return two_lens + +def load_three_lens_uniform(size=32): + three_lens = np.zeros((size, size)) + three_lens[16, 16] = 1 + three_lens[7, 9] = 1 + three_lens[23, 21] = 1 + three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8) + three_lens /= np.sum(three_lens) + return three_lens + +def load_four_lens_uniform(size=32): + four_lens = np.zeros((size, size)) + four_lens[16, 16] = 1 + four_lens[7, 9] = 1 + four_lens[23, 21] = 1 + four_lens[8, 24] = 1 + four_lens = scipy.ndimage.gaussian_filter(four_lens, sigma=0.8) + four_lens /= np.sum(four_lens) + return four_lens +def load_five_lens_uniform(size=32): + five_lens = np.zeros((size, size)) + five_lens[16, 16] = 1 + five_lens[7, 9] = 1 + five_lens[23, 21] = 1 + five_lens[8, 24] = 1 + five_lens[21, 5] = 1 + five_lens = scipy.ndimage.gaussian_filter(five_lens, sigma=0.8) + five_lens /= np.sum(five_lens) + return five_lens + + + +## 01/24/2024 new CNN that's slightly deeper +def current_testing_model(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=20): + # structure from https://www.kaggle.com/code/amyjang/tensorflow-cifar10-cnn-tutorial + + if seed_value == -1: + seed_val = np.random.randint(10, 1000) + tfk.utils.set_random_seed(seed_val) + else: + tfk.utils.set_random_seed(seed_value) + + model = tf.keras.models.Sequential([ + tf.keras.layers.Conv2D(32, kernel_size=5, padding='same', input_shape=(65, 65, 1), activation='relu'), + tf.keras.layers.Conv2D(32, kernel_size=5, activation='relu'), + tf.keras.layers.MaxPooling2D(), + tf.keras.layers.Dropout(0.25), + + tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'), + tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu'), + tf.keras.layers.MaxPooling2D(), + tf.keras.layers.Dropout(0.25), + + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(512, activation='relu'), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax'), + ]) + + model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + + early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option + mode="min", patience=patience, + restore_best_weights=True, verbose=1) + print(model.optimizer.get_config()) + + history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data + test_loss, test_acc = model.evaluate(test_data, test_labels) + return history, model, test_loss, test_acc + + + + +## 01/24/2024 new CNN that's slightly deeper +def momentum_testing_model(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=20): + # structure from https://www.kaggle.com/code/amyjang/tensorflow-cifar10-cnn-tutorial + # includes nesterov momentum feature, rather than regular momentum + if seed_value == -1: + seed_val = np.random.randint(10, 1000) + tfk.utils.set_random_seed(seed_val) + else: + tfk.utils.set_random_seed(seed_value) + + model = tf.keras.models.Sequential([ + tf.keras.layers.Conv2D(32, kernel_size=5, padding='same', input_shape=(65, 65, 1), activation='relu'), + tf.keras.layers.Conv2D(32, kernel_size=5, activation='relu'), + tf.keras.layers.MaxPooling2D(), + tf.keras.layers.Dropout(0.25), + + tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'), + tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu'), + tf.keras.layers.MaxPooling2D(), + tf.keras.layers.Dropout(0.25), + + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(512, activation='relu'), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax'), + ]) + + model.compile(optimizer=SGD(momentum=0.9, nesterov=True), loss='sparse_categorical_crossentropy', metrics=['accuracy']) + + early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option + mode="min", patience=patience, + restore_best_weights=True, verbose=1) + + print(model.optimizer.get_config()) + + history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data + test_loss, test_acc = model.evaluate(test_data, test_labels) + return history, model, test_loss, test_acc + + +# bootstrapping function +def compute_bootstraps(mses, psnrs, ssims, test_set_length, num_bootstraps=100): + bootstrap_mses = [] + bootstrap_psnrs = [] + bootstrap_ssims = [] + for bootstrap_idx in tqdm(range(num_bootstraps), desc='Bootstrapping to compute confidence interval'): + # select indices for sampling + bootstrap_indices = np.random.choice(test_set_length, test_set_length, replace=True) + # take the metric values at those indices + bootstrap_selected_mses = mses[bootstrap_indices] + bootstrap_selected_psnrs = psnrs[bootstrap_indices] + bootstrap_selected_ssims = ssims[bootstrap_indices] + # accumulate the mean of the selected metric values + bootstrap_mses.append(np.mean(bootstrap_selected_mses)) + bootstrap_psnrs.append(np.mean(bootstrap_selected_psnrs)) + bootstrap_ssims.append(np.mean(bootstrap_selected_ssims)) + bootstrap_mses = np.array(bootstrap_mses) + bootstrap_psnrs = np.array(bootstrap_psnrs) + bootstrap_ssims = np.array(bootstrap_ssims) + return bootstrap_mses, bootstrap_psnrs, bootstrap_ssims + +def compute_confidence_interval(list_of_items, confidence_interval=0.95): + # use this one, final version + assert confidence_interval > 0 and confidence_interval < 1 + mean_value = np.mean(list_of_items) + lower_bound = np.percentile(list_of_items, 50 * (1 - confidence_interval)) + upper_bound = np.percentile(list_of_items, 50 * (1 + confidence_interval)) + return mean_value, lower_bound, upper_bound +