Skip to content
This repository has been archived by the owner on Mar 26, 2024. It is now read-only.

Commit

Permalink
Modularization (#90)
Browse files Browse the repository at this point in the history
* Initial rewrite to use torch.nn.Module

* Minor changes

* Dunno

* Dunno

* Fix

* Dunno what happened

* More changes

* Changes and cleanup

* PMMH fix

* Some changes

* Minor changes

* NB updates

* Black

* Doc
  • Loading branch information
tingiskhan authored Dec 17, 2020
1 parent 413256c commit 4b02a73
Show file tree
Hide file tree
Showing 61 changed files with 735 additions and 1,012 deletions.
42 changes: 25 additions & 17 deletions examples/lorenz.ipynb

Large diffs are not rendered by default.

59 changes: 31 additions & 28 deletions examples/nutria.ipynb

Large diffs are not rendered by default.

90 changes: 48 additions & 42 deletions examples/stochastic-volatility.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyfilter/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.13.5"
__version__ = "0.14.0"
3 changes: 3 additions & 0 deletions pyfilter/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .parameterized import DistributionWrapper
from .prior import Prior
from .empirical import Empirical
19 changes: 19 additions & 0 deletions pyfilter/distributions/mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from torch.distributions import Distribution
import torch
from typing import Dict


class BuilderMixin(object):
def get_parameters(self) -> Dict[str, torch.Tensor]:
res = dict()

res.update(self._parameters)
res.update(self._buffers)

return res

def build_distribution(self) -> Distribution:
return self.base_dist(**self.get_parameters())

def forward(self) -> Distribution:
return self.build_distribution()
22 changes: 22 additions & 0 deletions pyfilter/distributions/parameterized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch.distributions import Distribution
import torch
from typing import Type, Dict, Union, Callable
from ..prior_module import PriorModule
from .mixin import BuilderMixin
from .prior import Prior


DistributionType = Union[Type[Distribution], Callable[[Dict], Distribution]]


class DistributionWrapper(BuilderMixin, PriorModule):
def __init__(self, base_dist: DistributionType, **parameters):
super().__init__()

self.base_dist = base_dist

for k, v in parameters.items():
if isinstance(v, Prior):
self.register_prior(k, v)
else:
self.register_buffer(k, v if isinstance(v, torch.Tensor) else torch.tensor(v))
42 changes: 42 additions & 0 deletions pyfilter/distributions/prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from torch.distributions import TransformedDistribution, biject_to
import torch
from typing import Tuple
from torch.nn import Module
from .mixin import BuilderMixin


class Prior(BuilderMixin, Module):
def __init__(self, base_dist, **parameters):
super().__init__()

self.base_dist = base_dist

for k, v in parameters.items():
self.register_buffer(k, v if isinstance(v, torch.Tensor) else torch.tensor(v))

self.bijection = biject_to(self().support)
self.shape = self().event_shape

@property
def unconstrained_prior(self):
return TransformedDistribution(self(), self.bijection.inv)

def get_unconstrained(self, x: torch.Tensor):
return self.bijection.inv(x)

def get_constrained(self, x: torch.Tensor):
return self.bijection(x)

def eval_prior(self, x: torch.Tensor, constrained=True) -> torch.Tensor:
if constrained:
return self().log_prob(x)

return self.unconstrained_prior.log_prob(self.get_unconstrained(x))

def get_numel(self, constrained=True):
return (self().event_shape if not constrained else self.unconstrained_prior.event_shape).numel()

def get_slice_for_parameter(self, prev_index, constrained=True) -> Tuple[slice, int]:
numel = self.get_numel(constrained)

return slice(prev_index, prev_index + numel), numel
3 changes: 2 additions & 1 deletion pyfilter/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .apf import APF
from .sisr import SISR
from .ukf import UKF
from .base import BaseFilter, BaseKalmanFilter, FilterResult
from .base import BaseFilter, BaseKalmanFilter
from .pf import ParticleFilter
from .state import ParticleState, KalmanState, BaseState
from .result import FilterResult
19 changes: 6 additions & 13 deletions pyfilter/filters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from tqdm import tqdm
import torch
from ..utils import choose
from ..module import Module
from .utils import enforce_tensor, FilterResult
from torch.nn import Module
from .utils import enforce_tensor
from .result import FilterResult
from typing import Tuple, Union, Iterable
from .state import BaseState

Expand All @@ -28,11 +29,6 @@ def ssm(self) -> StateSpaceModel:
def n_parallel(self) -> torch.Size:
return self._n_parallel

def viewify_params(self, shape: Union[int, torch.Size]):
self.ssm.viewify_params(shape)

return self

def set_nparallel(self, n: int):
"""
Sets the number of parallel filters to use
Expand Down Expand Up @@ -85,9 +81,8 @@ def longfilter(

return result

def copy(self, view_shape=torch.Size([])):
def copy(self):
res = copy.deepcopy(self)
res.viewify_params(view_shape)
return res

def predict(self, state: BaseState, steps: int, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -100,7 +95,8 @@ def resample(self, inds: torch.Tensor):
:return: Self
"""

self.ssm.p_apply(lambda u: choose(u.values, inds))
for p in self.parameters():
p.data[:] = choose(p, inds)

return self

Expand All @@ -117,9 +113,6 @@ def exchange(self, filter_, inds: torch.Tensor):

return self

def populate_state_dict(self):
return {"_model": self.ssm.state_dict(), "_n_parallel": self._n_parallel}

def smooth(self, states: Iterable[BaseState]) -> torch.Tensor:
raise NotImplementedError()

Expand Down
8 changes: 1 addition & 7 deletions pyfilter/filters/pf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def set_nparallel(self, n: int):
def initialize(self) -> ParticleState:
x = self._model.hidden.i_sample(self.particles)
w = torch.zeros(self.particles, device=x.device)
prev_inds = torch.ones_like(w) * torch.arange(w.shape[-1])
prev_inds = torch.ones_like(w) * torch.arange(w.shape[-1], device=x.device)

return ParticleState(x, w, torch.zeros(self._n_parallel, device=x.device), prev_inds)

Expand All @@ -113,12 +113,6 @@ def predict(self, state: ParticleState, steps, aggregate: bool = True, **kwargs)

return xm[1:], ym[1:]

def populate_state_dict(self):
base = super(ParticleFilter, self).populate_state_dict()
base.update({"_particles": self.particles, "_rsample": self._rsample, "_th": self._th})

return base

def smooth(self, states: Iterable[ParticleState]):
hidden_copy = self.ssm.hidden.copy((*self.n_parallel, 1, 1))
offset = -(2 + self.ssm.hidden_ndim)
Expand Down
3 changes: 1 addition & 2 deletions pyfilter/filters/proposals/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from ...timeseries import StateSpaceModel
from torch.distributions import Distribution
from ...module import Module
import torch


class Proposal(Module):
class Proposal(object):
def __init__(self):
"""
Defines a proposal object for how to draw the particles.
Expand Down
8 changes: 5 additions & 3 deletions pyfilter/filters/proposals/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def construct(self, y, x):
h_var_inv = 1 / scale ** 2

# ===== Observable ===== #
c = self._model.observable.parameter_views[0]
o_var_inv = 1 / self._model.observable.parameter_views[-1] ** 2
params = self._model.observable.functional_parameters()
c = params[0]
o_var_inv = 1 / params[-1] ** 2

if self._model.hidden_ndim == 0:
self._kernel = self._kernel_1d(y, loc, h_var_inv, o_var_inv, c)
Expand All @@ -73,7 +74,8 @@ def pre_weight(self, y, x):
hloc, hscale = self._model.hidden.mean_scale(x)
oloc, oscale = self._model.observable.mean_scale(hloc)

c = self._model.observable.parameter_views[0]
params = self._model.observable.functional_parameters()
c = params[0]
ovar = oscale ** 2
hvar = hscale ** 2

Expand Down
82 changes: 82 additions & 0 deletions pyfilter/filters/result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
from torch.nn import Module
from typing import Tuple
from .state import BaseState
from ..utils import TensorList, ModuleList


class FilterResult(Module):
def __init__(self, init_state: BaseState):
"""
Implements a basic object for storing log likelihoods and the filtered means of a filter.
"""
super().__init__()

self.register_buffer("_loglikelihood", init_state.get_loglikelihood())

self._filter_means = TensorList()
self._states = ModuleList(init_state)

@property
def loglikelihood(self) -> torch.Tensor:
return self._loglikelihood

@property
def filter_means(self) -> torch.Tensor:
return self._filter_means.values()

@property
def states(self) -> Tuple[BaseState, ...]:
return self._states.values()

@property
def latest_state(self) -> BaseState:
return self._states[-1]

def exchange(self, res, inds: torch.Tensor):
"""
Exchanges the specified indices of self with res.
:param res: The other filter result
:type res: FilterResult
:param inds: The indices
"""

# ===== Loglikelihood ===== #
self._loglikelihood[inds] = res.loglikelihood[inds]

# ===== Filter means ====== #
# TODO: Not the best...
for old_fm, new_fm in zip(self._filter_means, res._filter_means):
old_fm[inds] = new_fm[inds]

for ns, os in zip(res.states, self.states):
os.exchange(ns, inds)

return self

def resample(self, inds: torch.Tensor, entire_history=True):
"""
Resamples the specified indices of self with res.
"""

self._loglikelihood = self.loglikelihood[inds]

if entire_history:
for mean in self._filter_means:
mean[:] = mean[inds]

for s in self._states:
s.resample(inds)

return self

def append(self, state: BaseState, only_latest=True):
self._filter_means.append(state.get_mean())

self._loglikelihood += state.get_loglikelihood()
if only_latest:
self._states = ModuleList(state)
else:
self._states.append(state)

return self
22 changes: 12 additions & 10 deletions pyfilter/filters/state.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from ..uft import UFTCorrectionResult
from torch import Tensor
from ..utils import choose
from ..normalization import normalize
from torch.nn import Module
from ..uft import UFTCorrectionResult
from ..utils import choose, normalize


class BaseState(object):
class BaseState(Module):
def get_mean(self) -> Tensor:
raise NotImplementedError()

Expand All @@ -20,8 +20,9 @@ def exchange(self, state, inds: Tensor):

class KalmanState(BaseState):
def __init__(self, utf: UFTCorrectionResult, ll: Tensor):
self.utf = utf
self.ll = ll
super().__init__()
self.add_module("utf", utf)
self.register_buffer("ll", ll)

def get_mean(self):
return self.utf.xm
Expand Down Expand Up @@ -50,10 +51,11 @@ def exchange(self, state, inds):

class ParticleState(BaseState):
def __init__(self, x: Tensor, w: Tensor, ll: Tensor, prev_inds: Tensor):
self.x = x
self.w = w
self.ll = ll
self.prev_inds = prev_inds
super().__init__()
self.register_buffer("x", x)
self.register_buffer("w", w)
self.register_buffer("ll", ll)
self.register_buffer("prev_inds", prev_inds)

def get_mean(self):
normw = self.normalized_weights()
Expand Down
Loading

0 comments on commit 4b02a73

Please sign in to comment.