diff --git a/lensless_imager/01_09_2024_mi_vs_classification_plots_updated_mi_cifar10.ipynb b/lensless_imager/01_09_2024_mi_vs_classification_plots_updated_mi_cifar10.ipynb index 9596c5e..5ef078d 100644 --- a/lensless_imager/01_09_2024_mi_vs_classification_plots_updated_mi_cifar10.ipynb +++ b/lensless_imager/01_09_2024_mi_vs_classification_plots_updated_mi_cifar10.ipynb @@ -30,7 +30,7 @@ "import sys\n", "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments')\n", - "from leyla_fns import *\n", + "from lensless_helpers import *\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", diff --git a/lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb b/lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb deleted file mode 100644 index b78d352..0000000 --- a/lensless_imager/01_17_2024_pixelcnn_cifar10_all_lenses.ipynb +++ /dev/null @@ -1,175 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "2f0168a5", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import os\n", - "from jax import config\n", - "config.update(\"jax_enable_x64\", True)\n", - "import sys\n", - "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", - "sys.path.append('/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments/')\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '2'\n", - "from encoding_information.gpu_utils import limit_gpu_memory_growth\n", - "limit_gpu_memory_growth()\n", - "\n", - "# import tensorflow_datasets as tfds # TFDS for MNIST #TODO INSTALL AGAIN LATER\n", - "#import tensorflow as tf # TensorFlow operations\n", - "\n", - "\n", - "\n", - "# from image_distribution_models import PixelCNN\n", - "\n", - "from cleanplots import *\n", - "import jax.numpy as np\n", - "from jax.scipy.special import logsumexp\n", - "import numpy as onp\n", - "\n", - "from leyla_fns import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "34552381", - "metadata": {}, - "outputs": [], - "source": [ - "from encoding_information.image_utils import add_noise, extract_patches\n", - "from encoding_information.models.gaussian_process import StationaryGaussianProcess\n", - "from encoding_information.models.pixel_cnn import PixelCNN\n", - "from encoding_information.information_estimation import estimate_mutual_information" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Sweep Photon Count and Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# load the PSFs\n", - "\n", - "one_psf = load_single_lens_uniform(32)\n", - "two_psf = load_two_lens_uniform(32)\n", - "three_psf = load_three_lens_uniform(32)\n", - "four_psf = load_four_lens_uniform(32)\n", - "five_psf = load_five_lens_uniform(32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# set seed values for reproducibility\n", - "seed_values_full = np.arange(1, 5)\n", - "\n", - "# set photon properties \n", - "bias = 10 # in photons\n", - "mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300]\n", - "\n", - "# set eligible psfs\n", - "\n", - "# psf_patterns = [None, one_psf, four_psf, diffuser_psf]\n", - "# psf_names = ['uc', 'one', 'four', 'diffuser']\n", - "psf_patterns = [one_psf, two_psf, three_psf, four_psf, five_psf]\n", - "psf_names = ['one', 'two', 'three', 'four', 'five']\n", - "# MI estimator parameters \n", - "patch_size = 32\n", - "num_patches = 10000\n", - "bs = 500\n", - "max_epochs = 50" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for photon_count in mean_photon_count_list:\n", - " for index, psf_pattern in enumerate(psf_patterns):\n", - " gaussian_mi_estimates = []\n", - " pixelcnn_mi_estimates = []\n", - " print('Mean photon count: {}, PSF: {}'.format(photon_count, psf_names[index]))\n", - " for seed_value in seed_values_full:\n", - " # load dataset\n", - " (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data()\n", - " data = onp.concatenate((x_train, x_test), axis=0) # make one big glob of data\n", - " data = data.astype(np.float32)\n", - " data /= onp.mean(data)\n", - " data *= photon_count # convert to photons with mean photon_count\n", - " labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. \n", - " # for CIFAR 100, need to convert images to grayscale\n", - " if len(data.shape) == 4:\n", - " data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale\n", - " data = data.squeeze()\n", - " # make tiled data\n", - " random_data, random_labels = generate_random_tiled_data(data, labels, seed_value)\n", - " \n", - " if psf_pattern is None:\n", - " start_idx = data.shape[-1] // 2\n", - " end_idx = data.shape[-1] // 2 - 1 \n", - " psf_data = random_data[:, start_idx:-end_idx, start_idx:-end_idx]\n", - " else:\n", - " psf_data = convolved_dataset(psf_pattern, random_data)\n", - " # add bias to data \n", - " psf_data += bias\n", - " # make patches and add noise\n", - " psf_data_patch = extract_patches(psf_data, patch_size=patch_size, num_patches=num_patches, seed=seed_value)\n", - " psf_data_shot_patch = add_noise(psf_data_patch, seed=seed_value, batch_size=bs)\n", - " # compute gaussian MI estimate, use comparison clean images\n", - " mi_gaussian_psf = estimate_mutual_information(psf_data_shot_patch, clean_images=psf_data_patch, entropy_model='gaussian',\n", - " max_epochs=max_epochs, verbose=True)\n", - " # compute PixelCNN MI estimate, use comparison clean images\n", - " mi_pixelcnn_psf = estimate_mutual_information(psf_data_shot_patch, clean_images=psf_data_patch, entropy_model='pixelcnn', num_val_samples=1000,\n", - " max_epochs=max_epochs, do_lr_decay=True, verbose=True)\n", - " gaussian_mi_estimates.append(mi_gaussian_psf)\n", - " pixelcnn_mi_estimates.append(mi_pixelcnn_psf)\n", - " #np.save('cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(pixelcnn_mi_estimates))\n", - " #np.save('cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(gaussian_mi_estimates))\n", - " # save the results once the seeds are done, file includes photon count and psf name\n", - " #np.save('cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(pixelcnn_mi_estimates))\n", - " #np.save('cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(gaussian_mi_estimates))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "phenotypes", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb b/lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb deleted file mode 100644 index 15c463a..0000000 --- a/lensless_imager/01_26_2024_sweep_deconvolution_procedure_per_lens.ipynb +++ /dev/null @@ -1,335 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Sweeping Wiener Deconvolution, 01/24/2024\n", - "\n", - "When you randomly tile, you can make the problem much harder for deconvolution. Info is getting pushed out of the FOV and info is getting pulled into the FOV without knowing where it came from. Cropped convolution ends up being a compressive sensing problem. Instead, doing the reconstruction on the padded FOV including the center 32x32 region with a black border. \n", - "\n", - "There is no bias in this system. However, poisson noise is being added at each photon count." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import os\n", - "from jax import config\n", - "config.update(\"jax_enable_x64\", True)\n", - "import sys\n", - "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", - "sys.path.append('/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments/')\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", - "from encoding_information.gpu_utils import limit_gpu_memory_growth\n", - "limit_gpu_memory_growth()\n", - "\n", - "# from image_distribution_models import PixelCNN\n", - "\n", - "from cleanplots import *\n", - "#import jax.numpy as np\n", - "from jax.scipy.special import logsumexp\n", - "import numpy as np\n", - "\n", - "from leyla_fns import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from encoding_information.image_utils import add_noise, extract_patches\n", - "from encoding_information.models.gaussian_process import StationaryGaussianProcess\n", - "from encoding_information.models.pixel_cnn import PixelCNN\n", - "from encoding_information.information_estimation import estimate_mutual_information" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from skimage.restoration import wiener, unsupervised_wiener, richardson_lucy\n", - "import skimage.metrics as skm\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# load the PSFs\n", - "\n", - "diffuser_psf = load_diffuser_32()\n", - "one_psf = load_single_lens_uniform(32)\n", - "two_psf = load_two_lens_uniform(32)\n", - "three_psf = load_three_lens_uniform(32)\n", - "four_psf = load_four_lens_uniform(32)\n", - "five_psf = load_five_lens_uniform(32)\n", - "aperture_psf = np.copy(diffuser_psf)\n", - "aperture_psf[:5] = 0\n", - "aperture_psf[-5:] = 0\n", - "aperture_psf[:,:5] = 0\n", - "aperture_psf[:,-5:] = 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def compute_skm_metrics(gt, recon):\n", - " # takes in already normalized gt\n", - " mse = skm.mean_squared_error(gt, recon)\n", - " psnr = skm.peak_signal_noise_ratio(gt, recon)\n", - " nmse = skm.normalized_root_mse(gt, recon)\n", - " ssim = skm.structural_similarity(gt, recon, data_range=1)\n", - " return mse, psnr, nmse, ssim" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# set seed values for reproducibility\n", - "seed_values_full = np.arange(1, 4)\n", - "\n", - "# set photon properties \n", - "mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300]\n", - "\n", - "# set eligible psfs\n", - "\n", - "psf_patterns = [None, one_psf, two_psf, three_psf, four_psf, five_psf, diffuser_psf, aperture_psf]\n", - "psf_names = ['uc', 'one', 'two', 'three', 'four', 'five', 'diffuser', 'aperture']\n", - "\n", - "# MI estimator parameters \n", - "patch_size = 32\n", - "num_patches = 10000\n", - "bs = 500\n", - "max_epochs = 50" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "psf_patterns_use = [one_psf, two_psf, three_psf, four_psf, five_psf, diffuser_psf, aperture_psf]\n", - "psf_names_use = ['one', 'two', 'three', 'four', 'five', 'diffuser', 'aperture']\n", - "\n", - "mean_photon_count_list = [300, 250, 200, 150, 100, 80, 60, 40, 20]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for photon_count in mean_photon_count_list:\n", - " for psf_idx, psf_use in enumerate(psf_patterns_use):\n", - " print('PSF: {}, Photon Count: {}'.format(psf_names_use[psf_idx], photon_count))\n", - " seed_value = 1\n", - " # make the data and scale by the photon count \n", - " (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data()\n", - " data = np.concatenate((x_train, x_test), axis=0) # make one big glob of data\n", - " data = data.astype(np.float64)\n", - " data /= np.mean(data)\n", - " data *= photon_count # convert to photons with mean value photon_count\n", - " max_val = np.max(data)\n", - " labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. \n", - " # for CIFAR 100, need to convert images to grayscale\n", - " if len(data.shape) == 4:\n", - " data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale\n", - " data = data.squeeze()\n", - " # zero pad data to be 96 x 96\n", - " data_padded = np.zeros((data.shape[0], 96, 96))\n", - " data_padded[:, 32:64, 32:64] = data\n", - "\n", - " convolved_data = convolved_dataset(psf_use, data_padded)\n", - " convolved_data_noise = add_noise(convolved_data)\n", - " # output of this noisy data is a jax array of float32, correct to regular numpy and float64\n", - " convolved_data_noise = np.array(convolved_data_noise).astype(np.float64)\n", - "\n", - " mse_psf = []\n", - " psnr_psf = []\n", - " for i in range(convolved_data_noise.shape[0]):\n", - " recon, _ = unsupervised_wiener(convolved_data_noise[i] / max_val, psf_use)\n", - " recon = recon[17:49, 17:49] #this is the crop window to look at\n", - " mse = skm.mean_squared_error(data[i] / max_val, recon)\n", - " psnr = skm.peak_signal_noise_ratio(data[i] / max_val, recon)\n", - " mse_psf.append(mse)\n", - " psnr_psf.append(psnr)\n", - " print('PSF: {}, Mean MSE: {}, Mean PSNR: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf)))\n", - " #np.save('unsupervised_wiener_deconvolution/recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf])\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Repeating Wiener Deconvolution including fixed seed=10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "psf_patterns_use = [one_psf, two_psf, three_psf, four_psf, five_psf, diffuser_psf, aperture_psf]\n", - "psf_names_use = ['one', 'two', 'three', 'four', 'five', 'diffuser', 'aperture']\n", - "\n", - "mean_photon_count_list = [300, 250, 200, 150, 100, 80, 60, 40, 20]\n", - "\n", - "seed_value = 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for photon_count in mean_photon_count_list:\n", - " for psf_idx, psf_use in enumerate(psf_patterns_use):\n", - " print('PSF: {}, Photon Count: {}'.format(psf_names_use[psf_idx], photon_count))\n", - " # make the data and scale by the photon count \n", - " (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data()\n", - " data = np.concatenate((x_train, x_test), axis=0) # make one big glob of data\n", - " data = data.astype(np.float64)\n", - " data /= np.mean(data)\n", - " data *= photon_count # convert to photons with mean value photon_count\n", - " max_val = np.max(data)\n", - " labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. \n", - " # for CIFAR 100, need to convert images to grayscale\n", - " if len(data.shape) == 4:\n", - " data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale\n", - " data = data.squeeze()\n", - " # zero pad data to be 96 x 96\n", - " data_padded = np.zeros((data.shape[0], 96, 96))\n", - " data_padded[:, 32:64, 32:64] = data\n", - "\n", - " convolved_data = convolved_dataset(psf_use, data_padded)\n", - " convolved_data_noise = add_noise(convolved_data, seed=seed_value)\n", - " # output of this noisy data is a jax array of float32, correct to regular numpy and float64\n", - " convolved_data_noise = np.array(convolved_data_noise).astype(np.float64)\n", - "\n", - " mse_psf = []\n", - " psnr_psf = []\n", - " for i in range(convolved_data_noise.shape[0]):\n", - " recon, _ = unsupervised_wiener(convolved_data_noise[i] / max_val, psf_use)\n", - " recon = recon[17:49, 17:49] #this is the crop window to look at\n", - " mse = skm.mean_squared_error(data[i] / max_val, recon)\n", - " psnr = skm.peak_signal_noise_ratio(data[i] / max_val, recon)\n", - " mse_psf.append(mse)\n", - " psnr_psf.append(psnr)\n", - " print('PSF: {}, Mean MSE: {}, Mean PSNR: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf)))\n", - " #np.save('unsupervised_wiener_deconvolution_fixed_seed/recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Archive: Detour to figure out jax types" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "type(convolved_data_noise)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "convolved_data = convolved_dataset(psf_use, data_padded)\n", - "convolved_data_noise = add_noise(convolved_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(type(convolved_data), convolved_data.dtype)\n", - "print(type(convolved_data_noise), convolved_data_noise.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "np" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "convolved_data_noise_test = np.array(convolved_data_noise).astype(np.float64)\n", - "print(type(convolved_data_noise_test))\n", - "recon, _ = unsupervised_wiener(convolved_data_noise_test[0] / max_val, psf_use) #TODO change to convolved_data_noise\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "convolved_data_noise_test = convolved_data_noise.astype(np.float64)\n", - "recon, _ = unsupervised_wiener(convolved_data_noise_test[0] / max_val, psf_use)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "info_jax_flax_23", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/lensless_imager/02_12_2024_make_lenses_3D_edges.ipynb b/lensless_imager/02_12_2024_make_lenses_3D_edges.ipynb index 899af57..bfbd03a 100644 --- a/lensless_imager/02_12_2024_make_lenses_3D_edges.ipynb +++ b/lensless_imager/02_12_2024_make_lenses_3D_edges.ipynb @@ -17,7 +17,7 @@ "import numpy as np \n", "import matplotlib.pyplot as plt\n", "import plotly\n", - "from leyla_fns import *\n", + "from lensless_helpers import *\n", "from cleanplots import *" ] }, diff --git a/lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb b/lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb deleted file mode 100644 index 9e57d71..0000000 --- a/lensless_imager/02_13_2024_pixelcnn_cifar10_extra_photon_counts.ipynb +++ /dev/null @@ -1,185 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "2f0168a5", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import os\n", - "from jax import config\n", - "config.update(\"jax_enable_x64\", True)\n", - "import sys\n", - "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", - "sys.path.append('/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments/')\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '2'\n", - "from encoding_information.gpu_utils import limit_gpu_memory_growth\n", - "limit_gpu_memory_growth()\n", - "\n", - "# import tensorflow_datasets as tfds # TFDS for MNIST #TODO INSTALL AGAIN LATER\n", - "#import tensorflow as tf # TensorFlow operations\n", - "\n", - "\n", - "\n", - "# from image_distribution_models import PixelCNN\n", - "\n", - "from cleanplots import *\n", - "import jax.numpy as np\n", - "from jax.scipy.special import logsumexp\n", - "import numpy as onp\n", - "\n", - "from leyla_fns import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "34552381", - "metadata": {}, - "outputs": [], - "source": [ - "from encoding_information.image_utils import add_noise, extract_patches\n", - "from encoding_information.models.gaussian_process import StationaryGaussianProcess\n", - "from encoding_information.models.pixel_cnn import PixelCNN\n", - "from encoding_information.information_estimation import estimate_mutual_information" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Sweep Photon Count and Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "48df0226", - "metadata": {}, - "outputs": [], - "source": [ - "diffuser_psf = load_diffuser_32()\n", - "one_psf = load_single_lens_uniform(32)\n", - "two_psf = load_two_lens_uniform(32)\n", - "three_psf = load_three_lens_uniform(32)\n", - "four_psf = load_four_lens_uniform(32)\n", - "five_psf = load_five_lens_uniform(32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# set seed values for reproducibility\n", - "seed_values_full = np.arange(1, 5)\n", - "\n", - "# set photon properties \n", - "bias = 10 # in photons\n", - "#mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300]\n", - "mean_photon_count_list = [160, 320]\n", - "\n", - "# set eligible psfs\n", - "\n", - "# psf_patterns = [None, one_psf, four_psf, diffuser_psf]\n", - "# psf_names = ['uc', 'one', 'four', 'diffuser']\n", - "psf_patterns = [one_psf, four_psf, diffuser_psf]\n", - "psf_names = ['one', 'four', 'diffuser']\n", - "\n", - "# MI estimator parameters \n", - "patch_size = 32\n", - "num_patches = 10000\n", - "bs = 500\n", - "max_epochs = 50" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for photon_count in mean_photon_count_list:\n", - " for index, psf_pattern in enumerate(psf_patterns):\n", - " gaussian_mi_estimates = []\n", - " pixelcnn_mi_estimates = []\n", - " print('Mean photon count: {}, PSF: {}'.format(photon_count, psf_names[index]))\n", - " for seed_value in seed_values_full:\n", - " # load dataset\n", - " (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data()\n", - " data = onp.concatenate((x_train, x_test), axis=0) # make one big glob of data\n", - " data = data.astype(np.float32)\n", - " data /= onp.mean(data)\n", - " data *= photon_count # convert to photons with mean value of photon_count\n", - " labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. \n", - " # for CIFAR 100, need to convert images to grayscale\n", - " if len(data.shape) == 4:\n", - " data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale\n", - " data = data.squeeze()\n", - " # make tiled data\n", - " random_data, random_labels = generate_random_tiled_data(data, labels, seed_value)\n", - " \n", - " if psf_pattern is None:\n", - " start_idx = data.shape[-1] // 2\n", - " end_idx = data.shape[-1] // 2 - 1 \n", - " psf_data = random_data[:, start_idx:-end_idx, start_idx:-end_idx]\n", - " else:\n", - " psf_data = convolved_dataset(psf_pattern, random_data)\n", - " # add small bias to data \n", - " psf_data += bias\n", - " # make patches and add noise\n", - " psf_data_patch = extract_patches(psf_data, patch_size=patch_size, num_patches=num_patches, seed=seed_value)\n", - " psf_data_shot_patch = add_noise(psf_data_patch, seed=seed_value, batch_size=bs)\n", - " # compute gaussian MI estimate, use comparison clean images\n", - " mi_gaussian_psf = estimate_mutual_information(psf_data_shot_patch, clean_images=psf_data_patch, entropy_model='gaussian',\n", - " max_epochs=max_epochs, verbose=True)\n", - " # compute PixelCNN MI estimate, use comparison clean images\n", - " mi_pixelcnn_psf = estimate_mutual_information(psf_data_shot_patch, clean_images=psf_data_patch, entropy_model='pixelcnn', num_val_samples=1000,\n", - " max_epochs=max_epochs, do_lr_decay=True, verbose=True)\n", - " gaussian_mi_estimates.append(mi_gaussian_psf)\n", - " pixelcnn_mi_estimates.append(mi_pixelcnn_psf)\n", - " #np.save('cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(pixelcnn_mi_estimates))\n", - " #np.save('cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(gaussian_mi_estimates))\n", - " # save the results once the seeds are done, file includes photon count and psf name\n", - " #np.save('cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(pixelcnn_mi_estimates))\n", - " #np.save('cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(gaussian_mi_estimates))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f667a120", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "phenotypes", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py b/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py deleted file mode 100644 index 3da899f..0000000 --- a/lensless_imager/02_13_2024_sweep_wiener_deconvolution_per_lens_extra_photon_counts.py +++ /dev/null @@ -1,195 +0,0 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.15.2 -# kernelspec: -# display_name: info_jax_flax_23 -# language: python -# name: python3 -# --- - -# %% [markdown] -# ## Sweeping non-unsupervised Wiener Deconvolution with hand-tuned parameter, 01/29/2024 -# -# Using a fixed seed (10) for consistency. - -# %% -# %load_ext autoreload -# %autoreload 2 - -import os -from jax import config -config.update("jax_enable_x64", True) -import sys -sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/') -sys.path.append('/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments/') - -os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["CUDA_VISIBLE_DEVICES"] = '3' -from encoding_information.gpu_utils import limit_gpu_memory_growth -limit_gpu_memory_growth() - -# from image_distribution_models import PixelCNN - -from cleanplots import * -#import jax.numpy as np -from jax.scipy.special import logsumexp -import numpy as np - -from leyla_fns import * - -# %% -from encoding_information.image_utils import add_noise, extract_patches -from encoding_information.models.gaussian_process import StationaryGaussianProcess -from encoding_information.models.pixel_cnn import PixelCNN -from encoding_information.information_estimation import estimate_mutual_information - -# %% -from skimage.restoration import wiener, unsupervised_wiener, richardson_lucy -import skimage.metrics as skm - - -# %% -# load the PSFs - -diffuser_psf = load_diffuser_32() -one_psf = load_single_lens_uniform(32) -two_psf = load_two_lens_uniform(32) -three_psf = load_three_lens_uniform(32) -four_psf = load_four_lens_uniform(32) -five_psf = load_five_lens_uniform(32) -aperture_psf = np.copy(diffuser_psf) -aperture_psf[:5] = 0 -aperture_psf[-5:] = 0 -aperture_psf[:,:5] = 0 -aperture_psf[:,-5:] = 0 - - -# %% -def compute_skm_metrics(gt, recon): - # takes in already normalized gt - mse = skm.mean_squared_error(gt, recon) - psnr = skm.peak_signal_noise_ratio(gt, recon) - nmse = skm.normalized_root_mse(gt, recon) - ssim = skm.structural_similarity(gt, recon, data_range=1) - return mse, psnr, nmse, ssim - - -# %% -# set seed values for reproducibility -seed_values_full = np.arange(1, 4) - -# set photon properties -#mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300] -mean_photon_count_list = [160, 320] - -# set eligible psfs - -psf_patterns = [None, one_psf, two_psf, three_psf, four_psf, five_psf, diffuser_psf, aperture_psf] -psf_names = ['uc', 'one', 'two', 'three', 'four', 'five', 'diffuser', 'aperture'] - -# MI estimator parameters -patch_size = 32 -num_patches = 10000 -bs = 500 -max_epochs = 50 - -# %% -reg_value_best = 10**-2 -print(reg_value_best) - -# %% [markdown] -# ## Regular Wiener Deconvolution including fixed seed 10 - -# %% -psf_patterns_use = [one_psf, two_psf, three_psf, four_psf, five_psf, diffuser_psf, aperture_psf] -psf_names_use = ['one', 'two', 'three', 'four', 'five', 'diffuser', 'aperture'] - -#mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300] -mean_photon_count_list = [160, 320] - -seed_value = 10 - - -for photon_count in mean_photon_count_list: - for psf_idx, psf_use in enumerate(psf_patterns_use): - print('PSF: {}, Photon Count: {}'.format(psf_names_use[psf_idx], photon_count)) - # make the data and scale by the photon count - (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() - data = np.concatenate((x_train, x_test), axis=0) # make one big glob of data - data = data.astype(np.float64) - data /= np.mean(data) - data *= photon_count # convert to photons with mean value of photon_count - max_val = np.max(data) - labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. - # for CIFAR 100, need to convert images to grayscale - if len(data.shape) == 4: - data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale - data = data.squeeze() - # zero pad data to be 96 x 96 - data_padded = np.zeros((data.shape[0], 96, 96)) - data_padded[:, 32:64, 32:64] = data - - convolved_data = convolved_dataset(psf_use, data_padded) - convolved_data_noise = add_noise(convolved_data, seed=seed_value) - # output of this noisy data is a jax array of float32, correct to regular numpy and float64 - convolved_data_noise = np.array(convolved_data_noise).astype(np.float64) - - mse_psf = [] - psnr_psf = [] - for i in range(convolved_data_noise.shape[0]): - recon, _ = unsupervised_wiener(convolved_data_noise[i] / max_val, psf_use) - recon = recon[17:49, 17:49] #this is the crop window to look at - mse = skm.mean_squared_error(data[i] / max_val, recon) - psnr = skm.peak_signal_noise_ratio(data[i] / max_val, recon) - mse_psf.append(mse) - psnr_psf.append(psnr) - print('PSF: {}, Mean MSE: {}, Mean PSNR: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf))) - #np.save('unsupervised_wiener_deconvolution_fixed_seed/recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf]) - - - -# %% -for photon_count in mean_photon_count_list: - for psf_idx, psf_use in enumerate(psf_patterns_use): - print('PSF: {}, Photon Count: {}'.format(psf_names_use[psf_idx], photon_count)) - # make the data and scale by the photon count - (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() - data = np.concatenate((x_train, x_test), axis=0) # make one big glob of data - data = data.astype(np.float64) - data /= np.mean(data) - data *= photon_count # convert to photons with mean value of photon_count - max_val = np.max(data) - labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. - # for CIFAR 100, need to convert images to grayscale - if len(data.shape) == 4: - data = tf.image.rgb_to_grayscale(data).numpy() # convert to grayscale - data = data.squeeze() - # zero pad data to be 96 x 96 - data_padded = np.zeros((data.shape[0], 96, 96)) - data_padded[:, 32:64, 32:64] = data - - convolved_data = convolved_dataset(psf_use, data_padded) - convolved_data_noise = add_noise(convolved_data, seed=seed_value) - # output of this noisy data is a jax array of float32, correct to regular numpy and float64 - convolved_data_noise = np.array(convolved_data_noise).astype(np.float64) - - mse_psf = [] - psnr_psf = [] - for i in range(convolved_data_noise.shape[0]): - recon = wiener(convolved_data_noise[i] / max_val, psf_use, reg_value_best) - recon = recon[17:49, 17:49] #this is the crop window to look at - mse = skm.mean_squared_error(data[i] / max_val, recon) - psnr = skm.peak_signal_noise_ratio(data[i] / max_val, recon) - mse_psf.append(mse) - psnr_psf.append(psnr) - print('PSF: {}, Mean MSE: {}, Mean PSNR: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf))) - #np.save('regular_wiener_deconvolution_fixed_seed/recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf]) - -# %% - - diff --git a/lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb b/lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb deleted file mode 100644 index 166c75e..0000000 --- a/lensless_imager/02_14_2024_mi_vs_deconvolution_plots_cifar10_all_systems_log_photons.ipynb +++ /dev/null @@ -1,708 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Make the plot for MI and deconvolution relationship, 01/29/2024" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload \n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from jax import config\n", - "config.update(\"jax_enable_x64\", True)\n", - "import numpy as np\n", - "\n", - "import sys\n", - "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')\n", - "sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments')\n", - "from leyla_fns import *\n", - "import os\n", - "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", - "print(os.environ.get('PYTHONPATH'))\n", - "from cleanplots import * " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "seed_value = 10\n", - "\n", - "# set photon properties \n", - "bias = 10 # in photons\n", - "mean_photon_count_list = [20, 40, 60, 80, 100, 150, 160, 200, 250, 300, 320]\n", - "max_photon_count = mean_photon_count_list[-1]\n", - "\n", - "# set eligible psfs\n", - "\n", - "psf_names = ['one', 'four', 'diffuser'] # later make it all of them, but haven't gotten diffuser and aperture yet\n", - "\n", - "# MI estimator parameters \n", - "patch_size = 32\n", - "num_patches = 10000\n", - "bs = 500" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load MI data and make plots of it\n", - "Using updated MI data from 01/17/2024 which is run for the uniform data\n", - "\n", - "The plot has essentially invisible error bars. No more outlier issues" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from cleanplots import *\n", - "get_color_cycle()[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mi_folder = ''" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Minimum plot with no error bars" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gaussian_mi_estimates_across_psfs = [] # only keeps the minimum values, no outliers\n", - "pixelcnn_mi_estimates_across_psfs = [] # only keeps the minimum values, no outliers\n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", - "for psf_name in psf_names:\n", - " gaussian_across_photons = [] \n", - " pixelcnn_across_photons = []\n", - " for photon_count in mean_photon_count_list:\n", - " #gaussian_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", - " pixelcnn_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", - " #gaussian_across_photons.append(gaussian_mi_estimate)\n", - " pixelcnn_across_photons.append(pixelcnn_mi_estimate)\n", - " assert pixelcnn_mi_estimate.shape[0] == 4\n", - " #gaussian_mins = np.min(gaussian_across_photons, axis=1)\n", - " pixelcnn_mins = np.min(pixelcnn_across_photons, axis=1)\n", - " ax.plot(mean_photon_count_list, gaussian_mins, '-', label='Gaussian {}'.format(psf_name))\n", - " ax.plot(mean_photon_count_list, pixelcnn_mins, '-', label='PixelCNN {}'.format(psf_name))\n", - " gaussian_mi_estimates_across_psfs.append(gaussian_mins) # only keep mean dataset for use\n", - " pixelcnn_mi_estimates_across_psfs.append(pixelcnn_mins) # only keep mean datas\n", - "plt.legend()\n", - "plt.title(\"Gaussian vs. PixelCNN MI Estimates Across Photon Count, CIFAR10, 4 Seeds, Minimums\")\n", - "plt.ylabel('Estimated Mutual Information')\n", - "plt.xlabel('Mean Photon Count')\n", - "\n", - "gaussian_mi_estimates_across_psfs = np.array(gaussian_mi_estimates_across_psfs)\n", - "pixelcnn_mi_estimates_across_psfs = np.array(pixelcnn_mi_estimates_across_psfs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "psf_names_verbose = ['One Lens', 'Two Lens', 'Three Lens', 'Four Lens', 'Five Lens', 'Diffuser', 'Aperture']\n", - "plt.figure(figsize=(6, 5))\n", - "ax = plt.axes()\n", - "for i, modality in enumerate(psf_names_verbose):\n", - " #plt.plot(mean_photon_count_list, gaussian_mi_estimates_across_psfs[i], label = '{} Gaussian'.format(modality), color = get_color_cycle()[i], linestyle='--')\n", - " plt.plot(mean_photon_count_list, pixelcnn_mi_estimates_across_psfs[i], label = '{}'.format(modality), color = get_color_cycle()[i-1]) # manual color correct\n", - "plt.legend()\n", - "plt.xlabel('Mean Photon Count')\n", - "plt.ylabel(\"Mutual Information (bits per pixel)\")\n", - "#plt.title('Estimated Mutual Information vs. Mean Photon Count, CIFAR10')\n", - "clear_spines(ax)\n", - "#plt.savefig('mi_vs_photon_count.pdf', bbox_inches='tight', transparent=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Mean plot with error bars included" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", - "for psf_name in psf_names:\n", - " gaussian_across_photons = [] \n", - " pixelcnn_across_photons = []\n", - " for photon_count in mean_photon_count_list:\n", - " #gaussian_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", - " pixelcnn_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", - " if np.max(pixelcnn_mi_estimate) / np.min(pixelcnn_mi_estimate) > 2:\n", - " pixelcnn_mi_estimate[pixelcnn_mi_estimate > 2 * np.min(pixelcnn_mi_estimate)] = np.min(pixelcnn_mi_estimate)\n", - " #gaussian_across_photons.append(gaussian_mi_estimate)\n", - " pixelcnn_across_photons.append(pixelcnn_mi_estimate)\n", - " #error_lo, error_hi, mean = confidence_bars(gaussian_across_photons, 9)\n", - " error_lo_2, error_hi_2, mean_2 = confidence_bars(pixelcnn_across_photons, 11)\n", - " #ax.plot(mean_photon_count_list, mean, '-', label='Gaussian {}'.format(psf_name))\n", - " ax.plot(mean_photon_count_list, mean_2, '-', label='PixelCNN {}'.format(psf_name))\n", - " #ax.fill_between(mean_photon_count_list, error_lo, error_hi, alpha=0.4)\n", - " ax.fill_between(mean_photon_count_list, error_lo_2, error_hi_2, alpha=0.4)\n", - "plt.legend()\n", - "plt.title(\"Gaussian vs. PixelCNN MI Estimates Across Photon Count, CIFAR10, 4 Seeds, Means, Outliers Removed\")\n", - "plt.ylabel('Estimated Mutual Information')\n", - "plt.xlabel('Mean Photon Count')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load deconvolution data and make plots of it\n", - "Use means" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "deconvolution_folder = 'unsupervised_wiener_deconvolution_fixed_seed/'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mse_across_psfs = [] #5 x 9 x 1 array, 5 psfs, 9 photon counts, one value on each \n", - "psnr_across_psfs = [] #5 x 9 x 1 array, 5 psfs, 9 photon counts, one value on each\n", - "mse_lists_across_psfs = []\n", - "psnr_lists_across_psfs = []\n", - "for psf_name in psf_names:\n", - " mse_across_photons = []\n", - " psnr_across_photons = []\n", - " mse_lists_across_photons = []\n", - " psnr_lists_across_photons = []\n", - " for photon_count in mean_photon_count_list:\n", - " mse_list, psnr_list = np.load(deconvolution_folder + 'recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", - " mse_list = np.array(mse_list)\n", - " psnr_list = np.array(psnr_list)\n", - " mean_mse = np.mean(mse_list)\n", - " mean_psnr = np.mean(psnr_list)\n", - " mse_across_photons.append(mean_mse)\n", - " psnr_across_photons.append(mean_psnr)\n", - " mse_lists_across_photons.append(mse_list)\n", - " psnr_lists_across_photons.append(psnr_list)\n", - " mse_across_psfs.append(mse_across_photons)\n", - " psnr_across_psfs.append(psnr_across_photons)\n", - " mse_lists_across_psfs.append(mse_lists_across_photons)\n", - " psnr_lists_across_psfs.append(psnr_lists_across_photons)\n", - "mse_across_psfs = np.array(mse_across_psfs)\n", - "psnr_across_psfs = np.array(psnr_across_psfs)\n", - "mse_lists_across_psfs = np.array(mse_lists_across_psfs)\n", - "psnr_lists_across_psfs = np.array(psnr_lists_across_psfs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for idx, psf_name in enumerate(psf_names):\n", - " plt.plot(mean_photon_count_list, mse_across_psfs[idx], label='{}'.format(psf_name))\n", - "plt.legend()\n", - "plt.title(\"Deconvolution MSE vs. Mean Photon Count, CIFAR10\")\n", - "plt.ylabel('Mean Squared Error')\n", - "plt.xlabel('Mean Photon Count')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Make figures, include classifier error bars" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def marker_for_psf(psf_name):\n", - " if psf_name =='one':\n", - " marker = 'o'\n", - " elif psf_name == 'four':\n", - " marker = 's' \n", - " elif psf_name == 'diffuser':\n", - " marker = '*'\n", - " elif psf_name == 'uc':\n", - " marker = 'x'\n", - " elif psf_name =='two':\n", - " marker = 'd'\n", - " elif psf_name == 'three':\n", - " marker = 'v'\n", - " elif psf_name == 'five':\n", - " marker = 'p'\n", - " elif psf_name == 'aperture':\n", - " marker = 'P'\n", - " return marker" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Choose a base colormap\n", - "base_colormap = plt.cm.get_cmap('inferno')\n", - "# Define the start and end points--used so that high values aren't too light against white background\n", - "start, end = 0, 0.88 # making end point 0.8\n", - "from matplotlib.colors import LinearSegmentedColormap\n", - "# Create a new colormap from the portion of the original colormap\n", - "colormap = LinearSegmentedColormap.from_list(\n", - " 'trunc({n},{a:.2f},{b:.2f})'.format(n=base_colormap.name, a=start, b=end),\n", - " base_colormap(np.linspace(start, end, 256))\n", - ")\n", - "\n", - "min_photons_per_pixel = min(mean_photon_count_list)\n", - "max_photons_per_pixel = max(mean_photon_count_list)\n", - "\n", - "min_log_photons = np.log(min_photons_per_pixel)\n", - "max_log_photons = np.log(max_photons_per_pixel)\n", - "\n", - "def color_for_photon_level(photons_per_pixel):\n", - " log_photons = np.log(photons_per_pixel)\n", - " return colormap((log_photons - min_log_photons) / (max_log_photons - min_log_photons) )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Update parameters in below block to display the things you want to display, then run the block after to make the figure" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "estimator_type = 1 # 0 for gaussian, 1 for pixelcnn\n", - "metric_type = 1\n", - "valid_psfs = [0, 1, 2]\n", - "valid_photon_counts = [20, 40, 80, 150, 160, 300, 320]\n", - "print(psf_names)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "confidence_level = 0.9 \n", - "# using min-valued MI estimates \n", - "mi_estimate_lists = [gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]\n", - "deconv_estimate_lists = [mse_across_psfs, psnr_across_psfs]\n", - "metric_names = ['MSE', 'PSNR']\n", - "# classifier array is classifier_all_trials_across_psfs, 4x9x10 array. 4 psfs, 9 photon counts, 10 trials on each one \n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n", - "\n", - "mi_list_use = mi_estimate_lists[estimator_type] # use pixelcnn or gaussian, choose pixelcnn \n", - "deconv_list_use = deconv_estimate_lists[metric_type] # use mse or psnr, choose psnr\n", - "metric_name = metric_names[metric_type]\n", - "\n", - "for psf_idx, psf_name in enumerate(psf_names):\n", - " if psf_idx in valid_psfs:\n", - " mi_means_across_photons = [] # track mean MI values to make trendline \n", - " deconv_means_across_photons = [] # track mean MI values to make trendline\n", - " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", - " if photon_count in valid_photon_counts:\n", - " # load mean values and colors to plot \n", - " color = color_for_photon_level(photon_count)\n", - " mi_value = mi_list_use[psf_idx][photon_idx] # only use an MI value if the psf is valid, correctly indexed \n", - " deconv_value = deconv_list_use[psf_idx][photon_idx]\n", - " ax.scatter(mi_value, deconv_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", - " # add to lists to track later \n", - " mi_means_across_photons.append(mi_value)\n", - " deconv_means_across_photons.append(deconv_value)\n", - " mi_means_across_photons = np.array(mi_means_across_photons)\n", - " deconv_means_across_photons = np.array(deconv_means_across_photons)\n", - " ax.plot(mi_means_across_photons, deconv_means_across_photons, '--', color='grey', alpha=1, linewidth=2)\n", - "\n", - "ax.set_xlabel('Mutual Information (bits per pixel)')\n", - "ax.set_ylabel('Deconvolution Performance, {}'.format(metric_name))\n", - "clear_spines(ax)\n", - "\n", - "\n", - "# legend\n", - "# ax.scatter([], [], color='k', marker='x', label='No PSF')\n", - "ax.scatter([], [], color='k', marker='o', label='One Lens')\n", - "# ax.scatter([], [], color='k', marker='d', label='Two Lens')\n", - "# ax.scatter([], [], color='k', marker='v', label='Three Lens')\n", - "ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", - "# ax.scatter([], [], color='k', marker='p', label='Five Lens')\n", - "ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", - "# ax.scatter([], [], color='k', marker='P', label='Aperture')\n", - "\n", - "ax.legend(loc='lower right', frameon=True)\n", - "ax.set_xlim([0, None])\n", - "\n", - "\n", - "\n", - "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", - "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", - "sm.set_array([])\n", - "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", - "# set tick labels\n", - "cbar.ax.set_yticklabels(valid_photon_counts)\n", - "\n", - "\n", - "cbar.set_label('Mean Photon Count')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "estimator_type = 1 # 0 for gaussian, 1 for pixelcnn\n", - "metric_type = 0 # 0 for MSE, 1 for PSNR\n", - "valid_psfs = [0, 1, 2, 3, 4, 5, 6]\n", - "valid_photon_counts = [20, 40, 80, 150, 160, 300, 320]\n", - "print(psf_names)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "confidence_level = 0.9 \n", - "# using min-valued MI estimates \n", - "mi_estimate_lists = [gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]\n", - "deconv_estimate_lists = [mse_across_psfs, psnr_across_psfs]\n", - "metric_names = ['MSE', 'PSNR']\n", - "# classifier array is classifier_all_trials_across_psfs, 4x9x10 array. 4 psfs, 9 photon counts, 10 trials on each one \n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n", - "\n", - "mi_list_use = mi_estimate_lists[estimator_type] # use pixelcnn or gaussian, choose pixelcnn \n", - "deconv_list_use = deconv_estimate_lists[metric_type] # use mse or psnr, choose psnr\n", - "metric_name = metric_names[metric_type]\n", - "\n", - "for psf_idx, psf_name in enumerate(psf_names):\n", - " if psf_idx in valid_psfs:\n", - " mi_means_across_photons = [] # track mean MI values to make trendline \n", - " deconv_means_across_photons = [] # track mean MI values to make trendline\n", - " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", - " if photon_count in valid_photon_counts:\n", - " # load mean values and colors to plot \n", - " color = color_for_photon_level(photon_count)\n", - " mi_value = mi_list_use[psf_idx][photon_idx] # only use an MI value if the psf is valid, correctly indexed \n", - " deconv_value = deconv_list_use[psf_idx][photon_idx]\n", - " ax.scatter(mi_value, deconv_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", - " # add to lists to track later \n", - " mi_means_across_photons.append(mi_value)\n", - " deconv_means_across_photons.append(deconv_value)\n", - " mi_means_across_photons = np.array(mi_means_across_photons)\n", - " deconv_means_across_photons = np.array(deconv_means_across_photons)\n", - " ax.plot(mi_means_across_photons, deconv_means_across_photons, '--', color='grey', alpha=1, linewidth=2)\n", - "\n", - "ax.set_xlabel('Mutual Information (bits per pixel)')\n", - "ax.set_ylabel('Deconvolution Performance, {}'.format(metric_name))\n", - "clear_spines(ax)\n", - "\n", - "\n", - "# legend\n", - "# ax.scatter([], [], color='k', marker='x', label='No PSF')\n", - "ax.scatter([], [], color='k', marker='o', label='One Lens')\n", - "ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", - "ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", - "\n", - "ax.legend(loc='upper right', frameon=True)\n", - "ax.set_xlim([0, None])\n", - "\n", - "\n", - "\n", - "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", - "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", - "sm.set_array([])\n", - "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", - "# set tick labels\n", - "cbar.ax.set_yticklabels(valid_photon_counts)\n", - "\n", - "\n", - "cbar.set_label('Mean Photon Count')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Repeat same thing for just one lens, four lens and diffuser, include error bars - final figure" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "estimator_type = 1 # 0 for gaussian, 1 for pixelcnn\n", - "metric_type = 1\n", - "valid_psfs = [0, 1, 2]\n", - "valid_photon_counts = [20, 40, 80, 160, 320]\n", - "print([psf_names[i] for i in valid_psfs])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "confidence_level = 0.9\n", - "# using min-valued MI estimates \n", - "mi_estimate_lists = [gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]\n", - "#deconv_estimate_lists = [mse_across_psfs, psnr_across_psfs]\n", - "deconv_estimate_lists = [mse_lists_across_psfs, psnr_lists_across_psfs] # use full list versions\n", - "metric_names = ['MSE', 'PSNR']\n", - "# classifier array is classifier_all_trials_across_psfs, 4x9x10 array. 4 psfs, 9 photon counts, 10 trials on each one \n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n", - "\n", - "mi_list_use = mi_estimate_lists[estimator_type] # use pixelcnn or gaussian, choose pixelcnn \n", - "deconv_list_use = deconv_estimate_lists[metric_type] # use mse or psnr, choose psnr\n", - "metric_name = metric_names[metric_type]\n", - "\n", - "for psf_idx, psf_name in enumerate(psf_names):\n", - " if psf_idx in valid_psfs:\n", - " mi_means_across_photons = [] # track mean MI values to make trendline \n", - " deconv_means_across_photons = [] # track mean deconvolution values to make trendline\n", - " deconv_lower_across_photons = [] # track lower bounds\n", - " deconv_upper_across_photons = [] # track upper bounds\n", - " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", - " if photon_count in valid_photon_counts:\n", - " # load mean values and colors to plot \n", - " color = color_for_photon_level(photon_count)\n", - " mi_value = mi_list_use[psf_idx][photon_idx] # only use an MI value if the psf is valid, correctly indexed \n", - " deconv_value = np.mean(deconv_list_use[psf_idx][photon_idx])\n", - " ax.scatter(mi_value, deconv_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", - " # add to lists to track later \n", - " mi_means_across_photons.append(mi_value)\n", - " deconv_means_across_photons.append(deconv_value)\n", - " # calculate error bars \n", - " deconv_lower_across_photons.append(np.percentile(deconv_list_use[psf_idx][photon_idx], 100 - 100 * (1 + confidence_level) / 2))\n", - " deconv_upper_across_photons.append(np.percentile(deconv_list_use[psf_idx][photon_idx], 100 * (1 + confidence_level) / 2))\n", - " mi_means_across_photons = np.array(mi_means_across_photons)\n", - " deconv_means_across_photons = np.array(deconv_means_across_photons)\n", - " ax.plot(mi_means_across_photons, deconv_means_across_photons, '--', color='grey', alpha=1, linewidth=2)\n", - " ax.fill_between(mi_means_across_photons, deconv_lower_across_photons, deconv_upper_across_photons, color='grey', alpha=0.3, linewidth=0, zorder=-100)\n", - "\n", - "ax.set_xlabel('Mutual Information (bits per pixel)')\n", - "ax.set_ylabel('Deconvolution Performance, {}'.format(metric_name))\n", - "clear_spines(ax)\n", - "\n", - "\n", - "# legend\n", - "# ax.scatter([], [], color='k', marker='x', label='No PSF')\n", - "ax.scatter([], [], color='k', marker='o', label='One Lens')\n", - "#ax.scatter([], [], color='k', marker='d', label='Two Lens')\n", - "#ax.scatter([], [], color='k', marker='v', label='Three Lens')\n", - "ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", - "#ax.scatter([], [], color='k', marker='p', label='Five Lens')\n", - "ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", - "#ax.scatter([], [], color='k', marker='P', label='Aperture')\n", - "\n", - "ax.legend(loc='lower right', frameon=True)\n", - "ax.set_xlim([0, None])\n", - "\n", - "\n", - "\n", - "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", - "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", - "sm.set_array([])\n", - "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", - "# set tick labels\n", - "cbar.ax.set_yticklabels(valid_photon_counts)\n", - "\n", - "\n", - "cbar.set_label('Mean Photon Count')\n", - "\n", - "\n", - "#plt.savefig('{}_vs_MI_with_confidence_intervals_log_photons.pdf'.format(metric_name), bbox_inches='tight', transparent=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "estimator_type = 1 # 0 for gaussian, 1 for pixelcnn\n", - "metric_type = 0 # 0 for MSE, 1 for PSNR\n", - "valid_psfs = [0, 1, 2]\n", - "valid_photon_counts = [20, 40, 80, 160, 320]\n", - "print([psf_names[i] for i in valid_psfs])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "confidence_level = 0.9\n", - "# using min-valued MI estimates \n", - "mi_estimate_lists = [gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]\n", - "#deconv_estimate_lists = [mse_across_psfs, psnr_across_psfs]\n", - "deconv_estimate_lists = [mse_lists_across_psfs, psnr_lists_across_psfs] # use full list versions\n", - "metric_names = ['MSE', 'PSNR']\n", - "# classifier array is classifier_all_trials_across_psfs, 4x9x10 array. 4 psfs, 9 photon counts, 10 trials on each one \n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n", - "\n", - "mi_list_use = mi_estimate_lists[estimator_type] # use pixelcnn or gaussian, choose pixelcnn \n", - "deconv_list_use = deconv_estimate_lists[metric_type] # use mse or psnr, choose psnr\n", - "metric_name = metric_names[metric_type]\n", - "\n", - "for psf_idx, psf_name in enumerate(psf_names):\n", - " if psf_idx in valid_psfs:\n", - " mi_means_across_photons = [] # track mean MI values to make trendline \n", - " deconv_means_across_photons = [] # track mean deconvolution values to make trendline\n", - " deconv_lower_across_photons = [] # track lower bounds\n", - " deconv_upper_across_photons = [] # track upper bounds\n", - " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", - " if photon_count in valid_photon_counts:\n", - " # load mean values and colors to plot \n", - " color = color_for_photon_level(photon_count)\n", - " mi_value = mi_list_use[psf_idx][photon_idx] # only use an MI value if the psf is valid, correctly indexed \n", - " deconv_value = np.mean(deconv_list_use[psf_idx][photon_idx])\n", - " ax.scatter(mi_value, deconv_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", - " # add to lists to track later \n", - " mi_means_across_photons.append(mi_value)\n", - " deconv_means_across_photons.append(deconv_value)\n", - " # calculate error bars \n", - " deconv_lower_across_photons.append(np.percentile(deconv_list_use[psf_idx][photon_idx], 100 - 100 * (1 + confidence_level) / 2))\n", - " deconv_upper_across_photons.append(np.percentile(deconv_list_use[psf_idx][photon_idx], 100 * (1 + confidence_level) / 2))\n", - " mi_means_across_photons = np.array(mi_means_across_photons)\n", - " deconv_means_across_photons = np.array(deconv_means_across_photons)\n", - " ax.plot(mi_means_across_photons, deconv_means_across_photons, '--', color='grey', alpha=1, linewidth=2)\n", - " ax.fill_between(mi_means_across_photons, deconv_lower_across_photons, deconv_upper_across_photons, color='grey', alpha=0.3, linewidth=0, zorder=-100)\n", - "\n", - "ax.set_xlabel('Mutual Information (bits per pixel)')\n", - "ax.set_ylabel('Deconvolution Performance, {}'.format(metric_name))\n", - "clear_spines(ax)\n", - "\n", - "\n", - "# legend\n", - "# ax.scatter([], [], color='k', marker='x', label='No PSF')\n", - "ax.scatter([], [], color='k', marker='o', label='One Lens')\n", - "#ax.scatter([], [], color='k', marker='d', label='Two Lens')\n", - "#ax.scatter([], [], color='k', marker='v', label='Three Lens')\n", - "ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", - "#ax.scatter([], [], color='k', marker='p', label='Five Lens')\n", - "ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", - "#ax.scatter([], [], color='k', marker='P', label='Aperture')\n", - "\n", - "ax.legend(loc='upper right', frameon=True)\n", - "ax.set_xlim([0, None])\n", - "\n", - "\n", - "\n", - "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", - "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", - "sm.set_array([])\n", - "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", - "# set tick labels\n", - "cbar.ax.set_yticklabels(valid_photon_counts)\n", - "\n", - "\n", - "cbar.set_label('Mean Photon Count')\n", - "\n", - "\n", - "#plt.savefig('{}_vs_MI_with_confidence_intervals_log_photons.pdf'.format(metric_name), bbox_inches='tight', transparent=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "info_jax", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/lensless_imager/04_04_2024_make_lenses_3D_edges_IDEAL.ipynb b/lensless_imager/04_04_2024_make_lenses_3D_edges_IDEAL.ipynb index 37f9cf8..64e76e3 100644 --- a/lensless_imager/04_04_2024_make_lenses_3D_edges_IDEAL.ipynb +++ b/lensless_imager/04_04_2024_make_lenses_3D_edges_IDEAL.ipynb @@ -18,7 +18,7 @@ "import plotly\n", "import sys\n", "sys.path.append('/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments')\n", - "from leyla_fns import *\n", + "from lensless_helpers import *\n", "from cleanplots import *\n" ] }, diff --git a/lensless_imager/11_14_2023_run_classifier_cifar10.py b/lensless_imager/11_14_2023_run_classifier_cifar10.py index 2f1d528..7893d27 100644 --- a/lensless_imager/11_14_2023_run_classifier_cifar10.py +++ b/lensless_imager/11_14_2023_run_classifier_cifar10.py @@ -33,7 +33,7 @@ from jax.scipy.special import logsumexp import numpy as onp -from leyla_fns import * +from lensless_helpers import * from encoding_information.image_utils import add_noise # %% diff --git a/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py b/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py new file mode 100644 index 0000000..15fc594 --- /dev/null +++ b/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py @@ -0,0 +1,176 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.15.2 +# kernelspec: +# display_name: infotransformer +# language: python +# name: python3 +# --- + +# %% [markdown] +# ## Sweeping both unsupervised Wiener Deconvolution and non-unsupervised Wiener Deconvolution with hand-tuned paramete +# +# Using a fixed seed (10) for consistency. + +# %% +# %load_ext autoreload +# %autoreload 2 + +import os +from jax import config +config.update("jax_enable_x64", True) +import sys +sys.path.append('/home/lakabuli/workspace/EncodingInformation/src') + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = '1' +from encoding_information.gpu_utils import limit_gpu_memory_growth +limit_gpu_memory_growth() + + +from cleanplots import * +import numpy as np +import tensorflow as tf +import tensorflow.keras as tfk + +from lensless_helpers import * +from tqdm import tqdm + +# %% +from encoding_information.image_utils import add_noise +from skimage.restoration import wiener, unsupervised_wiener, richardson_lucy +import skimage.metrics as skm + +# %% +# load the PSFs + +diffuser_psf = load_diffuser_32() +one_psf = load_single_lens_uniform(32) +two_psf = load_two_lens_uniform(32) +three_psf = load_three_lens_uniform(32) +four_psf = load_four_lens_uniform(32) +five_psf = load_five_lens_uniform(32) +aperture_psf = np.copy(diffuser_psf) +aperture_psf[:5] = 0 +aperture_psf[-5:] = 0 +aperture_psf[:,:5] = 0 +aperture_psf[:,-5:] = 0 + + +# %% +def compute_skm_metrics(gt, recon): + # takes in already normalized gt + mse = skm.mean_squared_error(gt, recon) + psnr = skm.peak_signal_noise_ratio(gt, recon) + nmse = skm.normalized_root_mse(gt, recon) + ssim = skm.structural_similarity(gt, recon, data_range=1) + return mse, psnr, nmse, ssim + + +# %% +# set seed values for reproducibility +seed_values_full = np.arange(1, 4) + +# set photon properties +#mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300] +mean_photon_count_list = [20, 40, 80, 160, 320] + +# set eligible psfs + +psf_patterns_use = [one_psf, four_psf, diffuser_psf] +psf_names_use = ['one', 'four', 'diffuser'] + +save_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/deconvolutions/' + + +# MI estimator parameters +patch_size = 32 +num_patches = 10000 +test_set_size = 1500 +bs = 500 +max_epochs = 50 + +seed_value = 10 + +reg_value_best = 10**-2 + +# %% +# data generation process + +for photon_count in mean_photon_count_list: + for psf_idx, psf_pattern in enumerate(psf_patterns_use): + # load dataset + (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() + data = np.concatenate((x_train, x_test), axis=0) + data = data.astype(np.float64) + labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. + # convert data to grayscale before converting to photons + if len(data.shape) == 4: + data = tf.image.rgb_to_grayscale(data).numpy() + data = data.squeeze() + # convert to photons with mean value of photon_count + data /= np.mean(data) + data *= photon_count + # get maximum value in this data + max_val = np.max(data) + # make tiled data + random_data, random_labels = generate_random_tiled_data(data, labels, seed_value) + # only keep the middle part of the data + data_padded = np.zeros((data.shape[0], 96, 96)) + data_padded[:, 32:64, 32:64] = random_data[:, 32:64, 32:64] + # save the middle part of the data as the gt for metric computation, include only the test set portion. + gt_data = data_padded[:, 32:64, 32:64] + gt_data = gt_data[-test_set_size:] + # extract the test set before doing convolution + test_data = data_padded[-test_set_size:] + # convolve the data + convolved_data = convolved_dataset(psf_pattern, test_data) + convolved_data_noisy = add_noise(convolved_data, seed=seed_value) + # output of add_noise is a jax array that's float32, convert to regular numpy array and float64. + convolved_data_noisy = np.array(convolved_data_noisy).astype(np.float64) + + # compute metrics using unsupervised wiener deconvolution + mse_psf = [] + psnr_psf = [] + ssim_psf = [] + for i in tqdm(range(convolved_data_noisy.shape[0])): + recon, _ = unsupervised_wiener(convolved_data_noisy[i] / max_val, psf_pattern) + recon = recon[17:49, 17:49] #this is the crop window to look at + mse = skm.mean_squared_error(gt_data[i] / max_val, recon) + psnr = skm.peak_signal_noise_ratio(gt_data[i] / max_val, recon) + ssim = skm.structural_similarity(gt_data[i] / max_val, recon, data_range=1) + mse_psf.append(mse) + psnr_psf.append(psnr) + ssim_psf.append(ssim) + + print('PSF: {}, Mean MSE: {}, Mean PSNR: {}, Mean SSIM: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf), np.mean(ssim_psf))) + np.save(save_dir + 'unsupervised_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf, ssim_psf]) + + # repeat with regular deconvolution + mse_psf = [] + psnr_psf = [] + ssim_psf = [] + for i in tqdm(range(convolved_data_noisy.shape[0])): + recon = wiener(convolved_data_noisy[i] / max_val, psf_pattern, reg_value_best) + recon = recon[17:49, 17:49] #this is the crop window to look at + mse = skm.mean_squared_error(gt_data[i] / max_val, recon) + psnr = skm.peak_signal_noise_ratio(gt_data[i] / max_val, recon) + ssim = skm.structural_similarity(gt_data[i] / max_val, recon, data_range=1) + mse_psf.append(mse) + psnr_psf.append(psnr) + ssim_psf.append(ssim) + print('PSF: {}, Mean MSE: {}, Mean PSNR: {}, Mean SSIM: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf), np.mean(ssim_psf))) + np.save(save_dir + 'regular_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf, ssim_psf]) + + + + + +# %% + + diff --git a/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb b/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb new file mode 100644 index 0000000..97c9575 --- /dev/null +++ b/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb @@ -0,0 +1,591 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Make the plot for MI and deconvolution relationship for paper figure, 2024/10/23" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload \n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax import config\n", + "config.update(\"jax_enable_x64\", True)\n", + "import numpy as np\n", + "\n", + "import sys \n", + "sys.path.append('/home/lakabuli/workspace/EncodingInformation/src')\n", + "from lensless_helpers import *\n", + "import os\n", + "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", + "print(os.environ.get('PYTHONPATH'))\n", + "from cleanplots import * " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "seed_value = 10\n", + "\n", + "# set photon properties \n", + "bias = 10 # in photons\n", + "mean_photon_count_list = [20, 40, 80, 160, 320]\n", + "max_photon_count = mean_photon_count_list[-1]\n", + "\n", + "# set eligible psfs\n", + "\n", + "psf_names = ['one', 'four', 'diffuser']\n", + "\n", + "# MI estimator parameters \n", + "patch_size = 32\n", + "num_patches = 10000\n", + "val_set_size = 1000\n", + "test_set_size = 1500\n", + "\n", + "mi_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/mi_estimates_smaller_lr/'\n", + "recon_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/deconvolutions/'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load MI data and make plots of it\n", + "\n", + "The plot has essentially invisible error bars. No outlier issues" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cleanplots import *\n", + "get_color_cycle()[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", + "mis_across_psfs = []\n", + "lowers_across_psfs = []\n", + "uppers_across_psfs = []\n", + "for psf_name in psf_names:\n", + " mis = []\n", + " lowers = []\n", + " uppers = []\n", + " for photon_count in mean_photon_count_list:\n", + " mi_estimates = np.load(mi_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", + " mi_values = mi_estimates[0]\n", + " print(np.max(mi_values) - np.min(mi_values))\n", + " lower_bounds = mi_estimates[1]\n", + " upper_bounds = mi_estimates[2]\n", + " # get index that has smallest mi value across the different model runs.\n", + " min_mi_index = np.argmin(mi_values)\n", + " mis.append(mi_values[min_mi_index])\n", + " lowers.append(lower_bounds[min_mi_index])\n", + " uppers.append(upper_bounds[min_mi_index])\n", + " ax.plot(mean_photon_count_list, mis, label=psf_name) \n", + " ax.fill_between(mean_photon_count_list, lowers, uppers, alpha=0.3)\n", + " mis_across_psfs.append(mis)\n", + " lowers_across_psfs.append(lowers)\n", + " uppers_across_psfs.append(uppers)\n", + "plt.legend()\n", + "plt.title(\"PixelCNN MI estimates across Photon Count, CIFAR10\")\n", + "plt.xlabel(\"Mean Photon Count\")\n", + "plt.ylabel(\"Estimated Mutual Information\")\n", + "mis_across_psfs = np.array(mis_across_psfs)\n", + "lowers_across_psfs = np.array(lowers_across_psfs)\n", + "uppers_across_psfs = np.array(uppers_across_psfs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load recon data and make plots of it" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mses_across_psfs = []\n", + "mse_lowers_across_psfs = []\n", + "mse_uppers_across_psfs = []\n", + "psnrs_across_psfs = []\n", + "psnr_lowers_across_psfs = []\n", + "psnr_uppers_across_psfs = []\n", + "ssims_across_psfs = []\n", + "ssim_lowers_across_psfs = []\n", + "ssim_uppers_across_psfs = []\n", + "\n", + "for psf_name in psf_names: \n", + " mse_vals = []\n", + " mse_lowers = []\n", + " mse_uppers = []\n", + " psnr_vals = []\n", + " psnr_lowers = []\n", + " psnr_uppers = []\n", + " ssim_vals = []\n", + " ssim_lowers = []\n", + " ssim_uppers = []\n", + " for photon_count in mean_photon_count_list:\n", + " metrics = np.load(recon_dir + 'unsupervised_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", + " mse = metrics[0]\n", + " psnr = metrics[1] \n", + " ssim = metrics[2]\n", + " bootstrap_mse, bootstrap_psnr, bootstrap_ssim = compute_bootstraps(mse, psnr, ssim, test_set_size)\n", + " mean_mse, lower_bound_mse, upper_bound_mse = compute_confidence_interval(bootstrap_mse, confidence_interval=0.95)\n", + " mean_psnr, lower_bound_psnr, upper_bound_psnr = compute_confidence_interval(bootstrap_psnr, confidence_interval=0.95)\n", + " mean_ssim, lower_bound_ssim, upper_bound_ssim = compute_confidence_interval(bootstrap_ssim, confidence_interval=0.95)\n", + " mse_vals.append(mean_mse)\n", + " mse_lowers.append(lower_bound_mse)\n", + " mse_uppers.append(upper_bound_mse)\n", + " psnr_vals.append(mean_psnr)\n", + " psnr_lowers.append(lower_bound_psnr)\n", + " psnr_uppers.append(upper_bound_psnr)\n", + " ssim_vals.append(mean_ssim)\n", + " ssim_lowers.append(lower_bound_ssim)\n", + " ssim_uppers.append(upper_bound_ssim)\n", + " mses_across_psfs.append(mse_vals)\n", + " mse_lowers_across_psfs.append(mse_lowers)\n", + " mse_uppers_across_psfs.append(mse_uppers)\n", + " psnrs_across_psfs.append(psnr_vals)\n", + " psnr_lowers_across_psfs.append(psnr_lowers)\n", + " psnr_uppers_across_psfs.append(psnr_uppers)\n", + " ssims_across_psfs.append(ssim_vals)\n", + " ssim_lowers_across_psfs.append(ssim_lowers)\n", + " ssim_uppers_across_psfs.append(ssim_uppers)\n", + "mses_across_psfs = np.array(mses_across_psfs)\n", + "mse_lowers_across_psfs = np.array(mse_lowers_across_psfs)\n", + "mse_uppers_across_psfs = np.array(mse_uppers_across_psfs)\n", + "psnrs_across_psfs = np.array(psnrs_across_psfs)\n", + "psnr_lowers_across_psfs = np.array(psnr_lowers_across_psfs)\n", + "psnr_uppers_across_psfs = np.array(psnr_uppers_across_psfs)\n", + "ssims_across_psfs = np.array(ssims_across_psfs)\n", + "ssim_lowers_across_psfs = np.array(ssim_lowers_across_psfs)\n", + "ssim_uppers_across_psfs = np.array(ssim_uppers_across_psfs)\n", + "plt.figure(figsize=(20, 5))\n", + "plt.subplot(1, 3, 1)\n", + "for i in range(len(psf_names)):\n", + " plt.plot(mean_photon_count_list, mses_across_psfs[i], label=psf_names[i])\n", + " plt.fill_between(mean_photon_count_list, mse_lowers_across_psfs[i], mse_uppers_across_psfs[i], alpha=0.5)\n", + "plt.title(\"MSE\")\n", + "plt.legend()\n", + "plt.subplot(1, 3, 2)\n", + "for i in range(len(psf_names)):\n", + " plt.plot(mean_photon_count_list, psnrs_across_psfs[i], label=psf_names[i])\n", + " plt.fill_between(mean_photon_count_list, psnr_lowers_across_psfs[i], psnr_uppers_across_psfs[i], alpha=0.5)\n", + "plt.title(\"PSNR\")\n", + "plt.subplot(1, 3, 3)\n", + "for i in range(len(psf_names)):\n", + " plt.plot(mean_photon_count_list, ssims_across_psfs[i], label=psf_names[i])\n", + " plt.fill_between(mean_photon_count_list, ssim_lowers_across_psfs[i], ssim_uppers_across_psfs[i], alpha=0.5)\n", + "plt.title(\"SSIM\")\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make figures, omitting error bars since smaller than marker size and reverting to circular markers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def marker_for_psf(psf_name):\n", + " if psf_name =='one':\n", + " marker = 'o'\n", + " elif psf_name == 'four':\n", + " marker = 'o'\n", + " #marker = 's' \n", + " elif psf_name == 'diffuser':\n", + " #marker = '*'\n", + " marker = 'o'\n", + " elif psf_name == 'uc':\n", + " marker = 'x'\n", + " elif psf_name =='two':\n", + " marker = 'd'\n", + " elif psf_name == 'three':\n", + " marker = 'v'\n", + " elif psf_name == 'five':\n", + " marker = 'p'\n", + " elif psf_name == 'aperture':\n", + " marker = 'P'\n", + " return marker" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Choose a base colormap\n", + "base_colormap = plt.get_cmap('inferno')\n", + "# Define the start and end points--used so that high values aren't too light against white background\n", + "start, end = 0, 0.88 # making end point 0.8\n", + "from matplotlib.colors import LinearSegmentedColormap\n", + "# Create a new colormap from the portion of the original colormap\n", + "colormap = LinearSegmentedColormap.from_list(\n", + " 'trunc({n},{a:.2f},{b:.2f})'.format(n=base_colormap.name, a=start, b=end),\n", + " base_colormap(np.linspace(start, end, 256))\n", + ")\n", + "\n", + "min_photons_per_pixel = min(mean_photon_count_list)\n", + "max_photons_per_pixel = max(mean_photon_count_list)\n", + "\n", + "min_log_photons = np.log(min_photons_per_pixel)\n", + "max_log_photons = np.log(max_photons_per_pixel)\n", + "\n", + "def color_for_photon_level(photons_per_pixel):\n", + " log_photons = np.log(photons_per_pixel)\n", + " return colormap((log_photons - min_log_photons) / (max_log_photons - min_log_photons) )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# old format for selecting target indices, now not used much\n", + "metric_type = 1 # 0 for MSE, 1 for PSNR \n", + "valid_psfs = [0, 1, 2]\n", + "valid_photon_counts = [20, 40, 80, 160, 320]\n", + "psf_names = [psf_names[i] for i in valid_psfs]\n", + "print(psf_names)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mse_error_lower = np.abs(mses_across_psfs - mse_lowers_across_psfs)\n", + "mse_error_upper = np.abs(mse_uppers_across_psfs - mses_across_psfs)\n", + "psnr_error_lower = np.abs(psnrs_across_psfs - psnr_lowers_across_psfs)\n", + "psnr_error_upper = np.abs(psnr_uppers_across_psfs - psnrs_across_psfs)\n", + "mi_error_lower = np.abs(mis_across_psfs - lowers_across_psfs)\n", + "mi_error_upper = np.abs(uppers_across_psfs - mis_across_psfs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = mses_across_psfs[psf_idx][photon_idx] \n", + " ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], mses_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[mse_error_lower[psf_idx], mse_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "ax.set_ylabel(\"Mean Squared Error\")\n", + "clear_spines(ax)\n", + "\n", + "\n", + "# legend\n", + "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n", + "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", + "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", + "\n", + "ax.legend(loc='upper right', frameon=True)\n", + "ax.set_xlim([0, None])\n", + "\n", + "\n", + "\n", + "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "sm.set_array([])\n", + "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", + "# set tick labels\n", + "cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig('mse_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = psnrs_across_psfs[psf_idx][photon_idx] \n", + " ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], psnrs_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[psnr_error_lower[psf_idx], psnr_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "ax.set_ylabel(\"Peak Signal-to-Noise Ratio (dB)\")\n", + "clear_spines(ax)\n", + "\n", + "\n", + "# legend\n", + "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n", + "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", + "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", + "\n", + "ax.legend(loc='lower right', frameon=True)\n", + "ax.set_xlim([0, None])\n", + "\n", + "\n", + "\n", + "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "sm.set_array([])\n", + "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", + "# set tick labels\n", + "cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig('psnr_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = ssims_across_psfs[psf_idx][photon_idx] \n", + " ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], ssims_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[ssim_error_lower[psf_idx], ssim_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "ax.set_ylabel(\"Structural Similarity Index Measure (SSIM)\")\n", + "clear_spines(ax)\n", + "\n", + "\n", + "# legend\n", + "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n", + "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", + "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", + "\n", + "ax.legend(loc='lower right', frameon=True)\n", + "ax.set_xlim([0, None])\n", + "\n", + "\n", + "\n", + "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "sm.set_array([])\n", + "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", + "# set tick labels\n", + "cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig('ssim_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Put all 3 into one figure" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from cleanplots import *\n", + "from matplotlib.ticker import ScalarFormatter\n", + "\n", + "figs, axs = plt.subplots(1, 3, figsize=(12, 4), sharex=True)\n", + "\n", + "\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = mses_across_psfs[psf_idx][photon_idx] \n", + " axs[0].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], mses_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[mse_error_lower[psf_idx], mse_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " axs[0].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "#axs[0].set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "axs[0].set_title(\"Mean Squared Error\")\n", + "clear_spines(axs[0])\n", + "\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = ssims_across_psfs[psf_idx][photon_idx] \n", + " axs[1].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], ssims_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[ssim_error_lower[psf_idx], ssim_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " axs[1].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "axs[1].set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "axs[1].set_title(\"Structural Similarity Index Measure (SSIM)\")\n", + "clear_spines(axs[1])\n", + "\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = psnrs_across_psfs[psf_idx][photon_idx] \n", + " axs[2].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], psnrs_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[psnr_error_lower[psf_idx], psnr_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " axs[2].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "#axs[2].set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "axs[2].set_title(\"Peak Signal-to-Noise Ratio (dB)\")\n", + "clear_spines(axs[2])\n", + "\n", + "# norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "# sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "# sm.set_array([])\n", + "# cbar = plt.colorbar(sm, ax=axs[2], ticks=(np.log(valid_photon_counts)))\n", + "# # set tick labels\n", + "# cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "# cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig(\"metrics_vs_MI_with_confidence_intervals_log_photons.pdf\", bbox_inches='tight', transparent=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "infotransformer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py b/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py new file mode 100644 index 0000000..bd9148a --- /dev/null +++ b/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py @@ -0,0 +1,154 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.15.2 +# kernelspec: +# display_name: infotransformer +# language: python +# name: python3 +# --- + +# %% +# %load_ext autoreload +# %autoreload 2 + +# Final MI estimation script for lensless imager, used in paper. + +import os +from jax import config +config.update("jax_enable_x64", True) +import sys +sys.path.append('/home/lakabuli/workspace/EncodingInformation/src') + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = '0' +from encoding_information.gpu_utils import limit_gpu_memory_growth +limit_gpu_memory_growth() + +from cleanplots import * +import jax.numpy as np +import numpy as onp +import tensorflow as tf +import tensorflow.keras as tfk + + +from lensless_helpers import * + +# %% +from encoding_information import extract_patches +from encoding_information.models import PixelCNN +from encoding_information.plot_utils import plot_samples +from encoding_information.models import PoissonNoiseModel +from encoding_information.image_utils import add_noise +from encoding_information import estimate_information + +# %% [markdown] +# ### Sweep Photon Count and Diffusers + +# %% +diffuser_psf = load_diffuser_32() +one_psf = load_single_lens_uniform(32) +two_psf = load_two_lens_uniform(32) +three_psf = load_three_lens_uniform(32) +four_psf = load_four_lens_uniform(32) +five_psf = load_five_lens_uniform(32) + +# %% +# set seed values for reproducibility +seed_values_full = np.arange(1, 5) + +# set photon properties +bias = 10 # in photons +mean_photon_count_list = [20, 40, 80, 160, 320] + +# set eligible psfs + +psf_patterns = [diffuser_psf, four_psf, one_psf] +psf_names = ['diffuser', 'four', 'one'] + +# MI estimator parameters +patch_size = 32 +num_patches = 10000 +val_set_size = 1000 +test_set_size = 1500 +num_samples = 8 +learning_rate = 1e-3 # using 5x iterations per epoch, using smaller lr, and using less patience since it should be a smoother curve. +num_iters_per_epoch = 500 +patience_val = 20 + + +save_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/mi_estimates_smaller_lr/' + + +# %% +for photon_count in mean_photon_count_list: + for index, psf_pattern in enumerate(psf_patterns): + val_loss_log = [] + mi_estimates = [] + lower_bounds = [] + upper_bounds = [] + for seed_value in seed_values_full: + # load dataset + (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() + data = onp.concatenate((x_train, x_test), axis=0) + labels = np.concatenate((y_train, y_test), axis=0) + data = data.astype(np.float32) + # convert data to grayscale before converting to photons + if len(data.shape) == 4: + data = tf.image.rgb_to_grayscale(data).numpy() + data = data.squeeze() + # convert to photons with mean value of photon_count + data /= onp.mean(data) + data *= photon_count + # make tiled data + random_data, random_labels = generate_random_tiled_data(data, labels, seed_value) + + if psf_pattern is None: + start_idx = data.shape[-1] // 2 + end_idx = data.shape[-1] // 2 - 1 + psf_data = random_data[:, start_idx:-end_idx, start_idx:-end_idx] + else: + psf_data = convolved_dataset(psf_pattern, random_data) + # add small bias to data + psf_data += bias + # make patches for training and testing splits, random patching + patches = extract_patches(psf_data[:-test_set_size], patch_size=patch_size, num_patches=num_patches, seed=seed_value, verbose=True) + test_patches = extract_patches(psf_data[-test_set_size:], patch_size=patch_size, num_patches=test_set_size, seed=seed_value, verbose=True) + # put all the clean patches together for use in MI estimatino function later + full_clean_patches = onp.concatenate([patches, test_patches]) + # add noise to both sets + patches_noisy = add_noise(patches, seed=seed_value) + test_patches_noisy = add_noise(test_patches, seed=seed_value) + + # initialize pixelcnn + pixel_cnn = PixelCNN() + # fit pixelcnn to noisy patches. defaults to 10% val samples which will be 1k as desired. + # using smaller lr this time and adding seeding, letting it go for full training time. + val_loss_history = pixel_cnn.fit(patches_noisy, seed=seed_value, learning_rate=learning_rate, do_lr_decay=False, steps_per_epoch=num_iters_per_epoch, patience=patience_val) + # generate samples, not necessary for MI sweeps + # pixel_cnn_samples = pixel_cnn.generate_samples(num_samples=num_samples) + # # visualize samples + # plot_samples([pixel_cnn_samples], test_patches, model_names=['PixelCNN']) + + # instantiate noise model + noise_model = PoissonNoiseModel() + # estimate information using the fit pixelcnn and noise model, with clean data + pixel_cnn_info, pixel_cnn_lower_bound, pixel_cnn_upper_bound = estimate_information(pixel_cnn, noise_model, patches_noisy, + test_patches_noisy, clean_data=full_clean_patches, + confidence_interval=0.95) + print("PixelCNN estimated information: ", pixel_cnn_info) + print("PixelCNN lower bound: ", pixel_cnn_lower_bound) + print("PixelCNN upper bound: ", pixel_cnn_upper_bound) + # append results to lists + val_loss_log.append(val_loss_history) + mi_estimates.append(pixel_cnn_info) + lower_bounds.append(pixel_cnn_lower_bound) + upper_bounds.append(pixel_cnn_upper_bound) + np.save(save_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array([mi_estimates, lower_bounds, upper_bounds])) + np.save(save_dir + 'pixelcnn_val_loss_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(val_loss_log, dtype=object)) + np.save(save_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array([mi_estimates, lower_bounds, upper_bounds])) + np.save(save_dir + 'pixelcnn_val_loss_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(val_loss_log, dtype=object)) diff --git a/lensless_imager/leyla_fns.py b/lensless_imager/lensless_helpers.py similarity index 94% rename from lensless_imager/leyla_fns.py rename to lensless_imager/lensless_helpers.py index 131b4fe..0693b28 100644 --- a/lensless_imager/leyla_fns.py +++ b/lensless_imager/lensless_helpers.py @@ -578,4 +578,35 @@ def momentum_testing_model(train_data, train_labels, test_data, test_labels, val history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data test_loss, test_acc = model.evaluate(test_data, test_labels) - return history, model, test_loss, test_acc \ No newline at end of file + return history, model, test_loss, test_acc + + +# bootstrapping function +def compute_bootstraps(mses, psnrs, ssims, test_set_length, num_bootstraps=100): + bootstrap_mses = [] + bootstrap_psnrs = [] + bootstrap_ssims = [] + for bootstrap_idx in tqdm(range(num_bootstraps), desc='Bootstrapping to compute confidence interval'): + # select indices for sampling + bootstrap_indices = np.random.choice(test_set_length, test_set_length, replace=True) + # take the metric values at those indices + bootstrap_selected_mses = mses[bootstrap_indices] + bootstrap_selected_psnrs = psnrs[bootstrap_indices] + bootstrap_selected_ssims = ssims[bootstrap_indices] + # accumulate the mean of the selected metric values + bootstrap_mses.append(np.mean(bootstrap_selected_mses)) + bootstrap_psnrs.append(np.mean(bootstrap_selected_psnrs)) + bootstrap_ssims.append(np.mean(bootstrap_selected_ssims)) + bootstrap_mses = np.array(bootstrap_mses) + bootstrap_psnrs = np.array(bootstrap_psnrs) + bootstrap_ssims = np.array(bootstrap_ssims) + return bootstrap_mses, bootstrap_psnrs, bootstrap_ssims + +def compute_confidence_interval(list_of_items, confidence_interval=0.95): + # use this one, final version + assert confidence_interval > 0 and confidence_interval < 1 + mean_value = np.mean(list_of_items) + lower_bound = np.percentile(list_of_items, 50 * (1 - confidence_interval)) + upper_bound = np.percentile(list_of_items, 50 * (1 + confidence_interval)) + return mean_value, lower_bound, upper_bound + diff --git a/src/encoding_information/models/__init__.py b/src/encoding_information/models/__init__.py index b9d5bdb..ef91f66 100644 --- a/src/encoding_information/models/__init__.py +++ b/src/encoding_information/models/__init__.py @@ -1,4 +1,5 @@ from .pixel_cnn import PixelCNN +from .multichannel_pixel_cnn import MultiChannelPixelCNN from .gaussian_process import FullGaussianProcess, StationaryGaussianProcess from .conditional_entropy_models import AnalyticGaussianNoiseModel, PoissonNoiseModel, AnalyticComplexPixelGaussianNoiseModel \ No newline at end of file diff --git a/src/encoding_information/models/multichannel_pixel_cnn.py b/src/encoding_information/models/multichannel_pixel_cnn.py new file mode 100644 index 0000000..26826da --- /dev/null +++ b/src/encoding_information/models/multichannel_pixel_cnn.py @@ -0,0 +1,762 @@ +""" +MultichannelPixelCNN in Jax/Flax. Adapted from single channel PixelCNN implementation in Flax.: +https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial12/Autoregressive_Image_Modeling.html + +Univariate mixture density output adapted from: +https://github.com/hardmaru/mdn_jax_tutorial/blob/master/mixture_density_networks_jax.ipynb +""" + +## Standard libraries +import os +import numpy as onp +from typing import Any +from tqdm import tqdm +import warnings + +## JAX +import jax +import jax.numpy as np +from jax import random +from jax.scipy.special import logsumexp + +from flax import linen as nn +from flax.training.train_state import TrainState +import optax + + +from .model_base_class import MeasurementModel, MeasurementType, \ + train_model, _evaluate_nll, make_dataset_generators + + + +class PreprocessLayer(nn.Module): + """ + A layer that normalizes the input images using the provided mean and standard deviation. + + Attributes + ---------- + mean : np.ndarray + The mean to subtract from the input images. + std : np.ndarray + The standard deviation to divide the input images by. + """ + mean: np.ndarray + std: np.ndarray + + def __call__(self, x): + return (x - self.mean) / (self.std + 1e-5) + +class MaskedConvolution(nn.Module): + """ + A convolutional layer with a mask to ensure autoregressive behavior. + + This layer ensures that during the convolution, the current pixel does not + have access to any future pixels (either to the right or below in the image). + + Attributes + ---------- + c_out : int + The number of output channels. + mask : np.ndarray + The mask to apply to the convolution, determining which pixels are visible. + dilation : int, optional + The dilation factor for the convolution (default is 1). + """ + c_out : int + mask : np.ndarray + dilation : int = 1 + + @nn.compact + def __call__(self, x): + # Flax's convolution module already supports masking + # The mask must be the same size as kernel + # => extend over input and output feature channels + if len(self.mask.shape) == 2: + mask_ext = self.mask[...,None,None] + mask_ext = np.tile(mask_ext, (1, 1, x.shape[-1], self.c_out)) + else: + mask_ext = self.mask + # Convolution with masking + x = nn.Conv(features=self.c_out, + kernel_size=self.mask.shape[:2], + kernel_dilation=self.dilation, + mask=mask_ext)(x) + return x + + +class VerticalStackConvolution(nn.Module): + """ + A vertical convolutional layer that processes the pixels above the current pixel in an image. + + This layer creates a vertical stack by masking the convolution kernel, ensuring that the pixels + below the current pixel are not visible during the convolution. + + Attributes + ---------- + c_out : int + The number of output channels. + kernel_size : int + The size of the convolution kernel. + mask_center : bool, optional + Whether to mask out the center pixel in the kernel (default is False). + dilation : int, optional + The dilation factor for the convolution (default is 1). + """ + c_out : int + kernel_size : int + mask_center : bool = False + dilation : int = 1 + + def setup(self): + # Mask out all pixels below. For efficiency, we could also reduce the kernel + # size in height, but for simplicity, we stick with masking here. + mask = onp.ones((self.kernel_size, self.kernel_size), dtype=onp.float32) + mask[self.kernel_size//2+1:,:] = 0 + # For the very first convolution, we will also mask the center row + if self.mask_center: + mask[self.kernel_size//2,:] = 0 + # Our convolution module + self.conv = MaskedConvolution(c_out=self.c_out, + mask=mask, + dilation=self.dilation) + + def __call__(self, x): + return self.conv(x) + + +class HorizontalStackConvolution(nn.Module): + """ + A horizontal convolutional layer that processes the pixels to the left of the current pixel in an image. + + This layer creates a horizontal stack by masking the convolution kernel, ensuring that the pixels + to the right of the current pixel are not visible during the convolution. + + Attributes + ---------- + c_out : int + The number of output channels. + kernel_size : int + The size of the convolution kernel. + mask_center : bool, optional + Whether to mask out the center pixel in the kernel (default is False). + dilation : int, optional + The dilation factor for the convolution (default is 1). + """ + c_out : int + kernel_size : int + mask_center : bool = False + dilation : int = 1 + + def setup(self): + # Mask out all pixels on the left. Note that our kernel has a size of 1 + # in height because we only look at the pixel in the same row. + mask = onp.ones((1, self.kernel_size), dtype=onp.float32) + mask[0,self.kernel_size//2+1:] = 0 + # For the very first convolution, we will also mask the center pixel + if self.mask_center: + mask[0,self.kernel_size//2] = 0 + # Our convolution module + self.conv = MaskedConvolution(c_out=self.c_out, + mask=mask, + dilation=self.dilation) + + def __call__(self, x): + return self.conv(x) + + +class GatedMaskedConv(nn.Module): + """ + A gated masked convolution layer used in PixelCNN. This layer uses gated activation functions + to improve gradient flow during training. + + It combines information from a vertical stack and a horizontal stack, each being passed through + masked convolutions, and optionally conditioned on an external vector (such as class labels). + + Attributes + ---------- + dilation : int, optional + The dilation factor for the convolutions (default is 1). + id : int, optional + The layer ID, used for parameter naming. + condition_vector_size : int, optional + The size of the condition vector for conditional PixelCNN. + """ + dilation : int = 1 + id: int = None + condition_vector_size : int = None + + @nn.compact + def __call__(self, v_stack, h_stack, condition_vector=None): + c_in = v_stack.shape[-1] + + # Layers (depend on input shape) + conv_vert = VerticalStackConvolution(c_out=2*c_in, + kernel_size=3, + mask_center=False, + dilation=self.dilation) + conv_horiz = HorizontalStackConvolution(c_out=2*c_in, + kernel_size=3, + mask_center=False, + dilation=self.dilation) + conv_vert_to_horiz = nn.Conv(2*c_in, + kernel_size=(1, 1)) + conv_horiz_1x1 = nn.Conv(c_in, + kernel_size=(1, 1)) + + + + # Vertical stack (left) + v_stack_feat = conv_vert(v_stack) + v_val, v_gate = np.split(v_stack_feat, 2, axis=-1) + + if condition_vector is not None: + weights = self.param(f'conditioning_weights_vert_{self.id}', jax.nn.initializers.lecun_normal(), (1, self.condition_vector_size,)) + y = np.dot(weights, condition_vector.T).reshape(-1,1,1,1) + weights_gate = self.param(f'conditioning_weights_vert_gate{self.id}', jax.nn.initializers.lecun_normal(), (1, self.condition_vector_size,)) + y_gate = np.dot(weights_gate, condition_vector.T).reshape(-1,1,1,1) + v_stack_out = nn.tanh(v_val + y) * nn.sigmoid(v_gate + y_gate) + else: + v_stack_out = nn.tanh(v_val) * nn.sigmoid(v_gate) + + # Horizontal stack (right) + h_stack_feat = conv_horiz(h_stack) + h_stack_feat = h_stack_feat + conv_vert_to_horiz(v_stack_feat) + h_val, h_gate = np.split(h_stack_feat, 2, axis=-1) + if condition_vector is not None: + weights = self.param(f'conditioning_weights_horz{self.id}', jax.nn.initializers.lecun_normal(), (1, self.condition_vector_size,)) + y = np.dot(weights, condition_vector.T).reshape(-1,1,1,1) + weights_gate = self.param(f'conditioning_weights_horz_gate{self.id}', jax.nn.initializers.lecun_normal(), (1, self.condition_vector_size,)) + y_gate = np.dot(weights_gate, condition_vector.T).reshape(-1,1,1,1) + h_stack_feat = nn.tanh(h_val + y) * nn.sigmoid(h_gate + y_gate) + else: + h_stack_feat = nn.tanh(h_val) * nn.sigmoid(h_gate) + h_stack_out = conv_horiz_1x1(h_stack_feat) + h_stack_out = h_stack_out + h_stack + + return v_stack_out, h_stack_out + + +class _MultiChannelPixelCNNFlaxImpl(nn.Module): + """ + The core implementation of the PixelCNN model in Flax. + + This module defines the structure of the PixelCNN, including the vertical and horizontal + masked convolutions, gated activation functions, and a mixture density output layer. + + Attributes + ---------- + data_shape : tuple + The shape of the input data (height, width, channels). + num_hidden_channels : int, optional + The number of hidden channels in the model (default is 64). + num_mixture_components : int, optional + The number of components in the mixture density output (default is 40). + train_data_mean : float + The mean of the training data used for normalization. Multichannel considers a float for each channel. + train_data_std : float + The standard deviation of the training data used for normalization. Multichannel considers a float for each channel. + train_data_min : float + The minimum value of the training data. Multichannel considers a float for each channel. + train_data_max : float + The maximum value of the training data. Multichannel considers a float for each channel. + sigma_min : float, optional + The minimum standard deviation for the mixture density output (default is 1). + condition_vector_size : int, optional + The size of the condition vector for conditional PixelCNN. + use_positional_embedding : bool, optional + Whether to use learned positional embeddings for each pixel (default is False). + """ + data_shape : tuple + num_hidden_channels : int = 64 + num_mixture_components : int = 40 + train_data_mean : float = None + train_data_std : float = None + train_data_min : float = None + train_data_max : float = None + sigma_min : float = 1 + condition_vector_size : int = None + use_positional_embedding : bool = False + + def setup(self): + if None in [self.train_data_mean, self.train_data_std, self.train_data_min, self.train_data_max]: + raise Exception('Must pass in training data statistics constructor') + + if self.train_data_max.dtype != np.float32 or self.train_data_min.dtype != np.float32 or \ + self.train_data_mean.dtype != np.float32 or self.train_data_std.dtype != np.float32: + raise Exception('Must pass in training data statistics as float32') + + self.normalize = PreprocessLayer(mean=self.train_data_mean, std=self.train_data_std) + + if not isinstance(self.num_hidden_channels, int): + raise ValueError("num_hidden_channels must be an integer") + # Initial convolutions skipping the center pixel + self.conv_vstack = VerticalStackConvolution(self.num_hidden_channels, kernel_size=3, mask_center=True) + self.conv_hstack = HorizontalStackConvolution(self.num_hidden_channels, kernel_size=3, mask_center=True) + # Convolution block of PixelCNN. We use dilation instead of downscaling + self.conv_layers = [ + GatedMaskedConv(dilation=1, id=0, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=2, id=1, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=1, id=2, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=4, id=3, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=1, id=4, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=2, id=5, condition_vector_size=self.condition_vector_size), + GatedMaskedConv(dilation=1, id=6, condition_vector_size=self.condition_vector_size), + ] + # Output classification convolution (1x1) + self.conv_out = nn.Conv(self.num_hidden_channels, kernel_size=(1, 1)) + + # parameters for mixture density + def my_bias_init(rng, shape, dtype): + return random.uniform(rng, shape, dtype=dtype, + minval=np.min(self.train_data_min), maxval=np.max(self.train_data_max)) # just initializing a learnable parameter so using absolute values across channels + + # Parameters for learned positional embedding + if self.use_positional_embedding: + self.positional_embedding = nn.Embed(num_embeddings=self.data_shape[0] * self.data_shape[1], features=self.num_hidden_channels) + # generate unique index for each pixel + self.position_indices = np.arange(self.data_shape[0] * self.data_shape[1]).reshape(*self.data_shape[:2]) + + self.mu_dense = nn.Dense(self.num_mixture_components * self.data_shape[2], bias_init=my_bias_init) # scale by number of channels + self.sigma_dense = nn.Dense(self.num_mixture_components * self.data_shape[2] * self.data_shape[2]) # scale by squared number of channels since matrix + self.mix_logit_dense = nn.Dense(self.num_mixture_components) # mixture components are scalars for each pixel + + def __call__(self, x, condition_vectors=None): + """ + Do forward pass output the parameters of the gaussian mixture output + """ + # add trailing channel dimension if necessary + if x.ndim == 3: + x = x[..., np.newaxis] + + return self.forward_pass(x, condition_vectors=condition_vectors) + + def compute_gaussian_nll(self, mu, sigma, mix_logit, x): + # numerically efficient implementation of mixture density, slightly modified + # see https://github.com/hardmaru/mdn_jax_tutorial/blob/master/mixture_density_networks_jax.ipynb + # compute per-pixel negative log-likelihood + + # one-by-one step version for debugging. + # lognormal = self.lognormal(x, mu, sigma) + # jax.debug.print("Number of nans in lognormal {test}", test=np.sum(np.isnan(lognormal))) + # logit_normalized = mix_logit - logsumexp(mix_logit, axis=-1, keepdims=True) + # jax.debug.print("Number of nans in logit_normalized {test}", test=np.sum(np.isnan(logit_normalized))) + # nll = - logsumexp(logit_normalized + lognormal, axis=-1) + + #all in one step + nll = - logsumexp(mix_logit - logsumexp(mix_logit, axis=-1, keepdims=True) + self.lognormal(x, mu, sigma), axis=-1) + return nll + + def compute_loss(self, mu, sigma, mix_logit, x): + """ + Compute average negative log likelihood per pixel averaged over batch and pixels + """ + return self.compute_gaussian_nll(mu, sigma, mix_logit, x).mean() + + + def lognormal(self, y, mean, sigma): + # expand the data in the n_components dimension and tile + y = np.expand_dims(y, axis=-2) + y = np.tile(y, (1, 1, 1, self.num_mixture_components, 1)) + logRootDTwoPI = np.log(2.0 * np.pi) * self.data_shape[2] / 2.0 # d / 2 log 2pi + covarianceDeterminant = np.linalg.det(sigma) + matrix_sum = np.einsum('...i, ...ij, ...j->...', y - mean, np.linalg.inv(sigma), y - mean) + # -d/2 log(2pi) - 1/2 log det covariance - 0.5 (x - mu)T covaraince^-1 (x - mu) + return -1.0 * logRootDTwoPI - 0.5 * np.log(covarianceDeterminant) - 0.5 * matrix_sum + #return -0.5 * ((y - mean) / sigma) ** 2 - np.log(sigma) - logSqrtTwoPI # previous version for 1D + + def forward_pass(self, x, condition_vectors=None): + """ + Forward pass of the MultiChannelPixelCNN model. + + The image is passed through the vertical and horizontal masked convolutions, followed by + gated convolutions, and finally a mixture density output layer. The model outputs the parameters + of the mixture density for each pixel (mean, standard deviation, and mixture logits). + + Parameters + ---------- + x : ndarray + The input image, with shape (batch_size, height, width, channels). + condition_vectors : ndarray, optional + A vector to condition the image generation process (e.g., class labels). + + Returns + ------- + mu : ndarray + The mean of the Gaussian components for each pixel. + sigma : ndarray + The standard deviation of the Gaussian components for each pixel. + mix_logit : ndarray + The logits for the mixture components. + """ + # check shape + if x.ndim != 4: + raise ValueError("Input image must have shape BxHxWxC") + + # rescale to 0-1ish + x = self.normalize(x) + # Initial convolutions + v_stack = self.conv_vstack(x) + h_stack = self.conv_hstack(x) + # Gated Convolutions + for layer in self.conv_layers: + v_stack, h_stack = layer(v_stack, h_stack, condition_vector=condition_vectors) + # 1x1 classification convolution + # Apply ELU before 1x1 convolution for non-linearity on residual connection + out = self.conv_out(nn.elu(h_stack)) + + if self.use_positional_embedding: + # add positional embedding + indices = self.position_indices + # apply positional embedding + out = out + self.positional_embedding(indices) + # must be positive and within data range + #mu = np.clip(self.mu_dense(out), self.train_data_min, self.train_data_max) # 1D version + # mu items need to be reshaped and clipped + mu_out = self.mu_dense(out) + mu_out = np.reshape(mu_out, (out.shape[0], out.shape[1], out.shape[2], self.num_mixture_components, self.data_shape[2])) # reshape from b x h x w x components*num_channels to b x h x w x components x num_channels + mu = np.clip(mu_out, self.train_data_min, self.train_data_max) + + #sigma = nn.activation.softplus(self.sigma_dense(out)) # 1D version + # avoid having tiny components that overly concentrate mass, and don't need components larger than data standard deviation + #sigma = np.clip(sigma, self.sigma_min, self.train_data_std) # previous version + + # sigma items need to be reshaped to be a covariance matrix, and clipped to be a valid cholesky decomposition + sigma_out = self.sigma_dense(out) + # reshape to covariance matrix dimensions + sigma_out = np.reshape(sigma_out, (out.shape[0], out.shape[1], out.shape[2], self.num_mixture_components, self.data_shape[2], self.data_shape[2])) # reshape from b x h x w x components*num_channels**2 to b x h x w x components x num_channels x num_channels + # make a lower triangular matrix L for L L^T + sigma_out = np.tril(sigma_out) + # manually loop through the channel components to clip the diagonals TODO could be more intelligently done maybe? + for channel_idx in range(self.data_shape[2]): + # apply softplus to this diagonal + sigma_out = sigma_out.at[..., channel_idx, channel_idx].set(nn.softplus(sigma_out[..., channel_idx, channel_idx])) + # then clip the components TODO need to change sigma_min to not be 1 in the future when it's not image data + sigma_out = sigma_out.at[..., channel_idx, channel_idx].set(np.clip(sigma_out[..., channel_idx, channel_idx], self.sigma_min, self.train_data_std[channel_idx])) # TODO think about if there needs to be an absolute train_data_std + # now turn this into a covariance matrix + # transpose, swap the last two dimensions + sigma_out_transpose = np.einsum('...ij->...ji', sigma_out) + # multiply cov = L L^T + sigma = np.einsum('...ij, ...jk->...ik', sigma_out, sigma_out_transpose) + # add a small amount to the diagonal to make sure it's positive definite, 1e-6 + sigma = sigma + 1e-6 * np.eye(self.data_shape[2]) + + mix_logit = self.mix_logit_dense(out) # stays as b x h x w x n_components. there isn't a channel dimension for this one + + return mu, sigma, mix_logit + + + +class MultiChannelPixelCNN(MeasurementModel): + """ + The PixelCNN model for autoregressive image modeling. + + This class handles the training and evaluation of the PixelCNN model and wraps the Flax implementation + in a higher-level interface that conforms to the MeasurementModel class. It provides methods for fitting + the model to data, computing the negative log-likelihood of images, and generating new images. + + Attributes + ---------- + num_hidden_channels : int + The number of hidden channels in the model. + num_mixture_components : int + The number of components in the mixture density output. + """ + + def __init__(self, num_hidden_channels=64, num_mixture_components=40): + """ + Initialize the PixelCNN model with image shape, number of hidden channels, and mixture components. + + Parameters + ---------- + num_hidden_channels : int + Number of hidden channels in the convolutional layers. + num_mixture_components : int + Number of mixture components for the output layer. + """ + + super().__init__([MeasurementType.HW, MeasurementType.HWC], measurement_dtype=float) + self.num_hidden_channels = num_hidden_channels + self.num_mixture_components = num_mixture_components + self._flax_model = None + + def fit(self, train_images, condition_vectors=None, learning_rate=1e-2, max_epochs=200, steps_per_epoch=100, patience=40, + sigma_min=1, batch_size=64, num_val_samples=None, percent_samples_for_validation=0.1, do_lr_decay=False, verbose=True, + add_gaussian_noise=False, add_uniform_noise=True, model_seed=None, data_seed=None, use_positional_embedding=False, + # deprecated + seed=None,): + """ + Train the PixelCNN model on a dataset of images. + + Parameters + ---------- + train_images : ndarray + The input dataset, with shape (N, H, W, C). + condition_vectors : ndarray, optional + Vectors to condition the image generation process (e.g., class labels). + learning_rate : float, optional + The learning rate for optimization (default is 1e-2). + max_epochs : int, optional + The maximum number of training epochs (default is 200). + steps_per_epoch : int, optional + The number of steps per epoch (default is 100). + patience : int, optional + The number of epochs to wait before early stopping (default is 40). + sigma_min : float, optional + The minimum standard deviation for the mixture density output (default is 1). + batch_size : int, optional + The batch size for training (default is 64). + num_val_samples : int, optional + The number of validation samples. If None, a percentage is used (default is None). + percent_samples_for_validation : float, optional + The percentage of samples to use for validation (default is 0.1). + do_lr_decay : bool, optional + Whether to apply learning rate decay during training (default is False). + verbose : bool, optional + Whether to print progress during training (default is True). + add_gaussian_noise : bool, optional + Whether to add Gaussian noise to the training images (default is False). + add_uniform_noise : bool, optional + Whether to add uniform noise to the training images (default is True). + model_seed : int, optional + Seed for model initialization. + data_seed : int, optional + Seed for data shuffling. + + Returns + ------- + val_loss_history : list + A list of validation loss values for each epoch. + """ + if seed is not None: + warnings.warn("seed argument is deprecated. Use model_seed and data_seed instead") + model_seed = seed + data_seed = seed + + if model_seed is not None: + onp.random.seed(model_seed) + model_key = jax.random.PRNGKey(onp.random.randint(0, 100000)) + + if condition_vectors is not None: + warnings.warn("For multi-channel PixelCNN condition vectors have not been implemented or double checked.") + + self._validate_data(train_images) + + train_images = train_images.astype(np.float32) + + # check that only one type of noise is added + if add_gaussian_noise and add_uniform_noise: + raise ValueError("Only one type of noise can be added to the training data") + + num_val_samples = int(train_images.shape[0] * percent_samples_for_validation) if num_val_samples is None else num_val_samples + + # add trailing channel dimension if necessary + if train_images.ndim == 3: + train_images = train_images[..., np.newaxis] + + self.image_shape = train_images.shape[1:4] # 3D to include the image channels + + # Use the make dataset generators function because training data may be modified here during training + # (i.e. adding small amounts of noise to account for discrete data and continuous model) + _, dataset_fn = make_dataset_generators(train_images, batch_size=400, num_val_samples=train_images.shape[0], + add_gaussian_noise=add_gaussian_noise, add_uniform_noise=add_uniform_noise, + seed=data_seed) + example_images = dataset_fn().next() # TODO can make this batch size bigger if needed just to get the settings for the values in the following model initialization, currently at 400 + + if self._flax_model is None: + self.add_gaussian_noise = add_gaussian_noise + self.add_uniform_noise = add_uniform_noise + self._flax_model = _MultiChannelPixelCNNFlaxImpl(num_hidden_channels=self.num_hidden_channels, num_mixture_components=self.num_mixture_components, + train_data_mean=np.mean(example_images, axis=(0, 1, 2)), train_data_std=np.std(example_images, axis=(0, 1, 2)), + train_data_min=np.min(example_images, axis=(0, 1, 2)), train_data_max=np.max(example_images, axis=(0, 1, 2)), sigma_min=sigma_min, + condition_vector_size=None if condition_vectors is None else condition_vectors.shape[-1], + data_shape=train_images.shape[1:], use_positional_embedding=use_positional_embedding) + + # pass in an intial batch + initial_params = self._flax_model.init(model_key, train_images[:3], + condition_vectors[:3] if condition_vectors is not None else None) + + if do_lr_decay: + lr_schedule = optax.exponential_decay(init_value=learning_rate, + transition_steps=steps_per_epoch, + decay_rate=0.99,) + + self._optimizer = optax.adam(lr_schedule) + else: + self._optimizer = optax.adam(learning_rate) + + def apply_fn(params, x, condition_vector=None): + output = self._flax_model.apply(params, x, condition_vector) + return self._flax_model.compute_loss(*output, x) + + self._state = TrainState.create(apply_fn=apply_fn, params=initial_params, tx=self._optimizer) + + if condition_vectors is None: + + def loss_fn(params, state, imgs): + return state.apply_fn(params, imgs) + grad_fn = jax.value_and_grad(loss_fn) + + @jax.jit + def train_step(state, imgs): + """ + A standard gradient descent training step + """ + loss, grads = grad_fn(state.params, state, imgs) + state = state.apply_gradients(grads=grads) + return state, loss + else: + + def loss_fn(params, state, imgs, condition_vecs): + return state.apply_fn(params, imgs, condition_vecs) + grad_fn = jax.value_and_grad(loss_fn) + + @jax.jit + def train_step(state, imgs, condition_vecs): + """ + A standard gradient descent training step + """ + loss, grads = grad_fn(state.params, state, imgs, condition_vecs) + state = state.apply_gradients(grads=grads) + return state, loss + + + best_params, val_loss_history = train_model(train_images=train_images, condition_vectors=condition_vectors, train_step=train_step, + state=self._state, batch_size=batch_size, num_val_samples=int(num_val_samples), + add_gaussian_noise=add_gaussian_noise, add_uniform_noise=add_uniform_noise, + steps_per_epoch=steps_per_epoch, num_epochs=max_epochs, patience=patience, seed=data_seed, + verbose=verbose) + self._state = self._state.replace(params=best_params) + self.val_loss_history = val_loss_history + return val_loss_history + + + + def compute_negative_log_likelihood(self, data, conditioning_vecs=None, data_seed=None, average=True, verbose=True, seed=None): + """ + Compute the negative log-likelihood (NLL) of images under the trained PixelCNN model. + + Parameters + ---------- + data : ndarray + The input images for which to compute the NLL. + conditioning_vecs : ndarray, optional + Vectors to condition the image generation process (e.g., class labels). + data_seed : int, optional + Seed for data shuffling. + average : bool, optional + If True, return the average NLL over all images (default is True). + verbose : bool, optional + Whether to print progress (default is True). + seed : int, optional + Deprecated. Use data_seed instead. + + Returns + ------- + nll : float + The negative log-likelihood of the input images. + """ + # See superclass for docstring + if seed is not None: + warnings.warn("seed argument is deprecated. Use data_seed instead") + data_seed = seed + + if data.ndim == 3: + # add a trailing channel dimension if necessary + data = data[..., np.newaxis] + elif data.ndim == 2: + # add trailing channel and batch dimensions + data = data[np.newaxis, ..., np.newaxis] + + # check if data shape is different than image shape + if data.shape[1:4] != self.image_shape: + raise ValueError("Data shape is different than image shape of trained model. This is not yet supported" + "Expected {}, got {}".format(self.image_shape, data.shape[1:4])) + + # get test data generator. Here all data is "validation", because the data passed into this should already be + # (in the typical case) a test set + _, dataset_fn = make_dataset_generators(data, batch_size=32 if average else 1, num_val_samples=data.shape[0], + add_gaussian_noise=self.add_gaussian_noise, add_uniform_noise=self.add_uniform_noise, + condition_vectors=conditioning_vecs, seed=data_seed) + @jax.jit + def conditional_eval_step(state, imgs, condition_vecs): + return state.apply_fn(state.params, imgs, condition_vecs) + + return _evaluate_nll(dataset_fn(), self._state, return_average=average, + eval_step=conditional_eval_step if conditioning_vecs is not None else None, verbose=verbose) + + + def generate_samples(self, num_samples, conditioning_vecs=None, sample_shape=None, ensure_nonnegative=True, seed=None, verbose=True): + """ + Generate new images from the trained PixelCNN model by sampling pixel by pixel. + + Parameters + ---------- + num_samples : int + Number of images to generate. + conditioning_vecs : jax.Array, optional + Optional conditioning vectors. If provided, the shape should match + (num_samples, condition_vector_size). Default is None. + sample_shape : tuple of int or int, optional + Shape of the images to generate. If None, the model's image_shape is used. + If a single int is provided, it will be treated as a square shape. Default is None. + ensure_nonnegative : bool, optional + If True, ensure that the generated pixel values are non-negative. Default is True. + seed : int, optional + Random seed for reproducibility. Default is 123 if not provided. + verbose : bool, optional + If True, display progress during the generation process. Default is True. + + Returns + ------- + jax.Array + Generated images with the specified shape. + """ + if seed is None: + seed = 123 + key = jax.random.PRNGKey(seed) + if sample_shape is None: + sample_shape = self.image_shape + if type(sample_shape) == int: + sample_shape = (sample_shape, sample_shape) + + if conditioning_vecs is not None: + assert conditioning_vecs.shape[0] == num_samples + assert conditioning_vecs.shape[1] == self._flax_model.condition_vector_size + + sampled_images = onp.zeros((num_samples, *sample_shape)) + for i in tqdm(onp.arange(sample_shape[0]), desc='Generating PixelCNN samples') if verbose else np.arange(sample_shape[0]): + for j in onp.arange(sample_shape[1]): + i_limits = max(0, i - self.image_shape[0] + 1), max(self.image_shape[0], i+1) + j_limits = max(0, j - self.image_shape[1] + 1), max(self.image_shape[1], j+1) + + conditioning_images = sampled_images[:, i_limits[0]:i_limits[1], j_limits[0]:j_limits[1]] + i_in_cropped_image = i - i_limits[0] + j_in_cropped_image = j - j_limits[0] + + assert conditioning_images.shape[1:] == self.image_shape + + key, key2 = jax.random.split(key) + if conditioning_vecs is None: + mu, sigma, mix_logit = self._flax_model.apply(self._state.params, conditioning_images) + else: + mu, sigma, mix_logit = self._flax_model.apply(self._state.params, conditioning_images, conditioning_vecs) + # only sampling one pixel at a time + # make onp arrays for range checking + mu = onp.array(mu)[:, i_in_cropped_image, j_in_cropped_image, :] + sigma = onp.array(sigma)[:, i_in_cropped_image, j_in_cropped_image, :] + mix_logit = onp.array(mix_logit)[:, i_in_cropped_image, j_in_cropped_image, :] + + # mix_probs = np.exp(mix_logit - logsumexp(mix_logit, axis=-1, keepdims=True)) # this was commented out in 1D pixelcnn as well + component_indices = jax.random.categorical(key, mix_logit, axis=-1) + # draw categorical sample + sample_mus = mu[np.arange(num_samples), component_indices] + sample_sigmas = sigma[np.arange(num_samples), component_indices] + #sample = jax.random.normal(key2, shape=sample_mus.shape) * sample_sigmas + sample_mus # 1D pixelcnn version + # switching to a multivariate normal distribution for the sigmas + sample = jax.random.multivariate_normal(key2, sample_mus, sample_sigmas) + sampled_images[:, i, j] = sample + + if ensure_nonnegative: + sampled_images = np.where(sampled_images < 0, 0, sampled_images) + return sampled_images +