diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9e48956 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.*~ +*.pkl +__pycache__ +*.egg-info \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f5229d3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Computational Neuroscience, University of Bern + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..63a9e19 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +# dendritic-opinion-pooling +Library for models implementing the dendritic opinion pooling framework. + +### Installation + +Run `pip install .` from the root directory. + +### Examples + +See `examples/` for example applications. + +### Tests + +Run `pytest` from the root directory (requires [pytest](https://docs.pytest.org/en/stable/index.html)). diff --git a/dopp/__init__.py b/dopp/__init__.py new file mode 100644 index 0000000..f4d85a4 --- /dev/null +++ b/dopp/__init__.py @@ -0,0 +1,11 @@ +__version__ = "0.0.1" +__maintainer__ = "Jakob Jordan" +__author__ = "Jakob Jordan" +__license__ = "MIT" +__description__ = "Library for models implementing the dendritic opinion pooling framework." +__url__ = "https://github.com/unibe-cns/dendritic-opinion-pooling" +__doc__ = f"{__description__} <{__url__}>" + +from .dynamic_feedforward_cell import DynamicFeedForwardCell +from .feedforward_cell import FeedForwardCell +from .feedforward_current_cell import FeedForwardCurrentCell diff --git a/dopp/abstract_convex_cell.py b/dopp/abstract_convex_cell.py new file mode 100644 index 0000000..c3440e9 --- /dev/null +++ b/dopp/abstract_convex_cell.py @@ -0,0 +1,242 @@ +import math +import torch + + +class AbstractConvexCell(torch.nn.Module): + def __init__(self, in_features_per_dendrite, out_features): + super().__init__() + + self.in_features_per_dendrite = in_features_per_dendrite + self.out_features = out_features + + self.EE = 0.0 # mV + self.EI = -85.0 # mV + self.gL0 = 0.166667 # nS + self.gLd = torch.ones(self.n_dendrites) * self.gL0 # nS + self.EL = -70.0 # mV + self.gc = torch.ones(self.n_dendrites) * 50000.0 * self.gL0 # nS + self.input_scale = torch.ones(self.in_features) + self.lambda_e = 1.0 + + self._omegaE = [ + torch.nn.Parameter( + torch.zeros( + self.in_features_per_dendrite[d], + self.out_features, + dtype=torch.double, + ) + ) + for d in range(self.n_dendrites) + ] + self._omegaI = [ + torch.nn.Parameter( + torch.zeros( + self.in_features_per_dendrite[d], + self.out_features, + dtype=torch.double, + ) + ) + for d in range(self.n_dendrites) + ] + + for d in range(self.n_dendrites): + self.register_parameter(f"omegaE{d}", self._omegaE[d]) + self.register_parameter(f"omegaI{d}", self._omegaI[d]) + + self.init_weights() + + self.alpha = 1.0 + theta = self.EL - 0. + self.f = ( + lambda u: 1.0 + / self.alpha + * torch.nn.functional.softplus(self.alpha * (u - theta)) + ) + self.f_inv = ( + lambda r: 1.0 / self.alpha * torch.log(torch.exp(self.alpha * r) - 1.0) + + theta + ) + + @property + def in_features(self): + return sum(self.in_features_per_dendrite) + + @property + def n_dendrites(self): + return len(self.in_features_per_dendrite) + + @property + def min_rate(self): + return self.f(torch.DoubleTensor([self.EI])).item() + + @property + def max_rate(self): + return self.f(torch.DoubleTensor([self.EE])).item() + + def assert_valid_rate(self, r): + assert torch.all(self.min_rate <= r) + assert torch.all(r <= self.max_rate) + + def assert_valid_conductance(self, g): + assert torch.all(0 <= g) + + def assert_valid_voltage(self, v): + assert torch.all(self.EI <= v) + assert torch.all(v <= self.EE) + + def weights_from_omega(self, omega): + return torch.nn.functional.softplus(omega) + + def omega_from_weights(self, weights): + assert torch.all(weights >= 0.0) + return torch.log(torch.exp(weights) - 1.0) + + def weightsE(self, i): + weightsEi = self.weights_from_omega(self._omegaE[i]) + assert torch.all(weightsEi >= 0.0) + return weightsEi + + def weightsI(self, i): + weightsIi = self.weights_from_omega(self._omegaI[i]) + assert torch.all(weightsIi >= 0.0) + return weightsIi + + def set_weightsE(self, d, val): + assert torch.all(val >= 0.0) + assert val.shape == self._omegaE[d].shape + self._omegaE[d].data = self.omega_from_weights(val) + + def set_weightsI(self, d, val): + assert torch.all(val >= 0.0) + assert val.shape == self._omegaI[d].shape + self._omegaI[d].data = self.omega_from_weights(val) + + def scale_weightsE(self, scale, d=None): + if d is None: + for d in range(self.n_dendrites): + self.set_weightsE(d, scale * self.weightsE(d)) + else: + self.set_weightsE(d, scale * self.weightsE(d)) + + def scale_weightsI(self, scale, d=None): + if d is None: + for d in range(self.n_dendrites): + self.set_weightsI(d, scale * self.weightsI(d)) + else: + self.set_weightsI(d, scale * self.weightsI(d)) + + def set_input_scale(self, d, val): + self.input_scale[self._input_slice(d)] = val + + def init_weights(self): + scale = 0.2 + for d, in_features in enumerate(self.in_features_per_dendrite): + if in_features > 0: + stdv = scale * 1.0 / math.sqrt(in_features) + initial_weightE = torch.DoubleTensor( + self.in_features_per_dendrite[d], self.out_features + ).uniform_(0, stdv) * (self.EL - self.EI) / (self.EE - self.EL) + self._omegaE[d].data = self.omega_from_weights(initial_weightE) + initial_weightI = torch.DoubleTensor( + self.in_features_per_dendrite[d], self.out_features + ).uniform_(0, stdv) + self._omegaI[d].data = self.omega_from_weights(initial_weightI) + + def _input_slice(self, d): + if d == 0: + return slice(0, self.in_features_per_dendrite[0]) + else: + return slice( + sum(self.in_features_per_dendrite[:d]), + sum(self.in_features_per_dendrite[: d + 1]), + ) + + def dendritic_input(self, u_in, d): + r_in = self.input_scale * self.f(u_in) + return r_in[:, self._input_slice(d)] + + def forward(self, u_in): + + self.assert_valid_voltage(u_in) + assert u_in.shape[1] == self.in_features + + g0, u0 = self._forward(u_in) + + if g0 is not None: + self.assert_valid_conductance(g0) + self.assert_valid_voltage(u0) + return g0, u0 + + def copy_omegaE_omegaI_from(self, other): + assert self.n_dendrites == other.n_dendrites + + for d in range(self.n_dendrites): + self._omegaE[d].data = other._omegaE[d].data.clone() + self._omegaI[d].data = other._omegaI[d].data.clone() + + def compute_gff0_and_uff0(self): + gff0 = torch.ones(1, self.out_features, dtype=torch.double) * self.gL0 + uff0 = torch.ones(1, self.out_features, dtype=torch.double) * self.EL + + self.assert_valid_conductance(gff0) + self.assert_valid_voltage(uff0) + return gff0, uff0 + + def compute_gEd_gId(self, u_in, d): + gEd = torch.mm(self.dendritic_input(u_in, d), self.weightsE(d)) + gId = torch.mm(self.dendritic_input(u_in, d), self.weightsI(d)) + self.assert_valid_conductance(gEd) + self.assert_valid_conductance(gId) + + return gEd, gId + + def compute_gffd_and_uffd(self, u_in): + gffd = torch.empty(len(u_in), self.out_features, self.n_dendrites, dtype=torch.double) + uffd = torch.empty(len(u_in), self.out_features, self.n_dendrites, dtype=torch.double) + + assert len(self.gLd) == self.n_dendrites + + for d in range(self.n_dendrites): + gEd, gId = self.compute_gEd_gId(u_in, d) + gffd[:, :, d] = self.gLd[d] + gEd + gId + # need to use cloned gffd to avoid pytorch + # inplace-modification error when calling backward() when + # training with backprop + uffd[:, :, d] = ( + self.gLd[d] * self.EL + gEd * self.EE + gId * self.EI + ) / gffd[:, :, d].clone() + + self.assert_valid_conductance(gffd) + self.assert_valid_voltage(uffd) + return gffd, uffd + + def compute_Iffd(self, u_in): + Iffd = torch.empty(len(u_in), self.out_features, self.n_dendrites, dtype=torch.double) + + for d in range(self.n_dendrites): + gEd, gId = self.compute_gEd_gId(u_in, d) + Iffd[:, :, d] = gEd - gId + + return Iffd + + def sample(self, g0, u0): + raise NotImplementedError() + + def _forward(self, u_in): + raise NotImplementedError() + + def energy_target(self, u0_target, g0, u0): + raise NotImplementedError() + + def loss_target(self, u0_target, g0, u0): + raise NotImplementedError() + + def compute_grad_manual_target(self, u0_target, g0, u0, u_in): + raise NotImplementedError() + + def apply_grad_weights(self, lr): + for d in range(self.n_dendrites): + self._omegaE[d].data -= lr * self._omegaE[d]._grad + assert torch.all(self.weightsE(d) > 0.0) + self._omegaI[d].data -= lr * self._omegaI[d]._grad + assert torch.all(self.weightsI(d) > 0.0) diff --git a/dopp/dynamic_feedforward_cell.py b/dopp/dynamic_feedforward_cell.py new file mode 100644 index 0000000..af75761 --- /dev/null +++ b/dopp/dynamic_feedforward_cell.py @@ -0,0 +1,35 @@ +from scipy.integrate import solve_ivp +import torch + +from .feedforward_cell import FeedForwardCell + + +class DynamicFeedForwardCell(FeedForwardCell): + def __init__(self, in_features_per_dendrite, out_features): + super().__init__(in_features_per_dendrite, out_features) + + self.dt = 0.5 # ms + self.cm0 = 250.0 # pF + + self.u0 = torch.ones(1, self.out_features) * self.EL + + def _forward(self, u_in): + + assert len(u_in) == 1, "batch size larger than one not supported" + + gff0, uff0 = self.compute_gff0_and_uff0() + gffd, uffd = self.compute_gffd_and_uffd(u_in) + g0, u0 = self._compute_g0_and_u0(gff0, uff0, gffd, uffd) + + def rhs(t, u): + return (g0[0].numpy() * (u0[0].numpy() - u)) * 1.0 / self.cm0 + + res_ivp = solve_ivp( + rhs, (0.0, self.dt), self.u0[0].numpy(), method="RK23", max_step=self.dt + ).y[:, -1] + + self.u0[0] = torch.Tensor(res_ivp) + return g0, self.u0 + + def initialize_somatic_potential(self, u): + self.u0[0, :] = u diff --git a/dopp/feedforward_cell.py b/dopp/feedforward_cell.py new file mode 100644 index 0000000..cd45937 --- /dev/null +++ b/dopp/feedforward_cell.py @@ -0,0 +1,99 @@ +import math +import torch + +from .abstract_convex_cell import AbstractConvexCell + + +class FeedForwardCell(AbstractConvexCell): + def __init__(self, in_features_per_dendrite, out_features): + super().__init__(in_features_per_dendrite, out_features) + + def sample(self, g0, u0): + return u0 + torch.sqrt(self.lambda_e / g0) * torch.empty_like(u0, dtype=torch.double).normal_() + + def _compute_g0_and_u0(self, gff0, uff0, gffd, uffd): + g0 = torch.empty_like(gff0, dtype=torch.double) + u0 = torch.empty_like(uff0, dtype=torch.double) + if self.gc is None: + g0 = gff0 + torch.sum(gffd, dim=2) + u0 = (gff0 * uff0 + torch.sum(gffd * uffd, dim=2)) / g0 + else: + g0 = gff0 + torch.sum(self.gc * gffd / (gffd + self.gc), dim=2) + u0 = ( + gff0 * uff0 + torch.sum(self.gc * gffd / (gffd + self.gc) * uffd, dim=2) + ) / g0 + + return g0, u0 + + def _forward(self, u_in): + + gff0, uff0 = self.compute_gff0_and_uff0() + gffd, uffd = self.compute_gffd_and_uffd(u_in) + + g0, u0 = self._compute_g0_and_u0(gff0, uff0, gffd, uffd) + + return g0, u0 + + def compute_grad_manual_target(self, u0_target, g0, u0, u_in): + + self.assert_valid_voltage(u0_target) + self.assert_valid_conductance(g0) + self.assert_valid_voltage(u0) + assert u0_target.shape == u0.shape + + omegaE_grad = [ + torch.empty_like(self._omegaE[d], dtype=torch.double) for d in range(self.n_dendrites) + ] + omegaI_grad = [ + torch.empty_like(self._omegaI[d], dtype=torch.double) for d in range(self.n_dendrites) + ] + for d in range(self.n_dendrites): + if self.gc is None: + omegaE_grad_d = ( + 1.0 / self.lambda_e * (u0_target - u0) * (self.EE - u0) + ) + omegaI_grad_d = ( + 1.0 / self.lambda_e * (u0_target - u0) * (self.EI - u0) + ) + else: + raise NotImplementedError() + + if self.gc is None: + quad_error = -0.5 * ( + 1.0 / self.lambda_e * (u0_target - u0) ** 2 + - 1.0 / g0 + ) + else: + raise NotImplementedError() + + omegaE_grad_d += quad_error + omegaI_grad_d += quad_error + + omegaE_grad[d] = ( + torch.einsum( + "ij,ik->jk", [self.dendritic_input(u_in, d), omegaE_grad_d] + ) + * 1.0 + / (1.0 + torch.exp(-self._omegaE[d])) + ) + omegaI_grad[d] = ( + torch.einsum( + "ij,ik->jk", [self.dendritic_input(u_in, d), omegaI_grad_d] + ) + * 1.0 + / (1.0 + torch.exp(-self._omegaI[d])) + ) + + assert not torch.all(torch.isnan(omegaE_grad[d])) + assert not torch.all(torch.isnan(omegaI_grad[d])) + + self._omegaE[d]._grad = -omegaE_grad[d].data + self._omegaI[d]._grad = -omegaI_grad[d].data + + def energy_target(self, u0_target, g0, u0): + return 1.0 / self.lambda_e * g0 / 2.0 * ( + u0_target - u0 + ) ** 2 + 0.5 * torch.log(2 * math.pi * self.lambda_e / g0) + + def loss_target(self, u0_target, g0, u0): + return 0.5 * (u0_target - u0) ** 2 diff --git a/dopp/feedforward_current_cell.py b/dopp/feedforward_current_cell.py new file mode 100644 index 0000000..0fc3acb --- /dev/null +++ b/dopp/feedforward_current_cell.py @@ -0,0 +1,90 @@ +import math +import torch + +from .abstract_convex_cell import AbstractConvexCell + + +class FeedForwardCurrentCell(AbstractConvexCell): + def __init__(self, in_features_per_dendrite, out_features): + super().__init__(in_features_per_dendrite, out_features) + + self.gc = None + + def init_weights(self): + scale = 2.0 + for d, in_features in enumerate(self.in_features_per_dendrite): + if in_features > 0: + stdv = scale * 1.0 / math.sqrt(in_features) + initial_weightE = torch.DoubleTensor( + self.in_features_per_dendrite[d], self.out_features + ).uniform_(0, stdv) * 2.5 * (self.EL - self.EI) / (self.EE - self.EL) + self._omegaE[d].data = self.omega_from_weights(initial_weightE) + initial_weightI = torch.DoubleTensor( + self.in_features_per_dendrite[d], self.out_features + ).uniform_(0, stdv) + self._omegaI[d].data = self.omega_from_weights(initial_weightI) + + def assert_valid_voltage(self, v): + pass + + def _compute_u0(self, uff0, Iffd): + return uff0 + torch.sum(Iffd, dim=2) + + def _forward(self, u_in): + + assert self.gc is None + + gff0, uff0 = self.compute_gff0_and_uff0() + Iffd = self.compute_Iffd(u_in) + u0 = self._compute_u0(uff0, Iffd) + + return None, u0 + + def compute_grad_manual_target(self, u0_target, _, u0, u_in): + + assert self.gc is None + + self.assert_valid_voltage(u0_target) + self.assert_valid_voltage(u0) + self.assert_valid_voltage(u_in) + + assert u0_target.shape == u0.shape + + omegaE_grad = [ + torch.empty_like(self._omegaE[d]) for d in range(self.n_dendrites) + ] + omegaI_grad = [ + torch.empty_like(self._omegaI[d]) for d in range(self.n_dendrites) + ] + for d in range(self.n_dendrites): + omegaE_grad_d = (u0_target - u0) + omegaI_grad_d = -(u0_target - u0) + + omegaE_grad[d] = ( + torch.einsum( + "ij,ik->jk", [self.dendritic_input(u_in, d), omegaE_grad_d] + ) + * 1.0 + / (1.0 + torch.exp(-self._omegaE[d])) + ) + omegaI_grad[d] = ( + torch.einsum( + "ij,ik->jk", [self.dendritic_input(u_in, d), omegaI_grad_d] + ) + * 1.0 + / (1.0 + torch.exp(-self._omegaI[d])) + ) + + assert not torch.all(torch.isnan(omegaE_grad[d])) + assert not torch.all(torch.isnan(omegaI_grad[d])) + + self._omegaE[d]._grad = -omegaE_grad[d].data + self._omegaI[d]._grad = -omegaI_grad[d].data + + def energy_target(self, u0_target, _, u0): + return 1.0 / self.lambda_e * 1. / 2.0 * ( + u0_target - u0 + ) ** 2 + 0.5 * torch.log(2 * math.pi * self.lambda_e / torch.ones_like(u0)) + + def loss_target(self, u0_target, _, u0): + return 0.5 * (u0_target - u0) ** 2 diff --git a/examples/match_target_distribution/figure_config.py b/examples/match_target_distribution/figure_config.py new file mode 100644 index 0000000..237aa30 --- /dev/null +++ b/examples/match_target_distribution/figure_config.py @@ -0,0 +1,57 @@ +import math + +GR = (1. + math.sqrt(5)) / 2. + +# custom colors +colors = { + 'prior': '#909090', + 'priorweak': '#D0D0D0', + 'V': '#346ebf', + 'Vgray': '#698dbf', + 'Vweak': '#b8caec', + 'Valt': '#b8e5ed', + 'T': '#2faf41', + 'Tgray': '#60af6b', + 'Tweak': '#b7edc0', + 'VT': '#ee1d23', + 'VTalt': '#ed1c8b', + 'VTweak': '#eda6a8', + 'energy': '#111111', + 'MAP': '0.3', + 'sym': '0.8', + 'E': '#1F2041', + 'I': '#FFC857', +} + +# custom fontsizes +fontsize_medium = 12 +fontsize_small = 0.8 * fontsize_medium +fontsize_xsmall = 0.7 * fontsize_medium +fontsize_tiny = 0.6 * fontsize_medium +fontsize_xtiny = 0.5 * fontsize_medium +fontsize_xxtiny = 0.4 * fontsize_medium +fontsize_large = 1.2 * fontsize_medium +fontsize_xlarge = 1.4 * fontsize_medium +fontsize_xxlarge = 1.6 * fontsize_medium + +# custom line widths +lw_medium = 2. +lw_narrow = 1. + +# custom figure sizes +fig_scaling = 0.8 +single_figure = (fig_scaling * 6.4, fig_scaling * 3.96) # using default width and golden ratio +double_figure_horizontal = (1.5 * single_figure[0], single_figure[1]) +triple_figure_horizontal = (2.2 * single_figure[0], single_figure[1]) +double_figure_vertical = (single_figure[0], 1.3 * single_figure[1]) +triple_figure_vertical = (single_figure[0], 1.6 * single_figure[1]) +quad_figure = (6.4, 6.4) + +mpl_style = { + 'axes.spines.top': False, + 'axes.spines.right': False, + 'figure.dpi': 300, ## figure dots per inch + 'xtick.labelsize': fontsize_xsmall, + 'ytick.labelsize': fontsize_xsmall, + 'axes.labelsize': fontsize_medium, +} diff --git a/examples/match_target_distribution/plotting.py b/examples/match_target_distribution/plotting.py new file mode 100644 index 0000000..1c33c1c --- /dev/null +++ b/examples/match_target_distribution/plotting.py @@ -0,0 +1,268 @@ +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pickle + +import figure_config + +mpl.rcParams.update(figure_config.mpl_style) + + +def gaussian_density(x, mu, sigma): + return 1. / np.sqrt(2. * np.pi * sigma ** 2) * np.exp(-(x - mu) ** 2 / (2. * sigma ** 2)) + + +def plot_rates(ax_early, ax_late): + ax_early.set_ylabel(r"$r$ (1/s)", fontsize=figure_config.fontsize_tiny) + ax_early.set_ylim(plot_params["ylim_rate"]) + ax_early.set_xlabel("Time (s)", fontsize=figure_config.fontsize_tiny) + ax_early.set_xlim(xlim_early) + + ax_late.set_yticks([]) + ax_late.spines["left"].set_visible(False) + ax_late.set_ylim(plot_params["ylim_rate"]) + ax_late.set_xticks([1897, 1899]) + + ax_early.plot( + times[indices_times_start], res["r_in"][indices_times_start], color="k", + ) + ax_early.plot( + times[indices_times_start], + res["r_in_noisy"][indices_times_start, 1], + color=figure_config.colors["T"], + ) + ax_early.plot( + times[indices_times_start], + res["r_in_noisy"][indices_times_start, 0], + color=figure_config.colors["V"], + ) + + ax_late.plot( + times[indices_times_late], res["r_in"][indices_times_late], color="k", + ) + ax_late.plot( + times[indices_times_late], + res["r_in_noisy"][indices_times_late, 1], + color=figure_config.colors["T"], + ) + ax_late.plot( + times[indices_times_late], + res["r_in_noisy"][indices_times_late, 0], + color=figure_config.colors["V"], + ) + + +def plot_potential(ax_early, ax_late): + + ax_early.set_ylabel("Membrane potential (mV)", fontsize=figure_config.fontsize_tiny) + ax_early.set_ylim(plot_params["ylim_potential"]) + ax_early.set_xlim(xlim_early) + ax_early.set_xticklabels([]) + + ax_late.set_xticklabels([]) + ax_late.set_yticks([]) + ax_late.spines["left"].set_visible(False) + ax_late.set_ylim(plot_params["ylim_potential"]) + + ax_early.plot( + times[indices_times_early][1:], + res["u0_target_sample"][indices_times_early][1:], + color="k", + ) + ax_early.plot( + times[indices_times_before], + res["u0_sample"][indices_times_before], + color=figure_config.colors["VT"], + ls=':', + ) + ax_early.plot( + times[indices_times_early], + res["u0_sample"][indices_times_early], + color=figure_config.colors["VT"], + ls='--', + ) + + ax_early.axvline(0.0, color="k", ls="--") + ax_early.annotate( + "teacher present", + xy=(0.3, 1.05), + xycoords="axes fraction", + xytext=(0.3, 1.05), + textcoords="axes fraction", + fontsize=figure_config.fontsize_tiny, + ) + + ax_late.plot( + times[indices_times_late], + res["u0_target_sample"][indices_times_late], + color="k", + ) + ax_late.plot( + times[indices_times_late], + res["u0_sample"][indices_times_late], + color=figure_config.colors["VT"], + ) + + +def plot_dist(ax): + ax.set_xlim(0.0, 1.0) + ax.set_ylim(plot_params["ylim_potential"]) + ax.set_yticks([]) + + x = np.linspace(-85.0, -50.0, 500) + + mu_target = np.mean(res["u0_target"][indices_times_late]) + sigma_target = 1.0 / np.sqrt(np.mean(res["g0_target"][indices_times_late])) + # print('teacher mu/sigma', mu_target, sigma_target) + mu_before = np.mean(res["u0"][indices_times_before]) + sigma_before = 1.0 / np.sqrt(np.mean(res["g0"][indices_times_before])) + mu_initial = np.mean(res["u0"][indices_times_early]) + sigma_initial = 1.0 / np.sqrt(np.mean(res["g0"][indices_times_early])) + mu_final = np.mean(res["u0"][indices_times_late]) + sigma_final = 1.0 / np.sqrt(np.mean(res["g0"][indices_times_late])) + # print('final mu/sigma', mu_final, sigma_final) + ax.plot( + gaussian_density(x, mu_target, sigma_target), x, color="k", lw=3.0, + ) + ax.plot( + gaussian_density(x, mu_before, sigma_before,), + x, + color=figure_config.colors["VT"], + lw=2, + ls=":", + ) + ax.plot( + gaussian_density(x, mu_initial, sigma_initial,), + x, + color=figure_config.colors["VT"], + lw=2, + ls="--", + ) + ax.plot( + gaussian_density(x, mu_final, sigma_final,), + x, + color=figure_config.colors["VT"], + lw=2, + ) + + +def plot_cond(ax_early, ax_late, ax_rel): + ax_early.set_ylabel(r"$W_d^\mathsf{E}+W_d^\mathsf{I}$", fontsize=figure_config.fontsize_tiny) + ax_early.set_xlim(xlim_early) + ax_early.set_ylim(plot_params["ylim_cond"]) + ax_early.set_xticklabels([]) + + ax_late.set_xticklabels([]) + ax_late.set_yticks([]) + ax_late.spines["left"].set_visible(False) + ax_late.set_ylim(plot_params["ylim_cond"]) + + ax_rel.set_xlabel('Time (s)', fontsize=figure_config.fontsize_tiny) + ax_rel.set_ylabel('Rel. weight.', fontsize=figure_config.fontsize_tiny) + ax_rel.set_ylim(0., 1.) + + wd = res["wEd"] + res["wId"] + ax_early.plot( + times[indices_times_start], wd[:, 0][indices_times_start], color=figure_config.colors["V"], + ) + ax_early.plot( + times[indices_times_start], wd[:, 1][indices_times_start], color=figure_config.colors["T"], + ) + + ax_late.plot( + times[indices_times_late], wd[:, 0][indices_times_late], color=figure_config.colors["V"], + ) + ax_late.plot( + times[indices_times_late], wd[:, 1][indices_times_late], color=figure_config.colors["T"], + ) + + rel_sigma = ( + 1.0 + / params["sigma_0"] ** 2 + / (1.0 / params["sigma_0"] ** 2 + 1.0 / params["sigma_1"] ** 2) + ) + wd0_rel = wd[:, 0] / (wd[:, 0] + wd[:, 1]) + wd1_rel = wd[:, 1] / (wd[:, 0] + wd[:, 1]) + print(wd0_rel[0], "->", wd0_rel[-1], "<->", rel_sigma) + ax_rel.plot( + times, wd0_rel, ls="-", color=figure_config.colors["V"] + ) + ax_rel.plot( + times, wd1_rel, ls="-", color=figure_config.colors["T"] + ) + + +if __name__ == "__main__": + + sigma = 0.5 + + with open(f"params_{sigma}.pkl", "rb") as f: + params = pickle.load(f) + + plot_params = { + "t_before": (-2.0, 0.0), + "t_early": (0.0, 5.0), + "t_late": None, + "ylim_potential": (-85.0, -52.0), + "ylim_cond": (0.0, 1.5), + "ylim_rate": (0.0, 3.5), + } + + with open(f"res_{sigma}.pkl", "rb") as f: + res = pickle.load(f) + + fig = plt.figure(figsize=(5., 2.5)) + + x_early = 0.12 + width_early = 0.38 + x_late = 0.56 + width_late = 0.21 + + ax_rates_early = fig.add_axes([x_early, 0.15, width_early, 0.10], zorder=1) + ax_rates_late = fig.add_axes([x_late, 0.15, width_late, 0.10], zorder=1) + ax_pot_early = fig.add_axes([x_early, 0.55, width_early, 0.38]) + ax_pot_late = fig.add_axes([x_late, 0.55, width_late, 0.38]) + ax_dist = fig.add_axes([0.81, 0.55, 0.16, 0.4], zorder=-1) + ax_cond_early = fig.add_axes([x_early, 0.35, width_early, 0.1]) + ax_cond_late = fig.add_axes([x_late, 0.35, width_late, 0.1]) + ax_cond_rel = fig.add_axes([0.87, 0.15, 0.125, 0.25]) + + step_size = 1 + window_size = 5 + times = np.arange( + -params["relative_time_silent_teacher"] * params["trials"], + (1.0 - params["relative_time_silent_teacher"]) * params["trials"], + params["recording_interval"], + ) + alpha = 0.65 + + # convert trial indices to time + times *= 100.0 / params["recording_interval"] # 1 trial == 100 ms + times *= 1e-3 # ms to s + + plot_params["t_late"] = (times[-1] - 3.5, times[-1]) + + indices_times_before = np.where( + (plot_params["t_before"][0] <= times) & (times < plot_params["t_before"][1]) + )[0][::step_size] + indices_times_early = np.where( + (plot_params["t_early"][0] <= times) & (times < plot_params["t_early"][1]) + )[0][::step_size] + indices_times_late = np.where( + (plot_params["t_late"][0] <= times) & (times < plot_params["t_late"][1]) + )[0][::step_size] + + indices_times_start = np.hstack([indices_times_before, indices_times_early]) + + xlim_early = (plot_params["t_before"][0], plot_params["t_early"][1]) + xlim_late = plot_params["t_late"] + + plot_rates(ax_rates_early, ax_rates_late) + plot_potential(ax_pot_early, ax_pot_late) + plot_dist(ax_dist) + plot_cond(ax_cond_early, ax_cond_late, ax_cond_rel) + + figname = "learning.{ext}" + print(f"creating {figname}") + plt.savefig(figname.format(ext="svg")) + plt.savefig(figname.format(ext="pdf"), dpi=300) diff --git a/examples/match_target_distribution/training.py b/examples/match_target_distribution/training.py new file mode 100644 index 0000000..e888fe4 --- /dev/null +++ b/examples/match_target_distribution/training.py @@ -0,0 +1,168 @@ +import numpy as np +import os +import pickle +import torch + +from dopp import FeedForwardCell + + +def numpyfy_all_torch_tensors_in_dict(d): + for key in d: + try: + d[key] = d[key].detach().numpy() + except AttributeError: + pass + return d + + +def train(params, model, model_target, *, manual_grad=False): + + if not manual_grad: + optimizer = torch.optim.SGD(model.parameters(), lr=params["lr"]) + + res = {} + res["g0"] = torch.empty(params["trials"] // params["recording_interval"]) + res["u0"] = torch.empty(params["trials"] // params["recording_interval"]) + res["u0_sample"] = torch.empty(params["trials"] // params["recording_interval"]) + res["g0_target"] = torch.empty(params["trials"] // params["recording_interval"]) + res["u0_target"] = torch.empty(params["trials"] // params["recording_interval"]) + res["u0_target_sample"] = torch.empty( + params["trials"] // params["recording_interval"] + ) + res["r_in"] = torch.empty(params["trials"] // params["recording_interval"], 2) + res["r_in_noisy"] = torch.empty(params["trials"] // params["recording_interval"], 2) + res["wEd"] = torch.empty( + params["trials"] // params["recording_interval"], model.n_dendrites + ) + res["wId"] = torch.empty( + params["trials"] // params["recording_interval"], model.n_dendrites + ) + + torch.save( + model.state_dict(), os.path.join(params["save_path"], f"checkpoint_0.pkl") + ) + + batch_size = 1 + r_in = torch.ones(batch_size, 2, dtype=torch.double) + for i in range(params["trials"]): + + if i % 1000 == 0: + print(f'{i + 1} / {params["trials"]}', end="\r") + + r_in.zero_() + r_in += torch.empty(batch_size, 1, dtype=torch.double).normal_( + params["r_mu"], params["r_sigma"] + ) + r_in[r_in <= 0.0] = 0.001 + assert torch.all(r_in > 0.0) + + r_in_noisy = r_in.clone() + r_in_noisy[:, 0] += torch.empty(batch_size).normal_(0.0, params["sigma_0"]) + r_in_noisy[:, 1] += torch.empty(batch_size).normal_(0.0, params["sigma_1"]) + r_in_noisy[r_in_noisy <= 0.0] = 0.001 + assert torch.all(r_in_noisy > 0.0) + + u_in = model.f_inv(r_in_noisy) + g0, u0 = model(u_in) + + if i >= params["relative_time_silent_teacher"] * params["trials"]: + u_in_target = model.f_inv(r_in)[:, 0].reshape(batch_size, 1) + g0_target, u0_target = model_target(u_in_target) + u0_target_sample = model_target.sample(g0_target, u0_target) + + model.zero_grad() + if manual_grad: + model.compute_grad_manual_target(u0_target_sample, g0, u0, u_in) + model.apply_grad_weights(params["lr"]) + else: + energy = model.energy_target(u0_target_sample, g0, u0) + energy.sum().backward() + optimizer.step() + + if i % params["recording_interval"] == 0: + idx = i // params["recording_interval"] + + res["r_in"][idx] = r_in[0] + res["r_in_noisy"][idx] = r_in_noisy[0] + if i > params["relative_time_silent_teacher"] * params["trials"]: + res["g0_target"][idx] = g0_target.clone() + res["u0_target"][idx] = u0_target.clone() + res["u0_target_sample"][idx] = u0_target_sample + else: + res["g0_target"][idx] = 0.0 + res["u0_target"][idx] = -70.0 + res["u0_target_sample"][idx] = -70.0 + + g0, u0 = model(u_in) + res["g0"][idx] = g0 + res["u0"][idx] = u0 + res["u0_sample"][idx] = model.sample(g0, u0)[0] + + for d in range(model.n_dendrites): + res["wEd"][idx, d] = model.weightsE(d).clone() + res["wId"][idx, d] = model.weightsI(d).clone() + + if (i + 1) % params["check_point_interval"] == 0: + torch.save( + model.state_dict(), + os.path.join(params["save_path"], f"checkpoint_{i + 1}.pkl"), + ) + + print() + + res = numpyfy_all_torch_tensors_in_dict(res) + + return res + + +if __name__ == "__main__": + + params = { + "seed": 1234, + # "trials": 2_000_000, + "trials": 200_000, + "relative_time_silent_teacher": 0.05, + "recording_interval": 100, + "check_point_interval": 10_000, + "lr": 0.025e-2, + "n_dendrites": 2, + "r_mu": 1.2, + "r_sigma": 0.5 * 1e0, + "sigma_0": 0.05 * 7.5e-1, + "sigma_1": 0.25 * 7.5e-1, + "save_path": "./", + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + scale_factor = 5.0 + + gL = scale_factor * 0.05 + + model_target = FeedForwardCell([1], 1) + model_target.gc = None + model_target.gL0 = gL + model_target.gLd = torch.ones(1) * 0.1 * gL + model_target.scale_weightsE(scale_factor * 2.5) + model_target.scale_weightsI(scale_factor * 3.5) + + model = FeedForwardCell([1, 1], 1) + model.gc = None + model.gL0 = gL + model.gLd = torch.ones(params["n_dendrites"]) * 0.1 * gL + model.scale_weightsE(0.8 * scale_factor * 2.5) + model.scale_weightsI(0.8 * scale_factor * 3.5) + + # start both modalities with comparable initial weights + model.set_weightsE(1, 1.01 * model.weightsE(0)) + model.set_weightsI(1, 0.99 * model.weightsI(0)) + + torch.manual_seed(params["seed"]) + res = train(params, model, model_target, manual_grad=True) + + with open(os.path.join(params["save_path"], f"params_{params['r_sigma']}.pkl"), "wb") as f: + pickle.dump(params, f) + + with open(os.path.join(params["save_path"], f"res_{params['r_sigma']}.pkl"), "wb") as f: + pickle.dump(res, f) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..22960d7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +scipy ~= 1.4.1 +matplotlib ~= 3.3.3 +numpy ~= 1.18.4 +torch ~= 1.7.1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..57bcd5c --- /dev/null +++ b/setup.py @@ -0,0 +1,59 @@ +# encoding: utf8 +import re + +from setuptools import setup + + +def _cut_version_number_from_requirement(req): + return req.split()[0] + + +def read_metadata(metadata_str): + """ + Find __"meta"__ in init file. + """ + with open("./dopp/__init__.py", "r") as f: + meta_match = re.search(fr"^__{metadata_str}__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) + if meta_match: + return meta_match.group(1) + raise RuntimeError(f"Unable to find __{metadata_str}__ string.") + + +def read_requirements(): + requirements = [] + with open("./requirements.txt") as f: + for req in f: + req = req.replace("\n", " ") + requirements.append(req) + return requirements + +def read_long_description(): + with open("README.md", "r") as f: + descr = f.read() + return descr + + +setup( + name="dopp", + version=read_metadata("version"), + maintainer=read_metadata("maintainer"), + author=read_metadata("author"), + description=(read_metadata("description")), + license=read_metadata("license"), + keywords=("simulation", "neuronal networks", "dendritic computation", "synaptic plasticity"), + url=read_metadata("url"), + python_requires=">=3.6, <4", + install_requires=read_requirements(), + packages=["dopp"], + long_description=read_long_description(), + long_description_content_type="text/x-rst", + classifiers=[ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Topic :: Scientific/Engineering", + ], +) diff --git a/test/test_abstract_convex_cell.py b/test/test_abstract_convex_cell.py new file mode 100644 index 0000000..245c629 --- /dev/null +++ b/test/test_abstract_convex_cell.py @@ -0,0 +1,163 @@ +import numpy as np +import pytest +import torch + +from dopp.abstract_convex_cell import AbstractConvexCell + +SEED = np.random.randint(2 ** 32) + + +def test_weights_from_omega_omega_from_weights(): + torch.manual_seed(SEED) + + model = AbstractConvexCell([1], 1) + + assert model.weightsE(0).item() == pytest.approx( + model.weights_from_omega(model.omega_from_weights(model.weightsE(0))).item() + ), SEED + + +def test_set_weights(): + torch.manual_seed(SEED) + + w0 = torch.Tensor([[1.618]]) + w1 = torch.Tensor([[3.141]]) + + model = AbstractConvexCell([1, 1], 1) + model.set_weightsE(0, w0) + model.set_weightsE(1, w1) + model.set_weightsI(0, w0) + model.set_weightsI(1, w1) + + assert model.weightsE(0).item() == pytest.approx(w0.item()) + assert model.weightsE(1).item() == pytest.approx(w1.item()) + assert model.weightsI(0).item() == pytest.approx(w0.item()) + assert model.weightsI(1).item() == pytest.approx(w1.item()) + + +def test_scale_weightsE(): + torch.manual_seed(SEED) + + wE = torch.Tensor([[1.618]]) + wI = torch.Tensor([[3.141]]) + scale = 0.1 + + model = AbstractConvexCell([1], 1) + model.set_weightsE(0, wE) + model.set_weightsI(0, wI) + model.scale_weightsE(scale) + + assert model.weightsE(0).item() == pytest.approx(scale * wE.item()) + assert model.weightsI(0).item() == pytest.approx(wI.item()) + + +def test_scale_weightsI(): + torch.manual_seed(SEED) + + wE = torch.Tensor([[1.618]]) + wI = torch.Tensor([[3.141]]) + scale = 0.1 + + model = AbstractConvexCell([1], 1) + model.set_weightsE(0, wE) + model.set_weightsI(0, wI) + model.scale_weightsI(scale) + + assert model.weightsE(0).item() == pytest.approx(wE.item()) + assert model.weightsI(0).item() == pytest.approx(scale * wI.item()) + + +def test_dendritic_input(): + torch.manual_seed(SEED) + + r_in = torch.Tensor(1, 5).normal_(std=0.1) + 5.0 + + model = AbstractConvexCell([3, 2], 1) + + u_in = model.f_inv(r_in) + assert model.dendritic_input(u_in, 0)[0].tolist() == pytest.approx( + r_in[0, :3].tolist() + ) + assert model.dendritic_input(u_in, 1)[0].tolist() == pytest.approx( + r_in[0, 3 : 3 + 2].tolist() + ) + + +def test_copy_omegaE_omegaI_from(): + torch.manual_seed(SEED) + + model = AbstractConvexCell([1], 1) + model_other = AbstractConvexCell([1], 1) + + assert model.weightsE(0).item() != pytest.approx(model_other.weightsE(0).item()) + assert model.weightsI(0).item() != pytest.approx(model_other.weightsI(0).item()) + + model.copy_omegaE_omegaI_from(model_other) + + assert model.weightsE(0).item() == pytest.approx(model_other.weightsE(0).item()) + assert model.weightsI(0).item() == pytest.approx(model_other.weightsI(0).item()) + + +def test_gff0_and_uff0(): + torch.manual_seed(SEED) + + gL0 = 0.333 + EL = -72.0 + + model = AbstractConvexCell([1], 1) + model.gL0 = gL0 + model.EL = EL + gff0, uff0 = model.compute_gff0_and_uff0() + + assert gff0 == pytest.approx(gL0) + assert uff0 == pytest.approx(EL) + + +def test_gffd_and_uffd_w_input(): + torch.manual_seed(SEED) + + gLd = 0.333 + EL = -68.0 + EE = -10.0 + EI = -86.0 + wE = torch.Tensor([[[1.618]]]) + wI = torch.Tensor([[[3.141]]]) + r_in = torch.Tensor(1, 1).normal_(std=0.1) + 5.0 + + gffd_expected = gLd + wE * r_in + wI * r_in + uffd_expected = (gLd * EL + wE * r_in * EE + wI * r_in * EI) / gffd_expected + + model = AbstractConvexCell([1], 1) + model.gLd[0] = gLd + model.EL = EL + model.EE = EE + model.EI = EI + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + + u_in = model.f_inv(r_in) + gffd, uffd = model.compute_gffd_and_uffd(u_in) + + assert gffd[0][0].tolist() == pytest.approx(gffd_expected[0][0].tolist()) + assert uffd[0][0].tolist() == pytest.approx(uffd_expected[0][0].tolist()) + + +def test_input_scale(): + torch.manual_seed(SEED) + + scale_0 = 0.15 + scale_1 = 1.05 + + r_in = torch.Tensor(1, 5).normal_(std=0.1) + 5.0 + + model = AbstractConvexCell([3, 2], 1) + model.set_input_scale(0, scale_0) + model.set_input_scale(1, scale_1) + + u_in = model.f_inv(r_in) + assert model.dendritic_input(u_in, 0)[0].tolist() == pytest.approx( + (scale_0 * r_in[0, :3]).tolist() + ) + assert model.dendritic_input(u_in, 1)[0].tolist() == pytest.approx( + (scale_1 * r_in[0, 3 : 3 + 2]).tolist() + ) diff --git a/test/test_dynamic_feedforward_cell.py b/test/test_dynamic_feedforward_cell.py new file mode 100644 index 0000000..24250bc --- /dev/null +++ b/test/test_dynamic_feedforward_cell.py @@ -0,0 +1,81 @@ +import numpy as np +import pytest +import torch + +SEED = 1234 + + +from dopp import DynamicFeedForwardCell, FeedForwardCell + + +def test_single_dendrite_single_input_single_output_single_trial(): + """ + Test that dynamic feedforward cell converges to stationary solution. + """ + + params = { + "seed": SEED, + "in_features": [1], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = torch.Tensor([[[1.2]]]) + wI = torch.Tensor([[[0.7]]]) + + model_expected = FeedForwardCell(params["in_features"], params["out_features"]) + u_in = torch.Tensor([[model_expected.EL + 10.0]]) + + model_expected.set_weightsE(0, wE[0]) + model_expected.set_weightsI(0, wI[0]) + g0_expected, u0_expected = model_expected(u_in) + + model = DynamicFeedForwardCell(params["in_features"], params["out_features"]) + + # model solution + n_steps = 500 + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + with torch.no_grad(): + for _ in range(n_steps): + g0, u0 = model(u_in) + + assert g0.item() == pytest.approx(g0_expected.item()) + assert u0.item() == pytest.approx(u0_expected.item()) + + +def test_initialize_somatic_potential(): + params = { + "seed": SEED, + "in_features": [1], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = torch.Tensor([[[1.2]]]) + wI = torch.Tensor([[[0.7]]]) + + model = DynamicFeedForwardCell(params["in_features"], params["out_features"]) + model.gL0 *= 20.0 + u_init = torch.ones(params['out_features']) * (model.EL + 2.75) + model.initialize_somatic_potential(u_init) + + model.set_input_scale(0, 0.0) # no input, so cell should stay at initial state + u_in = torch.Tensor([[model.EL + 10.0]]) + + # model solution + n_steps = 2000 + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + with torch.no_grad(): + for step in range(n_steps): + g0, u0 = model(u_in) + assert u0.item() < u_init.item() # decay to leak potential + if step == 0: + assert u0.item() == pytest.approx(u_init.item(), rel=0.001) # still close to init + + assert u0.item() == pytest.approx(model.EL, rel=0.001) # decay to leak potential diff --git a/test/test_feedforward_cell.py b/test/test_feedforward_cell.py new file mode 100644 index 0000000..57dc888 --- /dev/null +++ b/test/test_feedforward_cell.py @@ -0,0 +1,715 @@ +import math +import numpy as np +import pytest +import torch + +SEED = 1234 + +from dopp import FeedForwardCell + + +def test_single_dendrite_single_input_single_output_single_trial(): + + params = { + "seed": SEED, + "in_features": [1], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = torch.Tensor([[[1.2]]]) + wI = torch.Tensor([[[0.7]]]) + + model = FeedForwardCell(params["in_features"], params["out_features"]) + u_in = torch.Tensor([[model.EL + 10.0]]) + + # hand-crafted solution + r_in = model.f(u_in) + gffd_target = model.gL0 + torch.mm(r_in, wE[0]) + torch.mm(r_in, wI[0]) + uffd_target = ( + model.gL0 * model.EL + + torch.mm(r_in, wE[0]) * model.EE + + torch.mm(r_in, wI[0]) * model.EI + ) / gffd_target + g0_target = model.gL0 + model.gc * gffd_target / (gffd_target + model.gc) + u0_target = ( + model.gL0 * model.EL + + model.gc * gffd_target / (gffd_target + model.gc) * uffd_target + ) / g0_target + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + gffd, uffd = model.compute_gffd_and_uffd(u_in) + g0, u0 = model(u_in) + + assert gffd.shape == (1, params["out_features"], len(params["in_features"])) + assert uffd.shape == (1, params["out_features"], len(params["in_features"])) + assert gffd_target.shape == (1, params["out_features"]) + assert uffd_target.shape == (1, params["out_features"]) + assert gffd[0, 0, 0].tolist() == pytest.approx(gffd_target[0, 0].tolist()) + assert uffd[0, 0, 0].tolist() == pytest.approx(uffd_target[0, 0].tolist()) + + assert g0.shape == (1, params["out_features"]) + assert u0.shape == (1, params["out_features"]) + assert g0.shape == g0_target.shape + assert u0.shape == u0_target.shape + assert g0[0, 0].tolist() == pytest.approx(g0_target[0, 0].tolist()) + assert u0[0, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + # test sampling + lambda_e = 1.67 + model.lambda_e = lambda_e + n_samples = 5000 + u0_sample = torch.empty(n_samples, model.out_features) + for i in range(n_samples): + u0_sample[i] = model.sample(g0, u0) + + assert torch.mean(u0_sample).item() == pytest.approx(u0_target.item(), rel=0.0001) + assert torch.std(u0_sample).item() == pytest.approx(torch.sqrt(lambda_e / g0_target).item(), rel=0.01) + + # test energy + u0_target = u0.clone() + 5.0 + g0, u0 = model(u_in) + p_expected = torch.sqrt(g0 / (2 * math.pi * model.lambda_e)) * torch.exp(-g0 / (2. * model.lambda_e) * (u0_target - u0) ** 2) + assert model.energy_target(u0_target, g0, u0).item() == pytest.approx(-torch.log(p_expected).item()) + + # test loss + assert model.loss_target(u0_target, g0, u0).item() == pytest.approx(0.5 * (u0_target - u0).item() ** 2) + + +def test_single_dendrite_two_inputs_single_output_single_trial(): + + params = { + "seed": SEED, + "in_features": [2], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = torch.Tensor([[[1.2], [0.5]]]) + wI = torch.Tensor([[[0.7], [0.4]]]) + + model = FeedForwardCell(params["in_features"], params["out_features"]) + u_in = torch.Tensor([[model.EL, model.EL + 5.0]]) + + # hand-crafted solution + r_in = model.f(u_in) + gffd_target = model.gLd + torch.mm(r_in, wE[0]) + torch.mm(r_in, wI[0]) + uffd_target = ( + model.gLd * model.EL + + torch.mm(r_in, wE[0]) * model.EE + + torch.mm(r_in, wI[0]) * model.EI + ) / gffd_target + g0_target = model.gL0 + model.gc * gffd_target / (gffd_target + model.gc) + u0_target = ( + model.gL0 * model.EL + + model.gc * gffd_target / (gffd_target + model.gc) * uffd_target + ) / g0_target + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + gffd, uffd = model.compute_gffd_and_uffd(u_in) + g0, u0 = model(u_in) + + assert gffd.shape == (1, 1, 1) + assert uffd.shape == (1, 1, 1) + assert gffd_target.shape == (1, 1) + assert uffd_target.shape == (1, 1) + assert gffd[0, 0, 0].tolist() == pytest.approx(gffd_target[0, 0].tolist()) + assert uffd[0, 0, 0].tolist() == pytest.approx(uffd_target[0, 0].tolist()) + + assert g0.shape == (1, 1) + assert u0.shape == (1, 1) + assert g0.shape == g0_target.shape + assert u0.shape == u0_target.shape + assert g0[0, 0].tolist() == pytest.approx(g0_target[0, 0].tolist()) + + +def test_two_dendrites_two_inputs_single_output_single_trial(): + + params = { + "seed": SEED, + "in_features": [1, 1], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = torch.Tensor([[[1.2]], [[0.5]]]) + wI = torch.Tensor([[[0.7]], [[0.4]]]) + + model = FeedForwardCell(params["in_features"], params["out_features"]) + u_in = torch.Tensor([[model.EL, model.EL + 5.0]]) + + # hand-crafted solution + r_in = model.f(u_in).reshape(1, 2, 1) + gffd_target = torch.Tensor( + [ + model.gLd[d] + torch.mm(r_in[:, d], wE[d]) + torch.mm(r_in[:, d], wI[d]) + for d in range(2) + ] + ).reshape(1, 1, 2) + uffd_target = ( + torch.Tensor( + [ + model.gLd[d] * model.EL + + torch.mm(r_in[:, d], wE[d]) * model.EE + + torch.mm(r_in[:, d], wI[d] * model.EI) + for d in range(2) + ] + ).reshape(1, 1, 2) + / gffd_target + ) + g0_target = model.gL0 + torch.sum( + model.gc * gffd_target / (gffd_target + model.gc), dim=2 + ) + u0_target = ( + model.gL0 * model.EL + + torch.sum( + model.gc * gffd_target / (gffd_target + model.gc) * uffd_target, dim=2 + ) + ) / g0_target + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + model.set_weightsE(1, wE[1]) + model.set_weightsI(1, wI[1]) + gffd, uffd = model.compute_gffd_and_uffd(u_in) + g0, u0 = model(u_in) + + assert gffd.shape == (1, 1, 2) + assert uffd.shape == (1, 1, 2) + assert gffd_target.shape == (1, 1, 2) + assert uffd_target.shape == (1, 1, 2) + assert gffd[0, 0, 0].tolist() == pytest.approx(gffd_target[0, 0, 0].tolist()) + assert gffd[0, 0, 1].tolist() == pytest.approx(gffd_target[0, 0, 1].tolist()) + assert uffd[0, 0, 0].tolist() == pytest.approx(uffd_target[0, 0, 0].tolist()) + assert uffd[0, 0, 1].tolist() == pytest.approx(uffd_target[0, 0, 1].tolist()) + + assert g0.shape == (1, 1) + assert u0.shape == (1, 1) + assert g0.shape == g0_target.shape + assert u0.shape == u0_target.shape + assert g0[0, 0].tolist() == pytest.approx(g0_target[0, 0].tolist()) + assert u0[0, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + +def test_two_dendrites_two_inputs_single_output_single_trial_heterogeneous_coupling_conductance(): + + params = { + "seed": SEED, + "in_features": [1, 1], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = torch.Tensor([[[1.2]], [[0.5]]]) + wI = torch.Tensor([[[0.7]], [[0.4]]]) + gc = torch.Tensor([[[10.0, 5.0]]]) + + model = FeedForwardCell(params["in_features"], params["out_features"]) + u_in = torch.Tensor([[model.EL, model.EL + 5.0]]) + + # hand-crafted solution + r_in = model.f(u_in).reshape(1, 2, 1) + gffd_target = torch.Tensor( + [ + model.gLd[d] + torch.mm(r_in[:, d], wE[d]) + torch.mm(r_in[:, d], wI[d]) + for d in range(2) + ] + ).reshape(1, 1, 2) + uffd_target = ( + torch.Tensor( + [ + model.gLd[d] * model.EL + + torch.mm(r_in[:, d], wE[d]) * model.EE + + torch.mm(r_in[:, d], wI[d] * model.EI) + for d in range(2) + ] + ).reshape(1, 1, 2) + / gffd_target + ) + g0_target = model.gL0 + torch.sum(gc * gffd_target / (gffd_target + gc), dim=2) + u0_target = ( + model.gL0 * model.EL + + torch.sum(gc * gffd_target / (gffd_target + gc) * uffd_target, dim=2) + ) / g0_target + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + model.set_weightsE(1, wE[1]) + model.set_weightsI(1, wI[1]) + model.gc = gc + gffd, uffd = model.compute_gffd_and_uffd(u_in) + g0, u0 = model(u_in) + + assert gffd.shape == (1, 1, 2) + assert uffd.shape == (1, 1, 2) + assert gffd_target.shape == (1, 1, 2) + assert uffd_target.shape == (1, 1, 2) + assert gffd[0, 0, 0].tolist() == pytest.approx(gffd_target[0, 0, 0].tolist()) + assert gffd[0, 0, 1].tolist() == pytest.approx(gffd_target[0, 0, 1].tolist()) + assert uffd[0, 0, 0].tolist() == pytest.approx(uffd_target[0, 0, 0].tolist()) + assert uffd[0, 0, 1].tolist() == pytest.approx(uffd_target[0, 0, 1].tolist()) + + assert g0.shape == (1, 1) + assert u0.shape == (1, 1) + assert g0.shape == g0_target.shape + assert u0.shape == u0_target.shape + assert g0[0, 0].tolist() == pytest.approx(g0_target[0, 0].tolist()) + assert u0[0, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + +def test_two_dendrites_four_inputs_single_output_single_trial(): + + params = { + "seed": SEED, + "in_features": [2, 2], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = [torch.Tensor([[1.2, 1.1]]).t(), torch.Tensor([[0.5, 0.8]]).t()] + wI = [torch.Tensor([[0.7, 0.2]]).t(), torch.Tensor([[0.4, 0.1]]).t()] + + model = FeedForwardCell(params["in_features"], params["out_features"]) + u_in = torch.Tensor([[model.EL, model.EL - 5.0, model.EL + 5.0, model.EL + 10.0]]) + + # handcrafted solution + r_in = model.f(u_in).reshape(1, 2, 2) + gffd_target = torch.Tensor( + [ + model.gLd[d] + torch.mm(r_in[:, d], wE[d]) + torch.mm(r_in[:, d], wI[d]) + for d in range(2) + ] + ).reshape(1, 1, 2) + uffd_target = ( + torch.Tensor( + [ + model.gLd[d] * model.EL + + torch.mm(r_in[:, d], wE[d]) * model.EE + + torch.mm(r_in[:, d], wI[d] * model.EI) + for d in range(2) + ] + ).reshape(1, 1, 2) + / gffd_target + ) + g0_target = model.gL0 + torch.sum( + model.gc * gffd_target / (gffd_target + model.gc), dim=2 + ) + u0_target = ( + model.gL0 * model.EL + + torch.sum( + model.gc * gffd_target / (gffd_target + model.gc) * uffd_target, dim=2 + ) + ) / g0_target + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + model.set_weightsE(1, wE[1]) + model.set_weightsI(1, wI[1]) + gffd, uffd = model.compute_gffd_and_uffd(u_in) + g0, u0 = model(u_in) + + assert gffd.shape == (1, 1, 2) + assert uffd.shape == (1, 1, 2) + assert gffd_target.shape == (1, 1, 2) + assert uffd_target.shape == (1, 1, 2) + assert gffd[0, 0, 0].tolist() == pytest.approx(gffd_target[0, 0, 0].tolist()) + assert gffd[0, 0, 1].tolist() == pytest.approx(gffd_target[0, 0, 1].tolist()) + assert uffd[0, 0, 0].tolist() == pytest.approx(uffd_target[0, 0, 0].tolist()) + assert uffd[0, 0, 1].tolist() == pytest.approx(uffd_target[0, 0, 1].tolist()) + + assert g0.shape == (1, 1) + assert u0.shape == (1, 1) + assert g0.shape == g0_target.shape + assert u0.shape == u0_target.shape + assert g0[0, 0].tolist() == pytest.approx(g0_target[0, 0].tolist()) + assert u0[0, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + +def test_two_dendrites_four_inputs_three_outputs_single_trial(): + + params = { + "seed": SEED, + "in_features": [2, 2], + "out_features": 3, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = [ + torch.Tensor([[1.2, 1.1], [1.2, 1.2], [1.2, 1.3]]).t(), + torch.Tensor([[0.5, 0.8], [0.4, 0.8], [0.3, 0.8]]).t(), + ] + wI = [ + torch.Tensor([[0.7, 0.2], [0.7, 0.2], [0.7, 0.2]]).t(), + torch.Tensor([[0.4, 0.1], [0.4, 0.1], [0.4, 0.1]]).t(), + ] + + model = FeedForwardCell(params["in_features"], params["out_features"]) + u_in = torch.Tensor([[model.EL, model.EL - 5.0, model.EL + 5.0, model.EL + 10.0]]) + + # hand-crafted solution + r_in = model.f(u_in).reshape(1, 2, 2) + gffd_target = torch.empty(1, params["out_features"], len(params["in_features"])) + uffd_target = torch.empty(1, params["out_features"], len(params["in_features"])) + for d in range(2): + gffd_target[:, :, d] = ( + model.gLd[d] + torch.mm(r_in[:, d], wE[d]) + torch.mm(r_in[:, d], wI[d]) + ) + uffd_target[:, :, d] = ( + model.gLd[d] * model.EL + + torch.mm(r_in[:, d], wE[d]) * model.EE + + torch.mm(r_in[:, d], wI[d] * model.EI) + ) / gffd_target[:, :, d] + g0_target = model.gL0 + torch.sum( + model.gc * gffd_target / (gffd_target + model.gc), dim=2 + ) + u0_target = ( + model.gL0 * model.EL + + torch.sum( + model.gc * gffd_target / (gffd_target + model.gc) * uffd_target, dim=2 + ) + ) / g0_target + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + model.set_weightsE(1, wE[1]) + model.set_weightsI(1, wI[1]) + gffd, uffd = model.compute_gffd_and_uffd(u_in) + g0, u0 = model(u_in) + + assert gffd.shape == (1, params["out_features"], len(params["in_features"])) + assert uffd.shape == (1, params["out_features"], len(params["in_features"])) + assert gffd_target.shape == (1, params["out_features"], len(params["in_features"])) + assert uffd_target.shape == (1, params["out_features"], len(params["in_features"])) + assert g0.shape == (1, params["out_features"]) + assert u0.shape == (1, params["out_features"]) + assert g0.shape == g0_target.shape + assert u0.shape == u0_target.shape + for n in range(params["out_features"]): + assert gffd[0, n, 0].tolist() == pytest.approx(gffd_target[0, n, 0].tolist()) + assert gffd[0, n, 1].tolist() == pytest.approx(gffd_target[0, n, 1].tolist()) + assert uffd[0, n, 0].tolist() == pytest.approx(uffd_target[0, n, 0].tolist()) + assert uffd[0, n, 1].tolist() == pytest.approx(uffd_target[0, n, 1].tolist()) + + assert g0[0, 0].tolist() == pytest.approx(g0_target[0, 0].tolist()) + assert u0[0, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + +def test_two_dendrites_four_inputs_three_outputs_multiple_trials(): + + params = { + "seed": SEED, + "in_features": [2, 2], + "out_features": 3, + "batch_size": 4, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = [ + torch.Tensor([[1.2, 1.1], [1.2, 1.2], [1.2, 1.3]]).t(), + torch.Tensor([[0.5, 0.8], [0.4, 0.8], [0.3, 0.8]]).t(), + ] + wI = [ + torch.Tensor([[0.7, 0.2], [0.7, 0.2], [0.7, 0.2]]).t(), + torch.Tensor([[0.4, 0.1], [0.4, 0.1], [0.4, 0.1]]).t(), + ] + + model = FeedForwardCell(params["in_features"], params["out_features"]) + u_in = torch.Tensor( + [ + [model.EL, model.EL - 5.0, model.EL + 5.0, model.EL + 10.0], + [model.EL, model.EL - 5.0, model.EL + 5.0, model.EL + 15.0], + [model.EL, model.EL - 5.0, model.EL - 5.0, model.EL + 10.0], + [model.EL, model.EL + 5.0, model.EL + 5.0, model.EL + 10.0], + ] + ) + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + model.set_weightsE(1, wE[1]) + model.set_weightsI(1, wI[1]) + gffd, uffd = model.compute_gffd_and_uffd(u_in) + g0, u0 = model(u_in) + + assert gffd.shape == ( + params["batch_size"], + params["out_features"], + len(params["in_features"]), + ) + assert uffd.shape == ( + params["batch_size"], + params["out_features"], + len(params["in_features"]), + ) + assert g0.shape == (params["batch_size"], params["out_features"]) + assert u0.shape == (params["batch_size"], params["out_features"]) + for trial in range(params["batch_size"]): + + # hand-crafted solution + r_in = model.f(u_in[trial]).reshape(1, 2, 2) + gffd_target = torch.empty(1, params["out_features"], len(params["in_features"])) + uffd_target = torch.empty(1, params["out_features"], len(params["in_features"])) + for d in range(2): + gffd_target[:, :, d] = ( + model.gLd[d] + torch.mm(r_in[:, d], wE[d]) + torch.mm(r_in[:, d], wI[d]) + ) + uffd_target[:, :, d] = ( + model.gLd[d] * model.EL + + torch.mm(r_in[:, d], wE[d]) * model.EE + + torch.mm(r_in[:, d], wI[d] * model.EI) + ) / gffd_target[:, :, d] + g0_target = model.gL0 + torch.sum( + model.gc * gffd_target / (gffd_target + model.gc), dim=2 + ) + u0_target = ( + model.gL0 * model.EL + + torch.sum( + model.gc * gffd_target / (gffd_target + model.gc) * uffd_target, dim=2 + ) + ) / g0_target + + assert gffd_target.shape == ( + 1, + params["out_features"], + len(params["in_features"]), + ) + assert uffd_target.shape == ( + 1, + params["out_features"], + len(params["in_features"]), + ) + assert g0_target.shape == (1, params["out_features"]) + assert u0_target.shape == (1, params["out_features"]) + + for n in range(params["out_features"]): + assert gffd[trial, n, 0].tolist() == pytest.approx( + gffd_target[0, n, 0].tolist() + ) + assert gffd[trial, n, 1].tolist() == pytest.approx( + gffd_target[0, n, 1].tolist() + ) + assert uffd[trial, n, 0].tolist() == pytest.approx( + uffd_target[0, n, 0].tolist() + ) + assert uffd[trial, n, 1].tolist() == pytest.approx( + uffd_target[0, n, 1].tolist() + ) + + assert g0[trial, 0].tolist() == pytest.approx(g0_target[0, 0].tolist()) + assert u0[trial, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + +def test_reduction_to_point_neuron(): + """check that results are independent of dendritic layout for large transfer + conductances and small leak conductances""" + + params = { + "seed": SEED, + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = torch.Tensor([[1.2, 0.5]]).t() + wI = torch.Tensor([[0.7, 0.4]]).t() + + model_single = FeedForwardCell([2], params["out_features"]) + model_single.gL0 = 0.0 + model_single.gLd = torch.Tensor([0.0]) + model_single.set_weightsE(0, wE) + model_single.set_weightsI(0, wI) + + model_double = FeedForwardCell([1, 1], params["out_features"]) + model_double.gL0 = 0.0 + model_double.gLd = torch.Tensor([0.0, 0.0]) + model_double.set_weightsE(0, wE.reshape(2, 1, 1)[0]) + model_double.set_weightsI(0, wI.reshape(2, 1, 1)[0]) + model_double.set_weightsE(1, wE.reshape(2, 1, 1)[1]) + model_double.set_weightsI(1, wI.reshape(2, 1, 1)[1]) + + u_in = torch.Tensor([[model_single.EL, model_single.EL + 10.0]]) + + # potentials should be different for small coupling conductances + model_single.gc = 1.0 + model_double.gc = 1.0 + g0_single, u0_single = model_single(u_in) + g0_double, u0_double = model_double(u_in) + + assert u0_single[0, 0].tolist() != pytest.approx(u0_double[0, 0].tolist()) + + # potentials should be identical for large coupling conductances + model_single.gc = 1_000_000.0 + model_double.gc = 1_000_000.0 + g0_single, u0_single = model_single(u_in) + g0_double, u0_double = model_double(u_in) + + assert u0_single[0, 0].tolist() == pytest.approx(u0_double[0, 0].tolist()) + + +def test_grad_is_identical_to_backprop_grad_point_neuron(): + params = { + "seed": SEED, + "in_features": [2, 2], + "out_features": 3, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + model = FeedForwardCell(params["in_features"], params["out_features"]) + model.lambda_e = 1.67 + + model.gc = None + u0_target = torch.DoubleTensor( + [ + [model.EL, model.EL - 5.0, model.EL + 5.0], + [model.EL, model.EL + 5.0, model.EL - 5.0], + ] + ) + + u_in = torch.DoubleTensor( + [ + [model.EL, model.EL + 5.0, model.EL - 5.0, model.EL + 10.0], + [model.EL, model.EL - 5.0, model.EL + 5.0, model.EL - 10.0], + ] + ) + + # calculate gradient manually + with torch.no_grad(): + g0_manual, u0_manual = model(u_in) + model.zero_grad() + model.compute_grad_manual_target(u0_target, g0_manual, u0_manual, u_in) + omegaE_0_grad_manual = model._omegaE[0]._grad.clone() + omegaI_0_grad_manual = model._omegaI[0]._grad.clone() + omegaE_1_grad_manual = model._omegaE[1]._grad.clone() + omegaI_1_grad_manual = model._omegaI[1]._grad.clone() + + # calculate gradient with autograd + g0_bp, u0_bp = model(u_in) + model.zero_grad() + model.energy_target(u0_target, g0_bp, u0_bp).sum().backward() + omegaE_0_grad_bp = model._omegaE[0]._grad.clone() + omegaI_0_grad_bp = model._omegaI[0]._grad.clone() + omegaE_1_grad_bp = model._omegaE[1]._grad.clone() + omegaI_1_grad_bp = model._omegaI[1]._grad.clone() + + assert g0_manual[0, 0].tolist() == pytest.approx(g0_bp[0, 0].tolist()) + assert u0_manual[0, 0].tolist() == pytest.approx(u0_bp[0, 0].tolist()) + + for n in range(params["out_features"]): + assert omegaE_0_grad_manual[:, n].tolist() == pytest.approx( + omegaE_0_grad_bp[:, n].tolist() + ) + assert omegaI_0_grad_manual[:, n].tolist() == pytest.approx( + omegaI_0_grad_bp[:, n].tolist() + ) + assert omegaE_1_grad_manual[:, n].tolist() == pytest.approx( + omegaE_1_grad_bp[:, n].tolist() + ) + assert omegaI_1_grad_manual[:, n].tolist() == pytest.approx( + omegaI_1_grad_bp[:, n].tolist() + ) + + +# def test_grad_is_identical_to_backprop_grad(): +# params = { +# "seed": SEED, +# "in_features": [2, 2], +# "out_features": 3, +# } + +# np.random.seed(params["seed"]) +# torch.manual_seed(params["seed"]) + +# model = FeedForwardCell(params["in_features"], params["out_features"]) + +# model.gc = torch.Tensor([2.34, 1.87]) +# u0_target = torch.DoubleTensor( +# [ +# [model.EL, model.EL - 5.0, model.EL + 5.0], +# [model.EL, model.EL + 5.0, model.EL - 5.0], +# ] +# ) + +# u_in = torch.DoubleTensor( +# [ +# [model.EL, model.EL + 5.0, model.EL - 5.0, model.EL + 10.0], +# [model.EL, model.EL - 5.0, model.EL + 5.0, model.EL - 10.0], +# ] +# ) + +# with torch.no_grad(): +# g0_manual, u0_manual = model(u_in) +# model.zero_grad() +# model.compute_grad_manual_target(u0_target, g0_manual, u0_manual, u_in) +# omegaE_0_grad_manual = model._omegaE[0]._grad.clone() +# omegaI_0_grad_manual = model._omegaI[0]._grad.clone() +# omegaE_1_grad_manual = model._omegaE[1]._grad.clone() +# omegaI_1_grad_manual = model._omegaI[1]._grad.clone() + +# # calculate gradient with autograd +# g0_bp, u0_bp = model(u_in) +# model.zero_grad() +# model.energy_target(u0_target, g0_bp, u0_bp).sum().backward() +# omegaE_0_grad_bp = model._omegaE[0]._grad.clone() +# omegaI_0_grad_bp = model._omegaI[0]._grad.clone() +# omegaE_1_grad_bp = model._omegaE[1]._grad.clone() +# omegaI_1_grad_bp = model._omegaI[1]._grad.clone() + +# assert g0_manual[0, 0].tolist() == pytest.approx(g0_bp[0, 0].tolist()) +# assert u0_manual[0, 0].tolist() == pytest.approx(u0_bp[0, 0].tolist()) + +# for n in range(params["out_features"]): +# assert omegaE_0_grad_manual[:, n].tolist() == pytest.approx( +# omegaE_0_grad_bp[:, n].tolist() +# ) +# assert omegaI_0_grad_manual[:, n].tolist() == pytest.approx( +# omegaI_0_grad_bp[:, n].tolist() +# ) +# assert omegaE_1_grad_manual[:, n].tolist() == pytest.approx( +# omegaE_1_grad_bp[:, n].tolist() +# ) +# assert omegaI_1_grad_manual[:, n].tolist() == pytest.approx( +# omegaI_1_grad_bp[:, n].tolist() +# ) + + +def test_backprop_multiple_times(): + + torch.manual_seed(SEED) + + model = FeedForwardCell([1], 1) + u0_target = -60.0 + + u_in = torch.DoubleTensor([[model.EL]]) + g0, u0 = model(u_in) + model.zero_grad() + model.energy_target(u0_target, g0, u0).backward() + + u_in = torch.DoubleTensor([[model.EL]]) + g0, u0 = model(u_in) + model.zero_grad() + model.energy_target(u0_target, g0, u0).backward() diff --git a/test/test_feedforward_current_cell.py b/test/test_feedforward_current_cell.py new file mode 100644 index 0000000..8085bd4 --- /dev/null +++ b/test/test_feedforward_current_cell.py @@ -0,0 +1,388 @@ +import math +import numpy as np +import pytest +import torch + +from dopp import FeedForwardCurrentCell + +SEED = 1234 + + +def test_single_dendrite_single_input_single_output_single_trial(): + + params = { + "seed": SEED, + "in_features": [1], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = torch.Tensor([[[1.2]]]) + wI = torch.Tensor([[[0.7]]]) + + model = FeedForwardCurrentCell(params["in_features"], params["out_features"]) + + u_in = torch.Tensor([[model.EL + 10.0]]) + + # hand-crafted solution + r_in = model.f(u_in) + Iffd_target = torch.mm(r_in, wE[0]) - torch.mm(r_in, wI[0]) + u0_target = model.EL + Iffd_target + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + Iffd = model.compute_Iffd(u_in) + _, u0 = model(u_in) + + assert Iffd.shape == (1, params["out_features"], len(params["in_features"])) + assert Iffd_target.shape == (1, params["out_features"]) + assert Iffd[0, 0, 0].tolist() == pytest.approx(Iffd_target[0, 0].tolist()) + + assert u0.shape == (1, params["out_features"]) + assert u0.shape == u0_target.shape + assert u0[0, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + # test energy + u0_target = u0.clone() + 5.0 + _, u0 = model(u_in) + g0 = torch.ones_like(u0) + p_expected = torch.sqrt(g0 / (2 * math.pi * model.lambda_e)) * torch.exp(-g0 / (2. * model.lambda_e) * (u0_target - u0) ** 2) + assert model.energy_target(u0_target, None, u0).item() == pytest.approx(-torch.log(p_expected).item()) + + # test loss + assert model.loss_target(u0_target, None, u0).item() == pytest.approx(0.5 * (u0_target - u0).item() ** 2) + + +def test_single_dendrite_two_inputs_single_output_single_trial(): + + params = { + "seed": SEED, + "in_features": [2], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = torch.Tensor([[[1.2], [0.5]]]) + wI = torch.Tensor([[[0.7], [0.4]]]) + + model = FeedForwardCurrentCell(params["in_features"], params["out_features"]) + + u_in = torch.Tensor([[model.EL, model.EL + 5.0]]) + + # hand-crafted solution + r_in = model.f(u_in) + Iffd_target = torch.mm(r_in, wE[0]) - torch.mm(r_in, wI[0]) + u0_target = model.EL + Iffd_target + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + Iffd = model.compute_Iffd(u_in) + _, u0 = model(u_in) + + assert Iffd.shape == (1, 1, 1) + assert Iffd_target.shape == (1, 1) + assert Iffd[0, 0, 0].tolist() == pytest.approx(Iffd_target[0, 0].tolist()) + + assert u0.shape == (1, 1) + assert u0.shape == u0_target.shape + assert u0[0, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + +def test_two_dendrites_two_inputs_single_output_single_trial(): + + params = { + "seed": SEED, + "in_features": [1, 1], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = torch.Tensor([[[1.2]], [[0.5]]]) + wI = torch.Tensor([[[0.7]], [[0.4]]]) + + model = FeedForwardCurrentCell(params["in_features"], params["out_features"]) + + u_in = torch.Tensor([[model.EL, model.EL + 5.0]]) + + # hand-crafted solution + r_in = model.f(u_in).reshape(1, 2, 1) + Iffd_target = torch.Tensor( + [ + torch.mm(r_in[:, d], wE[d]) - torch.mm(r_in[:, d], wI[d]) + for d in range(2) + ] + ).reshape(1, 1, 2) + u0_target = model.EL + torch.sum(Iffd_target, dim=2) + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + model.set_weightsE(1, wE[1]) + model.set_weightsI(1, wI[1]) + Iffd = model.compute_Iffd(u_in) + _, u0 = model(u_in) + + assert Iffd.shape == (1, 1, 2) + assert Iffd_target.shape == (1, 1, 2) + assert Iffd[0, 0, 0].tolist() == pytest.approx(Iffd_target[0, 0, 0].tolist()) + assert Iffd[0, 0, 1].tolist() == pytest.approx(Iffd_target[0, 0, 1].tolist()) + + assert u0.shape == (1, 1) + assert u0.shape == u0_target.shape + assert u0[0, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + +def test_two_dendrites_four_inputs_single_output_single_trial(): + + params = { + "seed": SEED, + "in_features": [2, 2], + "out_features": 1, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = [torch.Tensor([[1.2, 1.1]]).t(), torch.Tensor([[0.5, 0.8]]).t()] + wI = [torch.Tensor([[0.7, 0.2]]).t(), torch.Tensor([[0.4, 0.1]]).t()] + + model = FeedForwardCurrentCell(params["in_features"], params["out_features"]) + + u_in = torch.Tensor([[model.EL, model.EL - 5.0, model.EL + 5.0, model.EL + 10.0]]) + + # handcrafted solution + r_in = model.f(u_in).reshape(1, 2, 2) + Iffd_target = torch.Tensor( + [ + torch.mm(r_in[:, d], wE[d]) - torch.mm(r_in[:, d], wI[d]) + for d in range(2) + ] + ).reshape(1, 1, 2) + u0_target = model.EL + torch.sum(Iffd_target, dim=2) + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + model.set_weightsE(1, wE[1]) + model.set_weightsI(1, wI[1]) + Iffd = model.compute_Iffd(u_in) + _, u0 = model(u_in) + + assert Iffd.shape == (1, 1, 2) + assert Iffd_target.shape == (1, 1, 2) + assert Iffd[0, 0, 0].tolist() == pytest.approx(Iffd_target[0, 0, 0].tolist()) + assert Iffd[0, 0, 1].tolist() == pytest.approx(Iffd_target[0, 0, 1].tolist()) + + assert u0.shape == (1, 1) + assert u0.shape == u0_target.shape + assert u0[0, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + +def test_two_dendrites_four_inputs_three_outputs_single_trial(): + + params = { + "seed": SEED, + "in_features": [2, 2], + "out_features": 3, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = [ + torch.Tensor([[1.2, 1.1], [1.2, 1.2], [1.2, 1.3]]).t(), + torch.Tensor([[0.5, 0.8], [0.4, 0.8], [0.3, 0.8]]).t(), + ] + wI = [ + torch.Tensor([[0.7, 0.2], [0.7, 0.2], [0.7, 0.2]]).t(), + torch.Tensor([[0.4, 0.1], [0.4, 0.1], [0.4, 0.1]]).t(), + ] + + model = FeedForwardCurrentCell(params["in_features"], params["out_features"]) + + u_in = torch.Tensor([[model.EL, model.EL - 5.0, model.EL + 5.0, model.EL + 10.0]]) + + # hand-crafted solution + r_in = model.f(u_in).reshape(1, 2, 2) + Iffd_target = torch.empty(1, params["out_features"], len(params["in_features"])) + for d in range(2): + Iffd_target[:, :, d] = torch.mm(r_in[:, d], wE[d]) - torch.mm(r_in[:, d], wI[d]) + u0_target = model.EL + torch.sum(Iffd_target, dim=2) + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + model.set_weightsE(1, wE[1]) + model.set_weightsI(1, wI[1]) + Iffd = model.compute_Iffd(u_in) + _, u0 = model(u_in) + + assert Iffd.shape == (1, params["out_features"], len(params["in_features"])) + assert Iffd_target.shape == (1, params["out_features"], len(params["in_features"])) + assert u0.shape == (1, params["out_features"]) + assert u0.shape == u0_target.shape + for n in range(params["out_features"]): + assert Iffd[0, n, 0].tolist() == pytest.approx(Iffd_target[0, n, 0].tolist()) + assert Iffd[0, n, 1].tolist() == pytest.approx(Iffd_target[0, n, 1].tolist()) + assert u0[0, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + +def test_two_dendrites_four_inputs_three_outputs_multiple_trials(): + + params = { + "seed": SEED, + "in_features": [2, 2], + "out_features": 3, + "batch_size": 4, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + wE = [ + torch.Tensor([[1.2, 1.1], [1.2, 1.2], [1.2, 1.3]]).t(), + torch.Tensor([[0.5, 0.8], [0.4, 0.8], [0.3, 0.8]]).t(), + ] + wI = [ + torch.Tensor([[0.7, 0.2], [0.7, 0.2], [0.7, 0.2]]).t(), + torch.Tensor([[0.4, 0.1], [0.4, 0.1], [0.4, 0.1]]).t(), + ] + + model = FeedForwardCurrentCell(params["in_features"], params["out_features"]) + + u_in = torch.Tensor( + [ + [model.EL, model.EL - 5.0, model.EL + 5.0, model.EL + 10.0], + [model.EL, model.EL - 5.0, model.EL + 5.0, model.EL + 15.0], + [model.EL, model.EL - 5.0, model.EL - 5.0, model.EL + 10.0], + [model.EL, model.EL + 5.0, model.EL + 5.0, model.EL + 10.0], + ] + ) + + # model solution + model.set_weightsE(0, wE[0]) + model.set_weightsI(0, wI[0]) + model.set_weightsE(1, wE[1]) + model.set_weightsI(1, wI[1]) + Iffd = model.compute_Iffd(u_in) + _, u0 = model(u_in) + + assert Iffd.shape == ( + params["batch_size"], + params["out_features"], + len(params["in_features"]), + ) + assert u0.shape == (params["batch_size"], params["out_features"]) + + for trial in range(params["batch_size"]): + + # hand-crafted solution + r_in = model.f(u_in[trial]).reshape(1, 2, 2) + Iffd_target = torch.empty(1, params["out_features"], len(params["in_features"])) + for d in range(2): + Iffd_target[:, :, d] = torch.mm(r_in[:, d], wE[d]) - torch.mm(r_in[:, d], wI[d]) + u0_target = model.EL + torch.sum(Iffd_target, dim=2) + + assert Iffd_target.shape == ( + 1, + params["out_features"], + len(params["in_features"]), + ) + assert u0_target.shape == (1, params["out_features"]) + + for n in range(params["out_features"]): + assert Iffd[trial, n, 0].tolist() == pytest.approx( + Iffd_target[0, n, 0].tolist() + ) + assert Iffd[trial, n, 1].tolist() == pytest.approx( + Iffd_target[0, n, 1].tolist() + ) + + assert u0[trial, 0].tolist() == pytest.approx(u0_target[0, 0].tolist()) + + +def test_grad_is_identical_to_backprop_grad(): + params = { + "seed": SEED, + "in_features": [2, 2], + "out_features": 3, + } + + np.random.seed(params["seed"]) + torch.manual_seed(params["seed"]) + + model = FeedForwardCurrentCell(params["in_features"], params["out_features"]) + + u0_target = torch.DoubleTensor( + [ + [model.EL, model.EL - 5.0, model.EL + 5.0], + [model.EL, model.EL + 5.0, model.EL - 5.0], + ] + ) + + u_in = torch.DoubleTensor( + [ + [model.EL, model.EL + 5.0, model.EL - 5.0, model.EL + 10.0], + [model.EL, model.EL - 5.0, model.EL + 5.0, model.EL - 10.0], + ] + ) + + with torch.no_grad(): + _, u0_manual = model(u_in) + model.zero_grad() + model.compute_grad_manual_target(u0_target, None, u0_manual, u_in) + omegaE_0_grad_manual = model._omegaE[0]._grad.clone() + omegaI_0_grad_manual = model._omegaI[0]._grad.clone() + omegaE_1_grad_manual = model._omegaE[1]._grad.clone() + omegaI_1_grad_manual = model._omegaI[1]._grad.clone() + + # calculate gradient with autograd + _, u0_bp = model(u_in) + model.zero_grad() + model.loss_target(u0_target, None, u0_bp).sum().backward() + omegaE_0_grad_bp = model._omegaE[0]._grad.clone() + omegaI_0_grad_bp = model._omegaI[0]._grad.clone() + omegaE_1_grad_bp = model._omegaE[1]._grad.clone() + omegaI_1_grad_bp = model._omegaI[1]._grad.clone() + + assert u0_manual[0, 0].tolist() == pytest.approx(u0_bp[0, 0].tolist()) + + for n in range(params["out_features"]): + assert omegaE_0_grad_manual[:, n].tolist() == pytest.approx( + omegaE_0_grad_bp[:, n].tolist() + ) + assert omegaI_0_grad_manual[:, n].tolist() == pytest.approx( + omegaI_0_grad_bp[:, n].tolist() + ) + assert omegaE_1_grad_manual[:, n].tolist() == pytest.approx( + omegaE_1_grad_bp[:, n].tolist() + ) + assert omegaI_1_grad_manual[:, n].tolist() == pytest.approx( + omegaI_1_grad_bp[:, n].tolist() + ) + + +def test_backprop_multiple_times(): + + torch.manual_seed(SEED) + + model = FeedForwardCurrentCell([1], 1) + u0_target = -60.0 + + u_in = torch.DoubleTensor([[model.EL]]) + _, u0 = model(u_in) + model.zero_grad() + model.loss_target(u0_target, None, u0).backward() + + u_in = torch.DoubleTensor([[model.EL]]) + _, u0 = model(u_in) + model.zero_grad() + model.loss_target(u0_target, None, u0).backward()