Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dead points and Truncation #354

Merged
merged 12 commits into from
Nov 14, 2023
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
anesthetic: nested sampling post-processing
===========================================
:Authors: Will Handley and Lukas Hergt
:Version: 2.5.2
:Version: 2.6.0
:Homepage: https://github.com/handley-lab/anesthetic
:Documentation: http://anesthetic.readthedocs.io/

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.5.2'
__version__ = '2.6.0'
87 changes: 79 additions & 8 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,8 +1171,38 @@ def logL_P(self, nsamples=None, beta=None):

logL_P.__doc__ += _logZ_function_shape

def contour(self, logL=None):
"""Convert contour from (index or None) to a float loglikelihood.

Convention is that live points are inclusive of the contour.

Helper function for:
- NestedSamples.live_points,
- NestedSamples.dead_points,
- NestedSamples.truncate.

Parameters
----------
logL : float or int, optional
Loglikelihood or iteration number
If not provided, return the contour containing the last set of
live points.

Returns
-------
logL : float
Loglikelihood of contour
"""
if logL is None:
logL = self.loc[self.logL > self.logL_birth.max()].logL.iloc[0]
elif isinstance(logL, float):
pass
else:
logL = float(self.logL[logL])
return logL

def live_points(self, logL=None):
"""Get the live points within logL.
"""Get the live points within a contour.

Parameters
----------
Expand All @@ -1188,16 +1218,57 @@ def live_points(self, logL=None):
- ith iteration (if input is integer)
- last set of live points if no argument provided
"""
if logL is None:
logL = self.logL_birth.max()
else:
try:
logL = float(self.logL[logL])
except KeyError:
pass
logL = self.contour(logL)
i = ((self.logL >= logL) & (self.logL_birth < logL)).to_numpy()
yallup marked this conversation as resolved.
Show resolved Hide resolved
return Samples(self[i]).set_weights(None)

def dead_points(self, logL=None):
"""Get the dead points at a given contour.

Convention is that dead points are exclusive of the contour.

Parameters
----------
logL : float or int, optional
Loglikelihood or iteration number to return dead points.
If not provided, return the last set of dead points.

Returns
-------
dead_points : Samples
Dead points at either:
- contour logL (if input is float)
- ith iteration (if input is integer)
- last set of dead points if no argument provided
"""
logL = self.contour(logL)
i = ((self.logL < logL)).to_numpy()
return Samples(self[i]).set_weights(None)

def truncate(self, logL=None):
"""Truncate the run at a given contour.

Returns the union of the live_points and dead_points.

Parameters
----------
logL : float or int, optional
Loglikelihood or iteration number to truncate run.
If not provided, truncate at the last set of dead points.

Returns
-------
truncated_run : NestedSamples
Run truncated at either:
- contour logL (if input is float)
- ith iteration (if input is integer)
- last set of dead points if no argument provided
"""
dead_points = self.dead_points(logL)
live_points = self.live_points(logL)
index = np.concatenate([dead_points.index, live_points.index])
return self.loc[index].recompute()

def posterior_points(self, beta=1):
"""Get equally weighted posterior points at temperature beta."""
return self.set_beta(beta).compress('equal')
Expand Down
62 changes: 62 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,68 @@ def test_live_points():
assert not live_points.isweighted()


def test_dead_points():
np.random.seed(4)
pc = read_chains("./tests/example_data/pc")

for i, logL in pc.logL.iloc[::49].items():
dead_points = pc.dead_points(logL)
assert len(dead_points) == int(len(pc[:i[0]]))

dead_points_from_int = pc.dead_points(i[0])
assert_array_equal(dead_points_from_int, dead_points)

dead_points_from_index = pc.dead_points(i)
assert_array_equal(dead_points_from_index, dead_points)

assert pc.dead_points(1).index[0] == 0

last_dead_points = pc.dead_points()
logL = pc.logL_birth.max()
assert (last_dead_points.logL <= logL).all()
assert len(last_dead_points) == len(pc) - pc.nlive.mode().to_numpy()[0]
assert not dead_points.isweighted()


def test_contour():
np.random.seed(4)
pc = read_chains("./tests/example_data/pc")

cut_float = 30.0
assert cut_float == pc.contour(cut_float)

cut_int = 0
assert pc.logL.min() == pc.contour(cut_int)

cut_none = None
nlive = pc.nlive.mode().to_numpy()[0]
assert sorted(pc.logL)[-nlive] == pc.contour(cut_none)


@pytest.mark.parametrize("cut", [200, 0.0, None])
def test_truncate(cut):
np.random.seed(4)
pc = read_chains("./tests/example_data/pc")
truncated_run = pc.truncate(cut)
assert not truncated_run.index.duplicated().any()
if cut is None:
assert_array_equal(pc, truncated_run)


def test_hist_range_1d():
"""Test to provide a solution to #89"""
np.random.seed(3)
ns = read_chains('./tests/example_data/pc')
ax = ns.plot_1d('x0', kind='hist_1d')
x1, x2 = ax['x0'].get_xlim()
assert x1 > -1
assert x2 < +1
ax = ns.plot_1d('x0', kind='hist_1d', bins=np.linspace(-1, 1, 11))
x1, x2 = ax['x0'].get_xlim()
assert x1 <= -1
assert x2 >= +1


def test_contour_plot_2d_nan():
"""Contour plots with nans arising from issue #96"""
np.random.seed(3)
Expand Down
Loading