Skip to content

Commit

Permalink
API hints for existing users (#246)
Browse files Browse the repository at this point in the history
* Removed root as a keyword argument

* Updates to demo which may be worth considering including warnings for

* Put in user warnings

* Added code for tex, ns_output and plot

* Missed commit

* Now covering ns_output

* Added hist check and plot_type/types

* Added D and d NotImplementedErrors

Co-authored-by: Lukas Hergt <[email protected]>
  • Loading branch information
williamjameshandley and lukashergt authored Jan 24, 2023
1 parent f349603 commit ba7d4b3
Show file tree
Hide file tree
Showing 13 changed files with 1,021 additions and 143 deletions.
15 changes: 15 additions & 0 deletions anesthetic/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,13 @@ def make_1d_axes(params, ncol=None, labels=None,
Pandas array of axes objects.
"""
# TODO: remove this in version >= 2.1
if 'tex' in fig_kw:
raise NotImplementedError(
"This is anesthetic 1.0 syntax. You need to update, e.g.\n"
"make_1d_axes(..., tex=tex) # anesthetic 1.0\n"
"make_1d_axes(..., labels=tex) # anesthetic 2.0"
)
fig = fig_kw.pop('fig') if 'fig' in fig_kw else plt.figure(**fig_kw)
axes = AxesSeries(index=np.atleast_1d(params),
fig=fig,
Expand Down Expand Up @@ -587,7 +594,15 @@ def make_2d_axes(params, labels=None, lower=True, diagonal=True, upper=True,
Pandas array of axes objects.
"""
# TODO: remove this in version >= 2.1
if 'tex' in fig_kw:
raise NotImplementedError(
"This is anesthetic 1.0 syntax. You need to update, e.g.\n"
"make_2d_axes(..., tex=tex) # anesthetic 1.0\n"
"make_2d_axes(..., labels=tex) # anesthetic 2.0"
)
fig = fig_kw.pop('fig') if 'fig' in fig_kw else plt.figure(**fig_kw)

if nest_level(params) == 2:
xparams, yparams = params
else:
Expand Down
20 changes: 20 additions & 0 deletions anesthetic/plotting/_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pandas.plotting import PlotAccessor as _PlotAccessor
from matplotlib.axes import Axes # TODO: remove this in version >= 2.1


def _process_docstring(doc):
Expand Down Expand Up @@ -54,3 +55,22 @@ def hist_2d(self, x, y, **kwargs):
def scatter_2d(self, x, y, **kwargs):
"""Scatter plot: See anesthetic.plot.scatter_plot_2d."""
return self(kind="scatter_2d", x=x, y=y, **kwargs)

# TODO: remove this in version >= 2.1
def __call__(self, *args, **kwargs):
# noqa: disable=D102
if len(args) > 0 and isinstance(args[0], Axes):
raise ValueError(
"This is anesthetic 1.0 syntax. anesthetic 2.0 now follows "
"pandas in its use of plot.\n"
"samples.plot(ax, x) # anesthetic 1.0\n"
"# anesthetic 2.0\n"
"samples.plot(x=x, ax=ax, kind='kde_1d')\n"
"samples.x.plot.kde_1d(ax=ax)\n"
"samples.plot.kde_1d(x=x, ax=ax)\n\n"
"samples.plot(ax, x, y) # anesthetic 1.0\n"
"# anesthetic 2.0\n"
"samples.plot(x=x, y=y, ax=ax, kind='kde_2d')\n"
"samples.plot.kde_2d(x=x, y=y, ax=ax)"
)
return super().__call__(*args, **kwargs)
4 changes: 3 additions & 1 deletion anesthetic/read/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def read_chains(root, *args, **kwargs):
errors = []
for read in [read_polychord, read_multinest, read_cobaya, read_getdist]:
try:
return read(root, *args, **kwargs)
samples = read(root, *args, **kwargs)
samples.root = root
return samples
except (FileNotFoundError, IOError) as e:
errors.append(str(read) + ": " + str(e))

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/read/cobaya.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def read_cobaya(root, *args, **kwargs):
weights, logP, data = np.split(data, [1, 2], axis=1)
mcmc = MCMCSamples(data=data, columns=columns,
weights=weights.flatten(), logL=logP,
root=root, labels=labels, *args, **kwargs)
labels=labels, *args, **kwargs)
mcmc['chain'] = int(i) if i else np.nan
samples.append(mcmc)

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/read/getdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def read_getdist(root, *args, **kwargs):
weights, minuslogL, data = np.split(data, [1, 2], axis=1)
mcmc = MCMCSamples(data=data, columns=columns,
weights=weights.flatten(), logL=-minuslogL,
labels=labels, root=root, *args, **kwargs)
labels=labels, *args, **kwargs)
mcmc['chain'] = int(i) if i else np.nan
samples.append(mcmc)

