Skip to content

Commit

Permalink
test resize fns
Browse files Browse the repository at this point in the history
  • Loading branch information
loriab committed Apr 16, 2024
1 parent 330a0b2 commit e6b3ba6
Showing 1 changed file with 154 additions and 0 deletions.
154 changes: 154 additions & 0 deletions qcmanybody/models/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import pytest
import numpy as np

from qcelemental.testing import compare_values

import qcmanybody as qcmb

@pytest.fixture(scope="function")
def mbe_data():
henehh = Molecule(symbols=["He", "Ne", "H", "H"], fragments=[[0], [1], [2, 3]], geometry=[0, 0, 0, 2, 0, 0, 0, 1, 0, 0, -1, 0])
return henehh

f3grads = {
"full": np.array([
[ 0.598726, 0. , 0. ],
[-0.960726, 0. , 0. ],
[ 0.181 , -0.858448, 0. ],
[ 0.181 , 0.858448, 0. ],
]),
"b13": np.array([
[ 0.598726, 0. , 0. ],
#[-0.960726, 0. , 0. ],
[ 0.181 , -0.858448, 0. ],
[ 0.181 , 0.858448, 0. ],
]),
"b13_": np.array([
[ 0.598726, 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0.181 , -0.858448, 0. ],
[ 0.181 , 0.858448, 0. ],
]),
"b3": np.array([
#[ 0.598726, 0. , 0. ],
#[-0.960726, 0. , 0. ],
[ 0.181 , -0.858448, 0. ],
[ 0.181 , 0.858448, 0. ],
]),
"b3_": np.array([
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0.181 , -0.858448, 0. ],
[ 0.181 , 0.858448, 0. ],
]),
}

f3hesses = {
"full": np.array([
[ 0.035028, 0. , 0. , -1.398945, 0. , -0. , 0.681958, 0.052467, -0. , 0.681958, -0.052467, 0. ],
[ 0. , 3.339155, 0. , 0. , 0.120345, 0. , -0.179019, -1.72975 , -0. , 0.179019, -1.72975 , -0. ],
[ 0. , 0. , -1.232377, 0. , 0. , 0.299363, 0. , -0. , 0.466507, -0. , -0. , 0.466507],
[-1.398945, 0. , 0. , 2.147922, 0. , 0. , -0.374488, 0.060761, 0. , -0.374488, -0.060761, -0. ],
[ 0. , 0.120345, 0. , 0. , -0.252472, 0. , 0.227891, 0.066064, -0. , -0.227891, 0.066064, -0. ],
[-0. , 0. , 0.299363, 0. , 0. , -0.480363, 0. , -0. , 0.0905 , 0. , -0. , 0.0905 ],
[ 0.681958, -0.179019, 0. , -0.374488, 0.227891, 0. , -0.355068, -0.08105 , -0. , 0.047598, 0.032178, -0. ],
[ 0.052467, -1.72975 , -0. , 0.060761, 0.066064, -0. , -0.08105 , 2.276729, 0. , -0.032178, -0.613043, -0. ],
[-0. , -0. , 0.466507, 0. , -0. , 0.0905 , -0. , 0. , -0.707728, 0. , -0. , 0.150721],
[ 0.681958, 0.179019, -0. , -0.374488, -0.227891, 0. , 0.047598, -0.032178, 0. , -0.355068, 0.08105 , 0. ],
[-0.052467, -1.72975 , -0. , -0.060761, 0.066064, -0. , 0.032178, -0.613043, -0. , 0.08105 , 2.276729, 0. ],
[ 0. , -0. , 0.466507, -0. , -0. , 0.0905 , -0. , -0. , 0.150721, 0. , 0. , -0.707728],
]),
"b13": np.array([
[ 0.035028, 0. , 0. , 0.681958, 0.052467, -0. , 0.681958, -0.052467, 0. ],
[ 0. , 3.339155, 0. , -0.179019, -1.72975 , -0. , 0.179019, -1.72975 , -0. ],
[ 0. , 0. , -1.232377, 0. , -0. , 0.466507, -0. , -0. , 0.466507],
#[-1.398945, 0. , 0. , -0.374488, 0.060761, 0. , -0.374488, -0.060761, -0. ],
#[ 0. , 0.120345, 0. , 0.227891, 0.066064, -0. , -0.227891, 0.066064, -0. ],
#[-0. , 0. , 0.299363, 0. , -0. , 0.0905 , 0. , -0. , 0.0905 ],
[ 0.681958, -0.179019, 0. , -0.355068, -0.08105 , -0. , 0.047598, 0.032178, -0. ],
[ 0.052467, -1.72975 , -0. , -0.08105 , 2.276729, 0. , -0.032178, -0.613043, -0. ],
[-0. , -0. , 0.466507, -0. , 0. , -0.707728, 0. , -0. , 0.150721],
[ 0.681958, 0.179019, -0. , 0.047598, -0.032178, 0. , -0.355068, 0.08105 , 0. ],
[-0.052467, -1.72975 , -0. , 0.032178, -0.613043, -0. , 0.08105 , 2.276729, 0. ],
[ 0. , -0. , 0.466507, -0. , -0. , 0.150721, 0. , 0. , -0.707728],
]),
"b13_": np.array([
[ 0.035028, 0. , 0. , 0. , 0. , -0. , 0.681958, 0.052467, -0. , 0.681958, -0.052467, 0. ],
[ 0. , 3.339155, 0. , 0. , 0. , 0. , -0.179019, -1.72975 , -0. , 0.179019, -1.72975 , -0. ],
[ 0. , 0. , -1.232377, 0. , 0. , 0. , 0. , -0. , 0.466507, -0. , -0. , 0.466507],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0.681958, -0.179019, 0. , 0. , 0. , 0. , -0.355068, -0.08105 , -0. , 0.047598, 0.032178, -0. ],
[ 0.052467, -1.72975 , -0. , 0. , 0. , -0. , -0.08105 , 2.276729, 0. , -0.032178, -0.613043, -0. ],
[-0. , -0. , 0.466507, 0. , -0. , 0. , -0. , 0. , -0.707728, 0. , -0. , 0.150721],
[ 0.681958, 0.179019, -0. , -0. , 0. , 0. , 0.047598, -0.032178, 0. , -0.355068, 0.08105 , 0. ],
[-0.052467, -1.72975 , -0. , -0. , 0. , -0. , 0.032178, -0.613043, -0. , 0.08105 , 2.276729, 0. ],
[ 0. , -0. , 0.466507, -0. , -0. , 0. , -0. , -0. , 0.150721, 0. , 0. , -0.707728],

]),
"b3": np.array([
#[ 0.681958, 0.052467, -0. , 0.681958, -0.052467, 0. ],
#[ -0.179019, -1.72975 , -0. , 0.179019, -1.72975 , -0. ],
#[ 0. , -0. , 0.466507, -0. , -0. , 0.466507],
#[ -0.374488, 0.060761, 0. , -0.374488, -0.060761, -0. ],
#[ 0.227891, 0.066064, -0. , -0.227891, 0.066064, -0. ],
#[ 0. , -0. , 0.0905 , 0. , -0. , 0.0905 ],
[ -0.355068, -0.08105 , -0. , 0.047598, 0.032178, -0. ],
[ -0.08105 , 2.276729, 0. , -0.032178, -0.613043, -0. ],
[ -0. , 0. , -0.707728, 0. , -0. , 0.150721],
[ 0.047598, -0.032178, 0. , -0.355068, 0.08105 , 0. ],
[ 0.032178, -0.613043, -0. , 0.08105 , 2.276729, 0. ],
[ -0. , -0. , 0.150721, 0. , 0. , -0.707728],

]),
"b3_": np.array([
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , -0.355068, -0.08105 , -0. , 0.047598, 0.032178, -0. ],
[ 0. , -0. , -0. , 0. , 0. , -0. , -0.08105 , 2.276729, 0. , -0.032178, -0.613043, -0. ],
[-0. , -0. , 0. , 0. , -0. , 0. , -0. , 0. , -0.707728, 0. , -0. , 0.150721],
[ 0. , 0. , -0. , 0. , -0. , 0. , 0.047598, -0.032178, 0. , -0.355068, 0.08105 , 0. ],
[-0. , -0. , -0. , 0. , 0. , -0. , 0.032178, -0.613043, -0. , 0.08105 , 2.276729, 0. ],
[ 0. , -0. , 0. , -0. , -0. , 0. , -0. , -0. , 0.150721, 0. , 0. , -0.707728],
]),
}


