Skip to content
This repository has been archived by the owner on Mar 26, 2024. It is now read-only.

Commit

Permalink
Different sized transform (#71)
Browse files Browse the repository at this point in the history
* Better stacker

* Uses new stacker logic

* TODO

* Version

* Bug fix

* Bug fix

* TODO

* TODO

* Minor improvements

* Comments

* Minor helper

* Adds test for distributions

Co-authored-by: Victor <[email protected]>
  • Loading branch information
tingiskhan and tingiskhan authored Mar 16, 2020
1 parent 87dd336 commit 0d14168
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 64 deletions.
2 changes: 1 addition & 1 deletion pyfilter/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.6.8'
__version__ = '0.6.9'
6 changes: 3 additions & 3 deletions pyfilter/inference/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def record_stats(self, parameters, weights):
:rtype: BaseKernel
"""

values, _ = stacker(parameters, lambda u: u.t_values)
stacked = stacker(parameters, lambda u: u.t_values)
weights = weights.unsqueeze(-1)

mean = (values * weights).sum(0)
scale = ((values - mean) ** 2 * weights).sum(0).sqrt()
mean = (stacked.concated * weights).sum(0)
scale = ((stacked.concated - mean) ** 2 * weights).sum(0).sqrt()

self._recorded_stats['mean'] += (mean,)
self._recorded_stats['scale'] += (scale,)
Expand Down
8 changes: 4 additions & 4 deletions pyfilter/inference/kernels/mh.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def _before_resampling(self, filter_, stacked):
def _update(self, parameters, filter_, weights):
for i in range(self._nsteps):
# ===== Construct distribution ===== #
stacked, mask = stacker(parameters, lambda u: u.t_values)
dist = self.define_pdf(stacked, weights)
stacked = stacker(parameters, lambda u: u.t_values)
dist = self.define_pdf(stacked.concated, weights)

# ===== Perform necessary operation prior to resampling ===== #
self._before_resampling(filter_, stacked)
self._before_resampling(filter_, stacked.concated)

# ===== Resample among parameters ===== #
inds = self._resampler(weights, normalized=True)
Expand All @@ -88,7 +88,7 @@ def _update(self, parameters, filter_, weights):
# ===== Define new filters and move via MCMC ===== #
t_filt = filter_.copy()
t_filt.viewify_params((*filter_._n_parallel, 1))
_mcmc_move(t_filt.ssm.theta_dists, dist, mask, stacked.shape[0])
_mcmc_move(t_filt.ssm.theta_dists, dist, stacked, stacked.concated.shape[0])

# ===== Calculate difference in loglikelihood ===== #
quotient = self._calc_diff_logl(t_filt, filter_)
Expand Down
34 changes: 17 additions & 17 deletions pyfilter/inference/kernels/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def _resample(self, filter_, weights):

def _update(self, parameters, filter_, weights):
# ===== Perform shrinkage ===== #
stacked, mask = stacker(parameters, lambda u: u.t_values)
kde = self._kde.fit(stacked, weights)
stacked = stacker(parameters, lambda u: u.t_values)
kde = self._kde.fit(stacked.concated, weights)

inds = self._resample(filter_, weights)
jittered = kde.sample(inds=inds)

# ===== Mutate parameters ===== #
for msk, p in zip(mask, parameters):
p.t_values = unflattify(jittered[:, msk], p.c_shape)
for p, msk, ps in zip(parameters, stacked.mask, stacked.prev_shape):
p.t_values = unflattify(jittered[:, msk], ps)

return self._resampled

Expand All @@ -74,13 +74,13 @@ def __init__(self, eps=5e-5, **kwargs):

def _update(self, parameters, filter_, weights):
# ===== Define stacks ===== #
stacked, mask = stacker(parameters, lambda u: u.t_values)
stacked = stacker(parameters, lambda u: u.t_values)

# ===== Check "convergence" ====== #
w = add_dimensions(weights, stacked.dim())
w = add_dimensions(weights, stacked.concated.dim())

mean = (w * stacked).sum(0)
var = (w * (stacked - mean) ** 2).sum(0)
mean = (w * stacked.concated).sum(0)
var = (w * (stacked.concated - mean) ** 2).sum(0)

if self._switched is None:
self._switched = torch.zeros_like(mean).bool()
Expand All @@ -97,19 +97,19 @@ def _update(self, parameters, filter_, weights):
inds = self._resample(filter_, weights)

# ===== Perform shrinkage ===== #
jittered = torch.empty_like(stacked)
jittered = torch.empty_like(stacked.concated)

if (~self._switched).any():
shrink_kde = self._shrink_kde.fit(stacked[:, ~self._switched], weights)
shrink_kde = self._shrink_kde.fit(stacked.concated[:, ~self._switched], weights)
jittered[:, ~self._switched] = shrink_kde.sample(inds=inds)

if self._switched.any():
non_shrink = self._non_shrink.fit(stacked[:, self._switched], weights)
non_shrink = self._non_shrink.fit(stacked.concated[:, self._switched], weights)
jittered[:, self._switched] = non_shrink.sample(inds=inds)

# ===== Set new values ===== #
for p, msk in zip(parameters, mask):
p.t_values = unflattify(jittered[:, msk], p.c_shape)
for p, msk, ps in zip(parameters, stacked.mask, stacked.prev_shape):
p.t_values = unflattify(jittered[:, msk], ps)

return self._resampled

Expand All @@ -125,18 +125,18 @@ def __init__(self, kde=None, **kwargs):
self._kde = kde or MultivariateGaussian()

def _update(self, parameters, filter_, weights):
values, mask = stacker(parameters, lambda u: u.t_values)
stacked = stacker(parameters, lambda u: u.t_values)

# ===== Calculate covariance ===== #
kde = self._kde.fit(values, weights)
kde = self._kde.fit(stacked.concated, weights)

# ===== Resample ===== #
inds = self._resampler(weights, normalized=True)
filter_.resample(inds, entire_history=False)

# ===== Sample params ===== #
samples = kde.sample(inds=inds)
for p, msk in zip(parameters, mask):
p.t_values = unflattify(samples[:, msk], p.c_shape)
for p, msk, ps in zip(parameters, stacked.mask, stacked.prev_shape):
p.t_values = unflattify(samples[:, msk], ps)

return True
4 changes: 2 additions & 2 deletions pyfilter/inference/smc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _update(self, y):
self._iterator.set_description(str(self))

# ===== Check if to propagate ===== #
force_rejuv = self._logged_ess[-1] < 0.1 * self._particles or (~torch.isfinite(self._w_rec)).any()
force_rejuv = self._logged_ess[-1] < 0.1 * self._particles[0] or (~torch.isfinite(self._w_rec)).any()
if self._num_iters % self._bl == 0 or force_rejuv:
self._kernel.update(self.filter.ssm.theta_dists, self.filter, self._w_rec)
self._num_iters = 0
Expand All @@ -150,7 +150,7 @@ def _update(self, y):
self._num_iters += 1

# ===== Calculate efficient number of samples ===== #
self._logged_ess += (get_ess(self._w_rec),)
self._logged_ess.append(get_ess(self._w_rec))

return self

51 changes: 35 additions & 16 deletions pyfilter/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,50 @@
from ..utils import unflattify


def stacker(parameters, selector=lambda u: u.values):
class StackedObject(object):
def __init__(self, concated, mask, prev_shape):
"""
Helper object
"""

self.concated = concated
self.mask = mask
self.prev_shape = prev_shape


def stacker(parameters, selector=lambda u: u.values, dim=1):
"""
Stacks the parameters and returns a n-tuple containing the mask for each parameter.
:param parameters: The parameters
:type parameters: tuple[Parameter]|list[Parameter]
:param selector: The selector
:rtype: torch.Tensor, tuple[slice]
:param dim: The dimension to start flattening from
:type dim: int
:rtype: StackedObject
"""

to_conc = tuple()
mask = tuple()
prev_shape = tuple()

i = 0
# TODO: Currently only supports one sampling dimension...
for p in parameters:
if p.c_numel() < 2:
to_conc += (selector(p).unsqueeze(-1),)
s = selector(p)
flat = s if s.dim() <= dim else s.flatten(dim)

if flat.dim() == dim:
to_conc += (flat.unsqueeze(-1),)
slc = i
else:
to_conc += (selector(p).flatten(1),)
slc = slice(i, i + p.c_numel())
to_conc += (flat,)
slc = slice(i, i + flat.shape[-1])

mask += (slc,)
i += p.c_numel()
i += to_conc[-1].shape[-1]
prev_shape += (s.shape[dim:],)

return torch.cat(to_conc, dim=-1), mask
return StackedObject(torch.cat(to_conc, dim=-1), mask, prev_shape)


def _construct_mvn(x, w):
Expand All @@ -47,15 +66,15 @@ def _construct_mvn(x, w):
return MultivariateNormal(mean, scale_tril=torch.cholesky(cov))


def _mcmc_move(params, dist, mask, shape):
def _mcmc_move(params, dist, stacked, shape):
"""
Performs an MCMC move to rejuvenate parameters.
:param params: The parameters to use for defining the distribution
:type params: tuple[Parameter]
:param dist: The distribution to use for sampling
:type dist: MultivariateNormal
:param mask: The mask to apply for parameters
:type mask: tuple[slice]
:param stacked: The mask to apply for parameters
:type stacked: StackedObject
:param shape: The shape to sample
:type shape: int
:return: Samples from a multivariate normal distribution
Expand All @@ -64,8 +83,8 @@ def _mcmc_move(params, dist, mask, shape):

rvs = dist.sample((shape,))

for p, msk in zip(params, mask):
p.t_values = unflattify(rvs[:, msk], p.c_shape)
for p, msk, ps in zip(params, stacked.mask, stacked.prev_shape):
p.t_values = unflattify(rvs[:, msk], ps)

return True

Expand All @@ -83,7 +102,7 @@ def _eval_kernel(params, dist, n_params):
:rtype: torch.Tensor
"""

p_vals, _ = stacker(params, lambda u: u.t_values)
n_p_vals, _ = stacker(n_params, lambda u: u.t_values)
p_vals = stacker(params, lambda u: u.t_values)
n_p_vals = stacker(n_params, lambda u: u.t_values)

return dist.log_prob(p_vals) - dist.log_prob(n_p_vals)
return dist.log_prob(p_vals.concated) - dist.log_prob(n_p_vals.concated)
4 changes: 2 additions & 2 deletions pyfilter/inference/varapprox/meanfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def initialize(self, parameters, *args):
self._mean = torch.zeros(sum(p.c_numel() for p in parameters))
self._std = torch.ones_like(self._mean)

_, mask = stacker(parameters)
stacked = stacker(parameters)

for p, msk in zip(parameters, mask):
for p, msk in zip(parameters, stacked.mask):
try:
self._mean[msk] = p.bijection.inv(p.distr.mean)
except NotImplementedError:
Expand Down
5 changes: 3 additions & 2 deletions pyfilter/inference/vb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..utils import EPS, unflattify


# TODO: Shape not working correctly when transformed != untransformed
class VariationalBayes(BatchAlgorithm):
def __init__(self, model, num_samples=4, approx=None, optimizer=optim.Adam, maxiters=30e3, optkwargs=None):
"""
Expand All @@ -34,7 +35,7 @@ def __init__(self, model, num_samples=4, approx=None, optimizer=optim.Adam, maxi
self._s_approx = approx or StateMeanField()
self._p_approx = ParameterMeanField()

self._p_mask = None
self._mask = None

self._is_ssm = isinstance(model, StateSpaceModel)

Expand Down Expand Up @@ -84,7 +85,7 @@ def sample_params(self):
def _initialize(self, y):
# ===== Sample model in place for a primitive version of initialization ===== #
self._model.sample_params(self._numsamples)
_, self._mask = stacker(self._model.theta_dists) # NB: We create a mask once
self._mask = stacker(self._model.theta_dists).mask # NB: We create a mask once

# ===== Setup the parameter approximation ===== #
self._p_approx = self._p_approx.initialize(self._model.theta_dists)
Expand Down
7 changes: 5 additions & 2 deletions pyfilter/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@ def __bool__(self):

@property
def tensors(self):
return tuple(self.values())
return flatten(self.values())

def items(self):
return self._dict.items()

def values(self):
return self._dict.values()
if all(isinstance(v, TensorContainerBase) for v in self._dict.values()):
return tuple(d.values() for d in self._dict.values())

return tuple(self._dict.values())


def _find_types(x, type_):
Expand Down
Loading

0 comments on commit 0d14168

Please sign in to comment.