Expand Down
3 changes: 1 addition & 2 deletions anesthetic/read/multinest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,4 @@ def read_multinest(root, *args, **kwargs):
data = samples

return NestedSamples(data=data, logL=logL, logL_birth=logL_birth,
root=root, columns=columns, labels=labels,
*args, **kwargs)
columns=columns, labels=labels, *args, **kwargs)
2 changes: 1 addition & 1 deletion anesthetic/read/polychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ def read_polychord(root, *args, **kwargs):

return NestedSamples(data=data, columns=columns,
logL=logL, logL_birth=logL_birth,
labels=labels, root=root, *args, **kwargs)
labels=labels, *args, **kwargs)
101 changes: 89 additions & 12 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,33 @@ def plot_1d(self, axes, *args, **kwargs):
Pandas array of axes objects
"""
# TODO: remove this in version >= 2.1
if 'plot_type' in kwargs:
raise ValueError(
"You are using the anesthetic 1.0 kwarg \'plot_type\' instead "
"of anesthetic 2.0 \'kind\'. Please update your code."
)

if not isinstance(axes, AxesSeries):
_, axes = make_1d_axes(axes, labels=self.get_labels_map())

kwargs['kind'] = kwargs.get('kind', 'kde_1d')
kwargs['label'] = kwargs.get('label', self.label)

# TODO: remove this in version >= 2.1
if kwargs['kind'] == 'kde':
warnings.warn(
"You are using \'kde\' as a plot kind. "
"\'kde_1d\' is the appropriate keyword for anesthetic. "
"Your plots may look odd if you use this argument."
)
elif kwargs['kind'] == 'hist':
warnings.warn(
"You are using \'hist\' as a plot kind. "
"\'hist_1d\' is the appropriate keyword for anesthetic. "
"Your plots may look odd if you use this argument."
)

for x, ax in axes.items():
if x in self and kwargs['kind'] is not None:
xlabel = self.get_label(x)
Expand Down Expand Up @@ -268,8 +289,16 @@ def plot_2d(self, axes, *args, **kwargs):
Pandas array of axes objects
"""
# TODO: remove this in version >= 2.1
if 'types' in kwargs:
raise ValueError(
"You are using the anesthetic 1.0 kwarg \'types\' instead of "
"anesthetic 2.0 \'kind' or \'kinds\' (synonyms). "
"Please update your code."
)
kind = kwargs.pop('kind', 'default')
kind = kwargs.pop('kinds', kind)

