diff --git a/examples/hypoAL.ipynb b/examples/hypoAL.ipynb
deleted file mode 100644
index cbd9920..0000000
--- a/examples/hypoAL.ipynb
+++ /dev/null
@@ -1,1521 +0,0 @@
-{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "hypoAL_paper_Algo1.ipynb",
- "provenance": [],
- "collapsed_sections": [],
- "mount_file_id": "1-ixeuqsSo4uHHfnuQ7zq54N3i_k4H3wf",
- "authorship_tag": "ABX9TyNqOwrUIy0s6MQz1n+a+fhU",
- "include_colab_link": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- }
- },
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Hypothesis learning: toy data example"
- ],
- "metadata": {
- "id": "KswNo4REitip"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "This notebook demonstrates how to apply the hypothesis learning to toy data. The [hypothesis learning](https://arxiv.org/abs/2112.06649) is based on the idea that in active learning, the correct model of the system’s behavior leads to a faster decrease in the overall Bayesian uncertainty about the system under study. In the hypothesis learning setup, probabilistic models of the possible system’s behaviors (hypotheses) are wrapped into structured GPs, and a basic reinforcement learning policy is used to select a correct model from several competing hypotheses.\n",
- "\n",
- "*Prepared by Maxim Ziatdinov (March 2022)*"
- ],
- "metadata": {
- "id": "4M3tFS1hiebL"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "Installations:"
- ],
- "metadata": {
- "id": "H7NQ_pNfi2pe"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "9cNHFhBbKjzz"
- },
- "source": [
- "!pip install -q git+https://github.com/ziatdinovmax/gpax"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "3RmuZCkiH01r",
- "cellView": "form"
- },
- "source": [
- "#@title Imports\n",
- "from typing import Union, Dict, Type\n",
- "\n",
- "import gpax\n",
- "\n",
- "import jax.numpy as jnp\n",
- "import jax.random as jra\n",
- "import numpy as onp\n",
- "import numpyro\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "gpax.utils.enable_x64()"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "ukjdW3IYiHeu",
- "cellView": "form"
- },
- "source": [
- "#@title Plotting and data utilities { form-width: \"20%\" }\n",
- "\n",
- "def get_training_data(X, Y, num_seed_points=2, rng_seed=42, **kwargs):\n",
- " onp.random.seed(rng_seed)\n",
- " indices = jnp.arange(len(X))\n",
- " idx = kwargs.get(\"list_of_indices\")\n",
- " if idx is not None:\n",
- " idx = onp.array(idx)\n",
- " else:\n",
- " idx = onp.random.randint(0, len(X), num_seed_points)\n",
- " idx = onp.unique(idx)\n",
- " X_train, y_train = X[idx], Y[idx]\n",
- " indices_train = indices[idx]\n",
- " X_test = jnp.delete(X, idx)\n",
- " y_test = jnp.delete(Y, idx)\n",
- " indices_test = jnp.delete(indices, idx)\n",
- " return X_train, y_train, X_test, y_test, indices_train, indices_test\n",
- " \n",
- "\n",
- "def plot_results(X_measured, y_measured, X_unmeasured, y_pred, y_sampled, obj, model_idx, rewards, **kwargs):\n",
- " X = jnp.concatenate([X_measured, X_unmeasured], axis=0).sort()\n",
- " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
- " ax1.scatter(X_measured, y_measured, marker='x', s=100, c='k', label=\"Measured points\", zorder=1)\n",
- " ax1.plot(X, y_pred, c='red', label='Model reconstruction', zorder=0)\n",
- " ax1.fill_between(X, y_pred - y_sampled.std(0), y_pred + y_sampled.std(0),\n",
- " color='r', alpha=0.2, label=\"Model uncertainty\", zorder=0)\n",
- " ax1.set_xlabel(\"$x$\", fontsize=18)\n",
- " ax1.set_ylabel(\"$y$\", fontsize=18)\n",
- " ax2.plot(X_unmeasured, obj, c='k')\n",
- " ax2.vlines(X_unmeasured[obj.argmax()], obj.min(), obj.max(), linestyles='dashed', label= \"Next point\")\n",
- " ax2.set_xlabel(\"$x$\", fontsize=18)\n",
- " ax2.set_ylabel(\"Acquisition function\", fontsize=18)\n",
- " ax1.legend(loc=\"upper left\")\n",
- " ax2.legend(loc=\"upper left\")\n",
- " step = kwargs.get(\"e\", 0)\n",
- " plt.suptitle(\"Step: {}, Sampled Model: {}, Rewards: {}\".format(\n",
- " step+1, model_idx, onp.around(rewards, 3).tolist()), fontsize=24)\n",
- " fig.savefig(\"./{}.png\".format(step))\n",
- " plt.show() \n",
- " \n",
- "\n",
- "def plot_acq(x, obj, idx):\n",
- " plt.plot(x.squeeze(), obj, c='k')\n",
- " plt.vlines(x[idx], obj.min(), obj.max(), linestyles='dashed')\n",
- " plt.xlabel(\"$x$\", fontsize=18)\n",
- " plt.ylabel(\"Acquisition function\", fontsize=18)\n",
- " plt.show()\n",
- " \n",
- "\n",
- "def plot_final_result(X, y, X_test, y_pred, y_sampled, seed_points):\n",
- " plt.figure(dpi=100)\n",
- " plt.scatter(X[seed_points:], y[seed_points:], c=jnp.arange(1, len(X[seed_points:])+1),\n",
- " cmap='viridis', label=\"Sampled points\", zorder=2)\n",
- " cbar = plt.colorbar(label=\"Exploration step\")\n",
- " cbar_ticks = jnp.arange(2, len(X[seed_points:]) + 1, 2)\n",
- " cbar.set_ticks(cbar_ticks)\n",
- " plt.scatter(X[:seed_points], y[:seed_points], marker='x', s=64,\n",
- " c='k', label=\"Seed points\", zorder=1)\n",
- " plt.plot(X_test, y_pred, '--', c='red', label='Model reconstruction', zorder=1)\n",
- " plt.plot(X_test, truefunc, c='k', label=\"Ground truth\", zorder=0)\n",
- " plt.fill_between(X_test, y_pred - y_sampled.std(0), y_pred + y_sampled.std(0),\n",
- " color='r', alpha=0.2, label=\"Model uncertainty\", zorder=0)\n",
- " plt.xlabel(\"$x$\", fontsize=12)\n",
- " plt.ylabel(\"$y$\", fontsize=12)\n",
- " plt.legend(fontsize=9, loc='upper left')\n",
- " #plt.ylim(1.8, 6.6)\n",
- " plt.show()"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "IdXQ-tJXhmUC"
- },
- "source": [
- "First, let's generate some data. As a practical example chosen here, we are interested in the active learning of phase\n",
- "diagram that has a transition between different phases. The phase transition manifests in discontinuity of a measurable system’s property, such as heat capacity. However, we usually do not know where a phase transition occurs precisely, nor are we aware of the exact behavior of the property of interest in different phases. We note that using a standard Gaussian process-based active learning is not an optimal choice in such a case as simple GP struggles around the discontinuity point."
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 400
- },
- "id": "Onu_jtyMH3q2",
- "outputId": "70b13847-8d70-427b-b1d8-debc43fb2d9d"
- },
- "source": [
- "def function_(x: jnp.ndarray, params: Dict[str, float]) -> jnp.ndarray:\n",
- " return jnp.piecewise(\n",
- " x, [x < params[\"t\"], x >= params[\"t\"]],\n",
- " [lambda x: x**params[\"beta1\"], lambda x: x**params[\"beta2\"]])\n",
- "\n",
- "\n",
- "X = jnp.linspace(0.0, 2.5, 100)\n",
- "params = {\"t\": 1.6, \"beta1\": 4, \"beta2\": 2.5}\n",
- "\n",
- "truefunc = function_(X, params)\n",
- "Y = truefunc + 0.33 * jra.normal(jra.PRNGKey(0), shape=truefunc.shape)\n",
- "\n",
- "_, ax = plt.subplots(dpi=100) \n",
- "ax.scatter(X, Y, alpha=0.5, c='k', label=\"Noisy observations\")\n",
- "ax.plot(X, truefunc, lw=2, c='k', label=\"True function\")\n",
- "ax.legend()\n",
- "ax.set_xlabel(\"$x$\")\n",
- "ax.set_ylabel(\"$y$\")"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "Text(0, 0.5, '$y$')"
- ]
- },
- "metadata": {},
- "execution_count": 4
- },
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- "