diff --git a/pyfilter/__init__.py b/pyfilter/__init__.py index 954f8a03..f046216a 100644 --- a/pyfilter/__init__.py +++ b/pyfilter/__init__.py @@ -1 +1 @@ -__version__ = '0.6.8' \ No newline at end of file +__version__ = '0.6.9' \ No newline at end of file diff --git a/pyfilter/inference/kernels/base.py b/pyfilter/inference/kernels/base.py index b2974144..a28078a9 100644 --- a/pyfilter/inference/kernels/base.py +++ b/pyfilter/inference/kernels/base.py @@ -95,11 +95,11 @@ def record_stats(self, parameters, weights): :rtype: BaseKernel """ - values, _ = stacker(parameters, lambda u: u.t_values) + stacked = stacker(parameters, lambda u: u.t_values) weights = weights.unsqueeze(-1) - mean = (values * weights).sum(0) - scale = ((values - mean) ** 2 * weights).sum(0).sqrt() + mean = (stacked.concated * weights).sum(0) + scale = ((stacked.concated - mean) ** 2 * weights).sum(0).sqrt() self._recorded_stats['mean'] += (mean,) self._recorded_stats['scale'] += (scale,) diff --git a/pyfilter/inference/kernels/mh.py b/pyfilter/inference/kernels/mh.py index 58b700df..be456389 100644 --- a/pyfilter/inference/kernels/mh.py +++ b/pyfilter/inference/kernels/mh.py @@ -75,11 +75,11 @@ def _before_resampling(self, filter_, stacked): def _update(self, parameters, filter_, weights): for i in range(self._nsteps): # ===== Construct distribution ===== # - stacked, mask = stacker(parameters, lambda u: u.t_values) - dist = self.define_pdf(stacked, weights) + stacked = stacker(parameters, lambda u: u.t_values) + dist = self.define_pdf(stacked.concated, weights) # ===== Perform necessary operation prior to resampling ===== # - self._before_resampling(filter_, stacked) + self._before_resampling(filter_, stacked.concated) # ===== Resample among parameters ===== # inds = self._resampler(weights, normalized=True) @@ -88,7 +88,7 @@ def _update(self, parameters, filter_, weights): # ===== Define new filters and move via MCMC ===== # t_filt = filter_.copy() t_filt.viewify_params((*filter_._n_parallel, 1)) - _mcmc_move(t_filt.ssm.theta_dists, dist, mask, stacked.shape[0]) + _mcmc_move(t_filt.ssm.theta_dists, dist, stacked, stacked.concated.shape[0]) # ===== Calculate difference in loglikelihood ===== # quotient = self._calc_diff_logl(t_filt, filter_) diff --git a/pyfilter/inference/kernels/online.py b/pyfilter/inference/kernels/online.py index 5bb2918e..d9b0907e 100644 --- a/pyfilter/inference/kernels/online.py +++ b/pyfilter/inference/kernels/online.py @@ -43,15 +43,15 @@ def _resample(self, filter_, weights): def _update(self, parameters, filter_, weights): # ===== Perform shrinkage ===== # - stacked, mask = stacker(parameters, lambda u: u.t_values) - kde = self._kde.fit(stacked, weights) + stacked = stacker(parameters, lambda u: u.t_values) + kde = self._kde.fit(stacked.concated, weights) inds = self._resample(filter_, weights) jittered = kde.sample(inds=inds) # ===== Mutate parameters ===== # - for msk, p in zip(mask, parameters): - p.t_values = unflattify(jittered[:, msk], p.c_shape) + for p, msk, ps in zip(parameters, stacked.mask, stacked.prev_shape): + p.t_values = unflattify(jittered[:, msk], ps) return self._resampled @@ -74,13 +74,13 @@ def __init__(self, eps=5e-5, **kwargs): def _update(self, parameters, filter_, weights): # ===== Define stacks ===== # - stacked, mask = stacker(parameters, lambda u: u.t_values) + stacked = stacker(parameters, lambda u: u.t_values) # ===== Check "convergence" ====== # - w = add_dimensions(weights, stacked.dim()) + w = add_dimensions(weights, stacked.concated.dim()) - mean = (w * stacked).sum(0) - var = (w * (stacked - mean) ** 2).sum(0) + mean = (w * stacked.concated).sum(0) + var = (w * (stacked.concated - mean) ** 2).sum(0) if self._switched is None: self._switched = torch.zeros_like(mean).bool() @@ -97,19 +97,19 @@ def _update(self, parameters, filter_, weights): inds = self._resample(filter_, weights) # ===== Perform shrinkage ===== # - jittered = torch.empty_like(stacked) + jittered = torch.empty_like(stacked.concated) if (~self._switched).any(): - shrink_kde = self._shrink_kde.fit(stacked[:, ~self._switched], weights) + shrink_kde = self._shrink_kde.fit(stacked.concated[:, ~self._switched], weights) jittered[:, ~self._switched] = shrink_kde.sample(inds=inds) if self._switched.any(): - non_shrink = self._non_shrink.fit(stacked[:, self._switched], weights) + non_shrink = self._non_shrink.fit(stacked.concated[:, self._switched], weights) jittered[:, self._switched] = non_shrink.sample(inds=inds) # ===== Set new values ===== # - for p, msk in zip(parameters, mask): - p.t_values = unflattify(jittered[:, msk], p.c_shape) + for p, msk, ps in zip(parameters, stacked.mask, stacked.prev_shape): + p.t_values = unflattify(jittered[:, msk], ps) return self._resampled @@ -125,10 +125,10 @@ def __init__(self, kde=None, **kwargs): self._kde = kde or MultivariateGaussian() def _update(self, parameters, filter_, weights): - values, mask = stacker(parameters, lambda u: u.t_values) + stacked = stacker(parameters, lambda u: u.t_values) # ===== Calculate covariance ===== # - kde = self._kde.fit(values, weights) + kde = self._kde.fit(stacked.concated, weights) # ===== Resample ===== # inds = self._resampler(weights, normalized=True) @@ -136,7 +136,7 @@ def _update(self, parameters, filter_, weights): # ===== Sample params ===== # samples = kde.sample(inds=inds) - for p, msk in zip(parameters, mask): - p.t_values = unflattify(samples[:, msk], p.c_shape) + for p, msk, ps in zip(parameters, stacked.mask, stacked.prev_shape): + p.t_values = unflattify(samples[:, msk], ps) return True \ No newline at end of file diff --git a/pyfilter/inference/smc2.py b/pyfilter/inference/smc2.py index 85fa417d..cf262901 100644 --- a/pyfilter/inference/smc2.py +++ b/pyfilter/inference/smc2.py @@ -139,7 +139,7 @@ def _update(self, y): self._iterator.set_description(str(self)) # ===== Check if to propagate ===== # - force_rejuv = self._logged_ess[-1] < 0.1 * self._particles or (~torch.isfinite(self._w_rec)).any() + force_rejuv = self._logged_ess[-1] < 0.1 * self._particles[0] or (~torch.isfinite(self._w_rec)).any() if self._num_iters % self._bl == 0 or force_rejuv: self._kernel.update(self.filter.ssm.theta_dists, self.filter, self._w_rec) self._num_iters = 0 @@ -150,7 +150,7 @@ def _update(self, y): self._num_iters += 1 # ===== Calculate efficient number of samples ===== # - self._logged_ess += (get_ess(self._w_rec),) + self._logged_ess.append(get_ess(self._w_rec)) return self diff --git a/pyfilter/inference/utils.py b/pyfilter/inference/utils.py index 6d8a3a1c..a7dcb562 100644 --- a/pyfilter/inference/utils.py +++ b/pyfilter/inference/utils.py @@ -3,31 +3,50 @@ from ..utils import unflattify -def stacker(parameters, selector=lambda u: u.values): +class StackedObject(object): + def __init__(self, concated, mask, prev_shape): + """ + Helper object + """ + + self.concated = concated + self.mask = mask + self.prev_shape = prev_shape + + +def stacker(parameters, selector=lambda u: u.values, dim=1): """ Stacks the parameters and returns a n-tuple containing the mask for each parameter. :param parameters: The parameters :type parameters: tuple[Parameter]|list[Parameter] :param selector: The selector - :rtype: torch.Tensor, tuple[slice] + :param dim: The dimension to start flattening from + :type dim: int + :rtype: StackedObject """ to_conc = tuple() mask = tuple() + prev_shape = tuple() i = 0 + # TODO: Currently only supports one sampling dimension... for p in parameters: - if p.c_numel() < 2: - to_conc += (selector(p).unsqueeze(-1),) + s = selector(p) + flat = s if s.dim() <= dim else s.flatten(dim) + + if flat.dim() == dim: + to_conc += (flat.unsqueeze(-1),) slc = i else: - to_conc += (selector(p).flatten(1),) - slc = slice(i, i + p.c_numel()) + to_conc += (flat,) + slc = slice(i, i + flat.shape[-1]) mask += (slc,) - i += p.c_numel() + i += to_conc[-1].shape[-1] + prev_shape += (s.shape[dim:],) - return torch.cat(to_conc, dim=-1), mask + return StackedObject(torch.cat(to_conc, dim=-1), mask, prev_shape) def _construct_mvn(x, w): @@ -47,15 +66,15 @@ def _construct_mvn(x, w): return MultivariateNormal(mean, scale_tril=torch.cholesky(cov)) -def _mcmc_move(params, dist, mask, shape): +def _mcmc_move(params, dist, stacked, shape): """ Performs an MCMC move to rejuvenate parameters. :param params: The parameters to use for defining the distribution :type params: tuple[Parameter] :param dist: The distribution to use for sampling :type dist: MultivariateNormal - :param mask: The mask to apply for parameters - :type mask: tuple[slice] + :param stacked: The mask to apply for parameters + :type stacked: StackedObject :param shape: The shape to sample :type shape: int :return: Samples from a multivariate normal distribution @@ -64,8 +83,8 @@ def _mcmc_move(params, dist, mask, shape): rvs = dist.sample((shape,)) - for p, msk in zip(params, mask): - p.t_values = unflattify(rvs[:, msk], p.c_shape) + for p, msk, ps in zip(params, stacked.mask, stacked.prev_shape): + p.t_values = unflattify(rvs[:, msk], ps) return True @@ -83,7 +102,7 @@ def _eval_kernel(params, dist, n_params): :rtype: torch.Tensor """ - p_vals, _ = stacker(params, lambda u: u.t_values) - n_p_vals, _ = stacker(n_params, lambda u: u.t_values) + p_vals = stacker(params, lambda u: u.t_values) + n_p_vals = stacker(n_params, lambda u: u.t_values) - return dist.log_prob(p_vals) - dist.log_prob(n_p_vals) \ No newline at end of file + return dist.log_prob(p_vals.concated) - dist.log_prob(n_p_vals.concated) \ No newline at end of file diff --git a/pyfilter/inference/varapprox/meanfield.py b/pyfilter/inference/varapprox/meanfield.py index 5781e487..b96755c3 100644 --- a/pyfilter/inference/varapprox/meanfield.py +++ b/pyfilter/inference/varapprox/meanfield.py @@ -44,9 +44,9 @@ def initialize(self, parameters, *args): self._mean = torch.zeros(sum(p.c_numel() for p in parameters)) self._std = torch.ones_like(self._mean) - _, mask = stacker(parameters) + stacked = stacker(parameters) - for p, msk in zip(parameters, mask): + for p, msk in zip(parameters, stacked.mask): try: self._mean[msk] = p.bijection.inv(p.distr.mean) except NotImplementedError: diff --git a/pyfilter/inference/vb.py b/pyfilter/inference/vb.py index 15f9acd4..da6f07c8 100644 --- a/pyfilter/inference/vb.py +++ b/pyfilter/inference/vb.py @@ -9,6 +9,7 @@ from ..utils import EPS, unflattify +# TODO: Shape not working correctly when transformed != untransformed class VariationalBayes(BatchAlgorithm): def __init__(self, model, num_samples=4, approx=None, optimizer=optim.Adam, maxiters=30e3, optkwargs=None): """ @@ -34,7 +35,7 @@ def __init__(self, model, num_samples=4, approx=None, optimizer=optim.Adam, maxi self._s_approx = approx or StateMeanField() self._p_approx = ParameterMeanField() - self._p_mask = None + self._mask = None self._is_ssm = isinstance(model, StateSpaceModel) @@ -84,7 +85,7 @@ def sample_params(self): def _initialize(self, y): # ===== Sample model in place for a primitive version of initialization ===== # self._model.sample_params(self._numsamples) - _, self._mask = stacker(self._model.theta_dists) # NB: We create a mask once + self._mask = stacker(self._model.theta_dists).mask # NB: We create a mask once # ===== Setup the parameter approximation ===== # self._p_approx = self._p_approx.initialize(self._model.theta_dists) diff --git a/pyfilter/module.py b/pyfilter/module.py index 7f36c2e8..4ccfdbe6 100644 --- a/pyfilter/module.py +++ b/pyfilter/module.py @@ -86,13 +86,16 @@ def __bool__(self): @property def tensors(self): - return tuple(self.values()) + return flatten(self.values()) def items(self): return self._dict.items() def values(self): - return self._dict.values() + if all(isinstance(v, TensorContainerBase) for v in self._dict.values()): + return tuple(d.values() for d in self._dict.values()) + + return tuple(self._dict.values()) def _find_types(x, type_): diff --git a/pyfilter/timeseries/base.py b/pyfilter/timeseries/base.py index 28ff2ef7..4daa110c 100644 --- a/pyfilter/timeseries/base.py +++ b/pyfilter/timeseries/base.py @@ -164,6 +164,10 @@ def copy(self): return deepcopy(self) +def _view_helper(p, shape): + return p.view(*shape, *p._prior.event_shape) if len(shape) > 0 else p.view(p.shape) + + class StochasticProcess(StochasticProcessBase, ABC): def __init__(self, theta, initial_dist, increment_dist): """ @@ -198,20 +202,30 @@ def __init__(self, theta, initial_dist, increment_dist): self._inputdim = self.ndim self._event_dim = 0 if self.ndim < 2 else 1 - # ===== Parameters ===== # + # ===== Distributional parameters ===== # self._dist_theta = TensorContainerDict() - # TODO: Make sure same keys are same reference + self._org_dist = TensorContainerDict() + for n in [self.initial_dist, self.increment_dist]: if n is None: continue - for k, v in n.__dict__.items(): + parameters = TensorContainerDict() + statics = TensorContainerDict() + for k, v in vars(n).items(): if k.startswith('_'): continue if isinstance(v, Parameter) and n is self.increment_dist: - self._dist_theta[k] = v + parameters[k] = v + elif isinstance(v, torch.Tensor): + statics[k] = v + + if not not parameters: + self._dist_theta[n] = parameters + self._org_dist[n] = statics + # ===== Regular parameters ====== # self.theta = TensorContainer(Parameter(th) if not isinstance(th, Parameter) else th for th in theta) # ===== Check dimensions ===== # @@ -232,7 +246,7 @@ def distributional_theta(self): """ Returns the parameters of the distribution to re-initialize the distribution with. Mainly a helper for when the user passes distributions parameterized by priors. - :rtype: dict[str, Parameter] + :rtype: TensorContainerDict """ return self._dist_theta @@ -253,7 +267,7 @@ def theta(self, x): @property def theta_dists(self): - return tuple(p for p in self.theta if p.trainable) + tuple(self.distributional_theta.values()) + return tuple(p for p in self.theta if p.trainable) + self.distributional_theta.tensors @property def theta_vals(self): @@ -288,7 +302,7 @@ def viewify_params(self, shape): params = tuple() for param in self.theta: if param.trainable: - var = param.view(*shape, *param._prior.event_shape) if len(shape) > 0 else param.view(param.shape) + var = _view_helper(param, shape) else: var = param @@ -297,13 +311,14 @@ def viewify_params(self, shape): self._theta_vals = TensorContainer(*params) # ===== Distributional parameters ===== # - pdict = dict() - for k, v in self.distributional_theta.items(): - pdict[k] = v.view(*shape, *v._prior.event_shape) if len(shape) > 0 else v.view(v.shape) + for d, dists in self.distributional_theta.items(): + temp = dict() + temp.update(self._org_dist[d]._dict) + + for k, v in dists.items(): + temp[k] = _view_helper(v, shape) - if len(pdict) > 0: - self.initial_dist.__init__(**pdict) - self.increment_dist.__init__(**pdict) + d.__init__(**temp) return self diff --git a/test/timeseries.py b/test/timeseries.py index 43f9d44e..9adc29ef 100644 --- a/test/timeseries.py +++ b/test/timeseries.py @@ -252,7 +252,7 @@ def test_SDE(self): self.assertEqual(samps.shape, path.shape) def test_Poisson(self): - shape = 1000, 100 + shape = 10, 100 a = 1e-2 * torch.ones((shape[0], 1)) dt = 1e-2 @@ -273,6 +273,34 @@ def test_Poisson(self): samps = torch.stack(samps) self.assertEqual(samps.size(), torch.Size([num + 1, *shape])) + # ===== Sample path ===== # + path = sde.sample_path(num + 1, shape) + self.assertEqual(samps.shape, path.shape) + + def test_ParameterInDistribution(self): + shape = 10, 100 + + a = 1e-2 * torch.ones((shape[0], 1)) + dt = 1e-2 + dist = Normal(loc=0., scale=Parameter(Exponential(10.))) + + init = Normal(a, 1.) + sde = EulerMaruyama((f_sde, g_sde), (a, 0.15), init, dist, dt=dt, num_steps=10) + + sde.sample_params(shape) + + # ===== Initialize ===== # + x = sde.i_sample(shape) + + # ===== Propagate ===== # + num = 1000 + samps = [x] + for t in range(num): + samps.append(sde.propagate(samps[-1])) + + samps = torch.stack(samps) + self.assertEqual(samps.size(), torch.Size([num + 1, *shape])) + # ===== Sample path ===== # path = sde.sample_path(num + 1, shape) self.assertEqual(samps.shape, path.shape) \ No newline at end of file diff --git a/test/utils.py b/test/utils.py index 515c9393..4f2539d2 100644 --- a/test/utils.py +++ b/test/utils.py @@ -6,6 +6,8 @@ from pyfilter.filters import SISR, UKF from pyfilter.module import Module, TensorContainer from pyfilter.utils import concater +from pyfilter.inference.utils import stacker +from pyfilter.timeseries import Parameter class Help2(Module): @@ -148,4 +150,27 @@ def test_StateDict(self): ukf = UKF(model).initialize() sd = ukf.state_dict() - assert '_model' in sd and '_model' not in sd['_ut'] \ No newline at end of file + assert '_model' in sd and '_model' not in sd['_ut'] + + def test_Stacker(self): + # ===== Define a mix of parameters ====== # + zerod = Parameter(Normal(0., 1.)).sample_((1000,)) + oned_luring = Parameter(Normal(torch.tensor([0.]), torch.tensor([1.]))).sample_(zerod.shape) + oned = Parameter(MultivariateNormal(torch.zeros(2), torch.eye(2))).sample_(zerod.shape) + + mu = torch.zeros((3, 3)) + norm = Independent(Normal(mu, torch.ones_like(mu)), 2) + twod = Parameter(norm).sample_(zerod.shape) + + # ===== Stack ===== # + params = (zerod, oned, oned_luring, twod) + stacked = stacker(params, lambda u: u.t_values, dim=1) + + # ===== Verify it's recreated correctly ====== # + for p, m, ps in zip(params, stacked.mask, stacked.prev_shape): + v = stacked.concated[..., m] + + if len(p.c_shape) != 0: + v = v.reshape(*v.shape[:-1], *ps) + + assert (p.t_values == v).all()