Skip to content

Commit

Permalink
Merge branch 'master' into logL_birth_inf
Browse files Browse the repository at this point in the history
  • Loading branch information
lukashergt authored Apr 8, 2024
2 parents 233777b + ef8a27e commit 6a05edf
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
11 changes: 10 additions & 1 deletion anesthetic/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from matplotlib.axes import Axes
import matplotlib.cbook as cbook
import matplotlib.lines as mlines
from matplotlib.ticker import MaxNLocator, AutoMinorLocator
from matplotlib.ticker import MaxNLocator, AutoMinorLocator, LogLocator
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.transforms import Affine2D
from anesthetic.utils import nest_level
Expand Down Expand Up @@ -388,6 +388,15 @@ def _set_scale(self):
if y in self._logy:
ax.set_yscale('log')

def _set_logticks(self):
for y, rows in self.iterrows():
for x, ax in rows.items():
if ax is not None:
if x in self._logx:
ax.xaxis.set_major_locator(LogLocator(numticks=3))
if y in self._logy:
ax.yaxis.set_major_locator(LogLocator(numticks=3))

@staticmethod
def _set_labels(axes, labels, **kwargs):
all_params = list(axes.columns) + list(axes.index)
Expand Down
2 changes: 2 additions & 0 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def plot_2d(self, axes=None, *args, **kwargs):
else:
ax.plot([], [])

axes._set_logticks()

return axes

plot_2d_default_kinds = {
Expand Down
19 changes: 19 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,25 @@ def test_plot_logscale_2d(kind):
assert ax.twin.get_yscale() == 'linear'


def test_logscale_ticks():
np.random.seed(42)
ndim = 5
data = np.exp(10 * np.random.randn(200, ndim))
params = [f'a{i}' for i in range(ndim)]
fig, axes = make_2d_axes(params, logx=params, logy=params, upper=False)
samples = Samples(data, columns=params)
samples.plot_2d(axes)
for _, col in axes.iterrows():
for _, ax in col.items():
if ax is not None:
xlims = ax.get_xlim()
xticks = ax.get_xticks()
assert np.sum((xticks > xlims[0]) & (xticks < xlims[1])) > 1
ylims = ax.get_ylim()
yticks = ax.get_yticks()
assert np.sum((yticks > ylims[0]) & (yticks < ylims[1])) > 1


@pytest.mark.parametrize('k', ['hist_1d', 'hist'])
@pytest.mark.parametrize('b', ['scott', 10, np.logspace(-3, 0, 20)])
@pytest.mark.parametrize('r', [None, (1e-5, 1)])
Expand Down

0 comments on commit 6a05edf

Please sign in to comment.