if isinstance(kind, str) and kind in self.plot_2d_default_kinds:
kind = self.plot_2d_default_kinds.get(kind)
if (not isinstance(kind, dict) or
Expand Down Expand Up @@ -300,6 +329,21 @@ def plot_2d(self, axes, *args, **kwargs):
pos = ax.position
lkwargs = local_kwargs.get(pos, {})
lkwargs['kind'] = kind.get(pos, None)
# TODO: remove this in version >= 2.1
if lkwargs['kind'] == 'kde':
warnings.warn(
"You are using \'kde\' as a plot kind. "
"\'kde_1d\' and \'kde_2d\' are the appropriate "
"keywords for anesthetic. Your plots may look "
"odd if you use this argument."
)
elif lkwargs['kind'] == 'hist':
warnings.warn(
"You are using \'hist\' as a plot kind. "
"\'hist_1d\' and \'hist_2d\' are the appropriate "
"keywords for anesthetic. Your plots may look "
"odd if you use this argument."
)
if x in self and y in self and lkwargs['kind'] is not None:
xlabel = self.get_label(x)
ylabel = self.get_label(y)
Expand Down Expand Up @@ -387,6 +431,18 @@ def importance_sample(self, logL_new, action='add', inplace=False):
else:
return samples.__finalize__(self, "importance_sample")

# TODO: remove this in version >= 2.1
@property
def tex(self):
# noqa: disable=D102
raise NotImplementedError(
"This is anesthetic 1.0 syntax. You need to update, e.g.\n"
"samples.tex[label] = tex # anesthetic 1.0\n"
"samples.set_label(label, tex) # anesthetic 2.0\n\n"
"tex = samples.tex[label] # anesthetic 1.0\n"
"tex = samples.get_label(label) # anesthetic 2.0"
)


class MCMCSamples(Samples):
"""Storage and plotting tools for MCMC samples.
Expand All @@ -396,9 +452,6 @@ class MCMCSamples(Samples):
Parameters
----------
root: str, optional
root for reading chains from file. Overrides all other arguments.
data: np.array
Coordinates of samples. shape = (nsamples, ndims).
Expand All @@ -425,11 +478,6 @@ class MCMCSamples(Samples):

_metadata = Samples._metadata + ['root']

def __init__(self, *args, **kwargs):
root = kwargs.pop('root', None)
super().__init__(*args, **kwargs)
self.root = root

@property
def _constructor(self):
return MCMCSamples
Expand All @@ -455,9 +503,6 @@ class NestedSamples(Samples):
Parameters
----------
root: str, optional
root for reading chains from file. Overrides all other arguments.
data: np.array
Coordinates of samples. shape = (nsamples, ndims).
Expand Down Expand Up @@ -491,7 +536,6 @@ class NestedSamples(Samples):
_metadata = Samples._metadata + ['root', '_beta']

def __init__(self, *args, **kwargs):
self.root = kwargs.pop('root', None)
logzero = kwargs.pop('logzero', -1e30)
self._beta = kwargs.pop('beta', 1.)
logL_birth = kwargs.pop('logL_birth', None)
Expand Down Expand Up @@ -548,6 +592,17 @@ def prior(self, inplace=False):
"""Re-weight samples at infinite temperature to get prior samples."""
return self.set_beta(beta=0, inplace=inplace)

# TODO: remove this in version >= 2.1
def ns_output(self, *args, **kwargs):
# noqa: disable=D102
raise NotImplementedError(
"This is anesthetic 1.0 syntax. You need to update, e.g.\n"
"samples.ns_output(1000) # anesthetic 1.0\n"
"samples.stats(1000) # anesthetic 2.0\n\n"
"Check out the new temperature functionality: help(samples.stats),"
" as well as average loglikelihoods: help(samples.logL_P)"
)

def stats(self, nsamples=None, beta=None):
"""Compute Nested Sampling statistics.
Expand Down Expand Up @@ -795,6 +850,17 @@ def logZ(self, nsamples=None, beta=None):

_logZ_function_shape = '\n' + '\n'.join(logZ.__doc__.split('\n')[1:])

# TODO: remove this in version >= 2.1
def D(self, nsamples=None):
# noqa: disable=D102
raise NotImplementedError(
"This is anesthetic 1.0 syntax. You need to update, e.g.\n"
"samples.D(1000) # anesthetic 1.0\n"
"samples.D_KL(1000) # anesthetic 2.0\n\n"
"Check out the new temperature functionality: help(samples.D_KL), "
"as well as average loglikelihoods: help(samples.logL_P)"
)

def D_KL(self, nsamples=None, beta=None):
"""Kullback-Leibler divergence."""
logw = self.logw(nsamples, beta)
Expand All @@ -811,6 +877,17 @@ def D_KL(self, nsamples=None, beta=None):

D_KL.__doc__ += _logZ_function_shape

# TODO: remove this in version >= 2.1
def d(self, nsamples=None):
# noqa: disable=D102
raise NotImplementedError(
"This is anesthetic 1.0 syntax. You need to update, e.g.\n"
"samples.d(1000) # anesthetic 1.0\n"
"samples.d_G(1000) # anesthetic 2.0\n\n"
"Check out the new temperature functionality: help(samples.d_G), "
"as well as average loglikelihoods: help(samples.logL_P)"
)

def d_G(self, nsamples=None, beta=None):
"""Bayesian model dimensionality."""
logw = self.logw(nsamples, beta)
Expand Down
4 changes: 2 additions & 2 deletions bin/plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from anesthetic import NestedSamples
from anesthetic import read_chains
import numpy as np
ns = NestedSamples(root='./tests/example_data/pc')
ns = read_chains('./tests/example_data/pc')
fig, axes = ns.plot_2d(['x0', 'x1', 'x2', 'x3', 'x4'])

sigma0, sigma1 = 0.1, 0.1
Expand Down
898 changes: 807 additions & 91 deletions demo.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit ba7d4b3

Please sign in to comment.