@pytest.mark.parametrize("gin,bas,reverse,gans", [
pytest.param(f3grads["full"], [1, 2, 3], False, f3grads["full"]), # idempotent
pytest.param(f3grads["full"], [1, 2, 3], True, f3grads["full"]), # idempotent
pytest.param(f3grads["b13"], [1, 3], False, f3grads["b13_"]),
pytest.param(f3grads["b13_"], [1, 3], True, f3grads["b13"]),
pytest.param(f3grads["full"], [1, 3], True, f3grads["b13"]),
pytest.param(f3grads["b3"], [3], False, f3grads["b3_"]),
pytest.param(f3grads["b3_"], [3], True, f3grads["b3"]),
pytest.param(f3grads["full"], [3], True, f3grads["b3"]),
pytest.param(f3grads["full"], [], False, np.zeros((4, 3))), # zero
pytest.param(f3grads["full"], [], True, np.zeros((0, 3))), # collapse
])
def test_resize_gradient(gin, bas, reverse, gans):
gout = qcmb.resize_gradient(gin, bas, {1: 1, 2: 1, 3: 2}, {1: slice(0, 1), 2: slice(1, 2), 3: slice(2, 4)}, reverse=reverse)
assert compare_values(gans, gout, atol=1e-5, label="resize_gradient")


@pytest.mark.parametrize("hin,bas,reverse,hans", [
pytest.param(f3hesses["full"], [1, 2, 3], False, f3hesses["full"]), # idempotent
pytest.param(f3hesses["full"], [1, 2, 3], True, f3hesses["full"]), # idempotent
pytest.param(f3hesses["b13"], [1, 3], False, f3hesses["b13_"]),
pytest.param(f3hesses["b13_"], [1, 3], True, f3hesses["b13"]),
pytest.param(f3hesses["full"], [1, 3], True, f3hesses["b13"]),
pytest.param(f3hesses["b3"], [3], False, f3hesses["b3_"]),
pytest.param(f3hesses["b3_"], [3], True, f3hesses["b3"]),
pytest.param(f3hesses["full"], [3], True, f3hesses["b3"]),
pytest.param(f3hesses["full"], [], False, np.zeros((12, 12))), # zero
pytest.param(f3hesses["full"], [], True, np.zeros((0, 0))), # collapse
])
def test_resize_hessian(hin, bas, reverse, hans):
hout = qcmb.resize_hessian(hin, bas, {1: 1, 2: 1, 3: 2}, {1: slice(0, 1), 2: slice(1, 2), 3: slice(2, 4)}, reverse=reverse)
assert compare_values(hans, hout, atol=1e-5, label="resize_hessian")

0 comments on commit e6b3ba6

Please sign in to comment.