This repository has been archived by the owner on Mar 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial rewrite to use torch.nn.Module * Minor changes * Dunno * Dunno * Fix * Dunno what happened * More changes * Changes and cleanup * PMMH fix * Some changes * Minor changes * NB updates * Black * Doc
- Loading branch information
1 parent
413256c
commit 4b02a73
Showing
61 changed files
with
735 additions
and
1,012 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.13.5" | ||
__version__ = "0.14.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .parameterized import DistributionWrapper | ||
from .prior import Prior | ||
from .empirical import Empirical |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from torch.distributions import Distribution | ||
import torch | ||
from typing import Dict | ||
|
||
|
||
class BuilderMixin(object): | ||
def get_parameters(self) -> Dict[str, torch.Tensor]: | ||
res = dict() | ||
|
||
res.update(self._parameters) | ||
res.update(self._buffers) | ||
|
||
return res | ||
|
||
def build_distribution(self) -> Distribution: | ||
return self.base_dist(**self.get_parameters()) | ||
|
||
def forward(self) -> Distribution: | ||
return self.build_distribution() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from torch.distributions import Distribution | ||
import torch | ||
from typing import Type, Dict, Union, Callable | ||
from ..prior_module import PriorModule | ||
from .mixin import BuilderMixin | ||
from .prior import Prior | ||
|
||
|
||
DistributionType = Union[Type[Distribution], Callable[[Dict], Distribution]] | ||
|
||
|
||
class DistributionWrapper(BuilderMixin, PriorModule): | ||
def __init__(self, base_dist: DistributionType, **parameters): | ||
super().__init__() | ||
|
||
self.base_dist = base_dist | ||
|
||
for k, v in parameters.items(): | ||
if isinstance(v, Prior): | ||
self.register_prior(k, v) | ||
else: | ||
self.register_buffer(k, v if isinstance(v, torch.Tensor) else torch.tensor(v)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from torch.distributions import TransformedDistribution, biject_to | ||
import torch | ||
from typing import Tuple | ||
from torch.nn import Module | ||
from .mixin import BuilderMixin | ||
|
||
|
||
class Prior(BuilderMixin, Module): | ||
def __init__(self, base_dist, **parameters): | ||
super().__init__() | ||
|
||
self.base_dist = base_dist | ||
|
||
for k, v in parameters.items(): | ||
self.register_buffer(k, v if isinstance(v, torch.Tensor) else torch.tensor(v)) | ||
|
||
self.bijection = biject_to(self().support) | ||
self.shape = self().event_shape | ||
|
||
@property | ||
def unconstrained_prior(self): | ||
return TransformedDistribution(self(), self.bijection.inv) | ||
|
||
def get_unconstrained(self, x: torch.Tensor): | ||
return self.bijection.inv(x) | ||
|
||
def get_constrained(self, x: torch.Tensor): | ||
return self.bijection(x) | ||
|
||
def eval_prior(self, x: torch.Tensor, constrained=True) -> torch.Tensor: | ||
if constrained: | ||
return self().log_prob(x) | ||
|
||
return self.unconstrained_prior.log_prob(self.get_unconstrained(x)) | ||
|
||
def get_numel(self, constrained=True): | ||
return (self().event_shape if not constrained else self.unconstrained_prior.event_shape).numel() | ||
|
||
def get_slice_for_parameter(self, prev_index, constrained=True) -> Tuple[slice, int]: | ||
numel = self.get_numel(constrained) | ||
|
||
return slice(prev_index, prev_index + numel), numel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from .apf import APF | ||
from .sisr import SISR | ||
from .ukf import UKF | ||
from .base import BaseFilter, BaseKalmanFilter, FilterResult | ||
from .base import BaseFilter, BaseKalmanFilter | ||
from .pf import ParticleFilter | ||
from .state import ParticleState, KalmanState, BaseState | ||
from .result import FilterResult |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
from torch.nn import Module | ||
from typing import Tuple | ||
from .state import BaseState | ||
from ..utils import TensorList, ModuleList | ||
|
||
|
||
class FilterResult(Module): | ||
def __init__(self, init_state: BaseState): | ||
""" | ||
Implements a basic object for storing log likelihoods and the filtered means of a filter. | ||
""" | ||
super().__init__() | ||
|
||
self.register_buffer("_loglikelihood", init_state.get_loglikelihood()) | ||
|
||
self._filter_means = TensorList() | ||
self._states = ModuleList(init_state) | ||
|
||
@property | ||
def loglikelihood(self) -> torch.Tensor: | ||
return self._loglikelihood | ||
|
||
@property | ||
def filter_means(self) -> torch.Tensor: | ||
return self._filter_means.values() | ||
|
||
@property | ||
def states(self) -> Tuple[BaseState, ...]: | ||
return self._states.values() | ||
|
||
@property | ||
def latest_state(self) -> BaseState: | ||
return self._states[-1] | ||
|
||
def exchange(self, res, inds: torch.Tensor): | ||
""" | ||
Exchanges the specified indices of self with res. | ||
:param res: The other filter result | ||
:type res: FilterResult | ||
:param inds: The indices | ||
""" | ||
|
||
# ===== Loglikelihood ===== # | ||
self._loglikelihood[inds] = res.loglikelihood[inds] | ||
|
||
# ===== Filter means ====== # | ||
# TODO: Not the best... | ||
for old_fm, new_fm in zip(self._filter_means, res._filter_means): | ||
old_fm[inds] = new_fm[inds] | ||
|
||
for ns, os in zip(res.states, self.states): | ||
os.exchange(ns, inds) | ||
|
||
return self | ||
|
||
def resample(self, inds: torch.Tensor, entire_history=True): | ||
""" | ||
Resamples the specified indices of self with res. | ||
""" | ||
|
||
self._loglikelihood = self.loglikelihood[inds] | ||
|
||
if entire_history: | ||
for mean in self._filter_means: | ||
mean[:] = mean[inds] | ||
|
||
for s in self._states: | ||
s.resample(inds) | ||
|
||
return self | ||
|
||
def append(self, state: BaseState, only_latest=True): | ||
self._filter_means.append(state.get_mean()) | ||
|
||
self._loglikelihood += state.get_loglikelihood() | ||
if only_latest: | ||
self._states = ModuleList(state) | ||
else: | ||
self._states.append(state) | ||
|
||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.