Skip to content

Commit

Permalink
Update UQ tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mjwen committed Jul 30, 2023
1 parent f17be91 commit 2052da2
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 56 deletions.
2 changes: 0 additions & 2 deletions tests/descriptors/test_bispectrum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

import numpy as np

from kliff.dataset import Configuration
Expand Down
1 change: 0 additions & 1 deletion tests/descriptors/test_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
import os

import numpy as np

Expand Down
1 change: 0 additions & 1 deletion tests/descriptors/test_symmetry_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""Test symmetry functions values."""
import itertools
from collections import OrderedDict

Expand Down
2 changes: 1 addition & 1 deletion tests/test_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_neigh(test_data_dir):

def test_1D():
"""
Simple test of a dimer, nonperiodic.
Simple test of a dimer, non-periodic.
"""

cell = np.asarray([[200.0, 0.0, 0.0], [0.0, 200.0, 0.0], [0.0, 0.0, 200.0]])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def func(x, y, z=1):
return x + y + z


def test_main():
def test_parmap():
X = range(3)
Y = range(3)
Xp2 = [x + 2 for x in X]
Expand Down
2 changes: 0 additions & 2 deletions tests/test_scipy_optimize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

import numpy as np

from kliff.calculators import Calculator
Expand Down
2 changes: 1 addition & 1 deletion tests/uq/test_bootstrap_empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# training set
FILE_DIR = Path(__file__).absolute().parent # Directory of test file
path = FILE_DIR.parent.joinpath("configs_extxyz/Si_4")
path = FILE_DIR.parent.joinpath("test_data/configs/Si_4")
data = Dataset(path)
configs = data.get_configs()

Expand Down
2 changes: 1 addition & 1 deletion tests/uq/test_bootstrap_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

# training set
FILE_DIR = Path(__file__).absolute().parent # Directory of test file
path = FILE_DIR.parent.joinpath("configs_extxyz/Si_4")
path = FILE_DIR.parent.joinpath("test_data/configs/Si_4")
data = Dataset(path)
configs = data.get_configs()

Expand Down
2 changes: 1 addition & 1 deletion tests/uq/test_bootstrap_nn_separate_species.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

# training set
FILE_DIR = Path(__file__).absolute().parent # Directory of test file
path = FILE_DIR.parent.joinpath("configs_extxyz/SiC_4")
path = FILE_DIR.parent.joinpath("test_data/configs/SiC_4")
data = Dataset(path)
configs = data.get_configs()

Expand Down
88 changes: 43 additions & 45 deletions tests/uq/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from kliff.uq.mcmc import PtemceeSampler

ptemcee_avail = True
except ModuleNotFoundError:
except ImportError:
ptemcee_avail = False

try:
from kliff.uq.mcmc import EmceeSampler

emcee_avail = True
except ModuleNotFoundError:
except ImportError:
emcee_avail = False


Expand Down Expand Up @@ -95,71 +95,69 @@ def test_T0():
"""
# Using internal function
T0_internal = get_T0(loss)

# Compute manually
xopt = calc.get_opt_params()
T0_manual = 2 * loss._get_loss(xopt) / len(xopt)
assert T0_internal == T0_manual, "Internal function to compute T0 doesn't work"


def test_MCMC_wrapper():
@pytest.mark.skipif(not ptemcee_avail, reason="ptemcee is not found")
def test_MCMC_wrapper1():
"""Test if the MCMC wrapper class returns the correct sampler instance."""
if ptemcee_avail:
assert (
type(ptsampler) == PtemceeSampler
), "MCMC should return ``PtemceeSampler`` instance"
if emcee_avail:
assert (
type(sampler) == EmceeSampler
), "MCMC should return ``EmceeSampler`` instance"
assert (
type(ptsampler) == PtemceeSampler
), "MCMC should return ``PtemceeSampler`` instance"


@pytest.mark.skipif(not emcee_avail, reason="emcee is not found")
def test_MCMC_wrapper2():
assert type(sampler) == EmceeSampler, "MCMC should return ``EmceeSampler`` instance"


def test_dimensionality():
@pytest.mark.skipif(not ptemcee_avail, reason="ptemcee is not found")
def test_dimensionality1():
"""Test the number of temperatures, walkers, steps, and parameters. This is done by
comparing the shape of the resulting MCMC chains and the variables used to set these
dimensions.
"""

# Test for ptemcee wrapper
if ptemcee_avail:
p0 = np.random.uniform(0, 10, (ntemps, nwalkers, ndim))
ptsampler.run_mcmc(p0=p0, iterations=nsteps)
assert ptsampler.chain.shape == (
ntemps,
nwalkers,
nsteps,
ndim,
), "Dimensionality from the ptemcee wrapper is not right"
else:
print("Skip testing ptemcee; ptemcee is not found")

p0 = np.random.uniform(0, 10, (ntemps, nwalkers, ndim))
ptsampler.run_mcmc(p0=p0, iterations=nsteps)
assert ptsampler.chain.shape == (
ntemps,
nwalkers,
nsteps,
ndim,
), "Dimensionality from the ptemcee wrapper is not right"


@pytest.mark.skipif(not emcee_avail, reason="emcee is not found")
def test_dimensionality2():
# Test for emcee wrapper
if emcee_avail:
p0 = np.random.uniform(0, 10, (nwalkers, ndim))
sampler.run_mcmc(initial_state=p0, nsteps=nsteps)
assert sampler.get_chain().shape == (
nsteps,
nwalkers,
ndim,
), "Dimensionality from the emcee wrapper is not right"
else:
print("Skip testing emcee; emcee is not found")
p0 = np.random.uniform(0, 10, (nwalkers, ndim))
sampler.run_mcmc(initial_state=p0, nsteps=nsteps)
assert sampler.get_chain().shape == (
nsteps,
nwalkers,
ndim,
), "Dimensionality from the emcee wrapper is not right"


@pytest.mark.skipif(not ptemcee_avail, reason="ptemcee is not found")
def test_pool_exception():
"""Test if an exception is raised when declaring the pool prior to instantiating
``kliff.uq.MCMC``.
"""
if ptemcee_avail:
with pytest.raises(ValueError):
_ = MCMC(
loss,
ntemps=ntemps,
nwalkers=nwalkers,
logprior_args=(prior_bounds,),
pool=Pool(1),
)
else:
print("Skip the test; ptemcee is not found")
with pytest.raises(ValueError):
_ = MCMC(
loss,
ntemps=ntemps,
nwalkers=nwalkers,
logprior_args=(prior_bounds,),
pool=Pool(1),
)


def test_sampler_exception():
Expand Down

0 comments on commit 2052da2

Please sign in to comment.