From f42e72ad3a1fe4494f9ed962413d3e4de8fec1e3 Mon Sep 17 00:00:00 2001 From: Tianqing Zhang Date: Thu, 1 Aug 2024 19:37:41 -0400 Subject: [PATCH] add back pzflow demo --- .../estimation_examples/pzflow_demo.ipynb | 191 ++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 examples/estimation_examples/pzflow_demo.ipynb diff --git a/examples/estimation_examples/pzflow_demo.ipynb b/examples/estimation_examples/pzflow_demo.ipynb new file mode 100644 index 0000000..ce813f3 --- /dev/null +++ b/examples/estimation_examples/pzflow_demo.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "327d391f-58bc-4b6a-9bbe-3987b969c8f4", + "metadata": {}, + "source": [ + "PZFlow Informer and Estimator Demo\n", + "\n", + "Author: Tianqing Zhang\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "916a05ad", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import rail\n", + "from rail.core.data import TableHandle\n", + "from rail.core.stage import RailStage\n", + "import qp\n", + "import tables_io\n", + "\n", + "from rail.estimation.algos.pzflow_nf import PZFlowInformer, PZFlowEstimator\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8ef87d3", + "metadata": {}, + "outputs": [], + "source": [ + "DS = RailStage.data_store\n", + "DS.__class__.allow_overwrite = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f79c3a7b", + "metadata": {}, + "outputs": [], + "source": [ + "from rail.utils.path_utils import find_rail_file\n", + "trainFile = find_rail_file('examples_data/testdata/test_dc2_training_9816.hdf5')\n", + "testFile = find_rail_file('examples_data/testdata/test_dc2_validation_9816.hdf5')\n", + "training_data = DS.read_file(\"training_data\", TableHandle, trainFile)\n", + "test_data = DS.read_file(\"test_data\", TableHandle, testFile)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "756d78a3", + "metadata": {}, + "outputs": [], + "source": [ + "pzflow_dict = dict(hdf5_groupname='photometry',output_mode = 'not_fiducial' )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0857e6bb-18eb-4f89-bc4b-29bed1ffa122", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1042a9f3", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# epoch = 200 gives a reasonable converged loss\n", + "pzflow_train = PZFlowInformer.make_stage(name='inform_pzflow',model='demo_pzflow.pkl',num_training_epochs = 30, **pzflow_dict)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ae0180b-b67f-4ac1-bf90-545c7c2fbb62", + "metadata": {}, + "outputs": [], + "source": [ + "pzflow_train.finalize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c407f45b", + "metadata": {}, + "outputs": [], + "source": [ + "# training of the pzflow\n", + "pzflow_train.inform(training_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "156b6e3d", + "metadata": {}, + "outputs": [], + "source": [ + "pzflow_dict = dict(hdf5_groupname='photometry')\n", + "\n", + "pzflow_estimator = PZFlowEstimator.make_stage(name='estimate_pzflow',model='demo_pzflow.pkl',**pzflow_dict, chunk_size = 20000)\n", + "\n", + "pzflow_estimator._aliases = dict()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00911d60", + "metadata": {}, + "outputs": [], + "source": [ + "# estimate using the test data\n", + "estimate_results = pzflow_estimator.estimate(test_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cbdece3", + "metadata": {}, + "outputs": [], + "source": [ + "mode = estimate_results.data.ancil['zmode']\n", + "truth = np.array(training_data.data['photometry']['redshift'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba076bab-c5ab-4292-8de9-415e7b30af5c", + "metadata": {}, + "outputs": [], + "source": [ + "# visualize the prediction. \n", + "plt.figure(figsize = (8,8))\n", + "plt.scatter(truth, mode, s = 0.5)\n", + "plt.xlabel('True Redshift')\n", + "plt.ylabel('Mode of Estimated Redshift')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed5bf266-3b5c-4d9b-8428-77a2833cafef", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rail_env", + "language": "python", + "name": "rail_env" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}