-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_utils.py
207 lines (191 loc) · 15.5 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import numpy as np
import pytest
from qcelemental.testing import compare_recursive, compare_values
import qcmanybody as qcmb
@pytest.fixture(scope="function")
def mbe_data():
henehh = Molecule(symbols=["He", "Ne", "H", "H"], fragments=[[0], [1], [2, 3]], geometry=[0, 0, 0, 2, 0, 0, 0, 1, 0, 0, -1, 0])
return henehh
f3grads = {
"full": np.array([
[ 0.598726, 0. , 0. ],
[-0.960726, 0. , 0. ],
[ 0.181 , -0.858448, 0. ],
[ 0.181 , 0.858448, 0. ],
]),
"b13": np.array([
[ 0.598726, 0. , 0. ],
#[-0.960726, 0. , 0. ],
[ 0.181 , -0.858448, 0. ],
[ 0.181 , 0.858448, 0. ],
]),
"b13_": np.array([
[ 0.598726, 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0.181 , -0.858448, 0. ],
[ 0.181 , 0.858448, 0. ],
]),
"b3": np.array([
#[ 0.598726, 0. , 0. ],
#[-0.960726, 0. , 0. ],
[ 0.181 , -0.858448, 0. ],
[ 0.181 , 0.858448, 0. ],
]),
"b3_": np.array([
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0.181 , -0.858448, 0. ],
[ 0.181 , 0.858448, 0. ],
]),
}
f3hesses = {
"full": np.array([
[ 0.035028, 0. , 0. , -1.398945, 0. , -0. , 0.681958, 0.052467, -0. , 0.681958, -0.052467, 0. ],
[ 0. , 3.339155, 0. , 0. , 0.120345, 0. , -0.179019, -1.72975 , -0. , 0.179019, -1.72975 , -0. ],
[ 0. , 0. , -1.232377, 0. , 0. , 0.299363, 0. , -0. , 0.466507, -0. , -0. , 0.466507],
[-1.398945, 0. , 0. , 2.147922, 0. , 0. , -0.374488, 0.060761, 0. , -0.374488, -0.060761, -0. ],
[ 0. , 0.120345, 0. , 0. , -0.252472, 0. , 0.227891, 0.066064, -0. , -0.227891, 0.066064, -0. ],
[-0. , 0. , 0.299363, 0. , 0. , -0.480363, 0. , -0. , 0.0905 , 0. , -0. , 0.0905 ],
[ 0.681958, -0.179019, 0. , -0.374488, 0.227891, 0. , -0.355068, -0.08105 , -0. , 0.047598, 0.032178, -0. ],
[ 0.052467, -1.72975 , -0. , 0.060761, 0.066064, -0. , -0.08105 , 2.276729, 0. , -0.032178, -0.613043, -0. ],
[-0. , -0. , 0.466507, 0. , -0. , 0.0905 , -0. , 0. , -0.707728, 0. , -0. , 0.150721],
[ 0.681958, 0.179019, -0. , -0.374488, -0.227891, 0. , 0.047598, -0.032178, 0. , -0.355068, 0.08105 , 0. ],
[-0.052467, -1.72975 , -0. , -0.060761, 0.066064, -0. , 0.032178, -0.613043, -0. , 0.08105 , 2.276729, 0. ],
[ 0. , -0. , 0.466507, -0. , -0. , 0.0905 , -0. , -0. , 0.150721, 0. , 0. , -0.707728],
]),
"b13": np.array([
[ 0.035028, 0. , 0. , 0.681958, 0.052467, -0. , 0.681958, -0.052467, 0. ],
[ 0. , 3.339155, 0. , -0.179019, -1.72975 , -0. , 0.179019, -1.72975 , -0. ],
[ 0. , 0. , -1.232377, 0. , -0. , 0.466507, -0. , -0. , 0.466507],
#[-1.398945, 0. , 0. , -0.374488, 0.060761, 0. , -0.374488, -0.060761, -0. ],
#[ 0. , 0.120345, 0. , 0.227891, 0.066064, -0. , -0.227891, 0.066064, -0. ],
#[-0. , 0. , 0.299363, 0. , -0. , 0.0905 , 0. , -0. , 0.0905 ],
[ 0.681958, -0.179019, 0. , -0.355068, -0.08105 , -0. , 0.047598, 0.032178, -0. ],
[ 0.052467, -1.72975 , -0. , -0.08105 , 2.276729, 0. , -0.032178, -0.613043, -0. ],
[-0. , -0. , 0.466507, -0. , 0. , -0.707728, 0. , -0. , 0.150721],
[ 0.681958, 0.179019, -0. , 0.047598, -0.032178, 0. , -0.355068, 0.08105 , 0. ],
[-0.052467, -1.72975 , -0. , 0.032178, -0.613043, -0. , 0.08105 , 2.276729, 0. ],
[ 0. , -0. , 0.466507, -0. , -0. , 0.150721, 0. , 0. , -0.707728],
]),
"b13_": np.array([
[ 0.035028, 0. , 0. , 0. , 0. , -0. , 0.681958, 0.052467, -0. , 0.681958, -0.052467, 0. ],
[ 0. , 3.339155, 0. , 0. , 0. , 0. , -0.179019, -1.72975 , -0. , 0.179019, -1.72975 , -0. ],
[ 0. , 0. , -1.232377, 0. , 0. , 0. , 0. , -0. , 0.466507, -0. , -0. , 0.466507],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0.681958, -0.179019, 0. , 0. , 0. , 0. , -0.355068, -0.08105 , -0. , 0.047598, 0.032178, -0. ],
[ 0.052467, -1.72975 , -0. , 0. , 0. , -0. , -0.08105 , 2.276729, 0. , -0.032178, -0.613043, -0. ],
[-0. , -0. , 0.466507, 0. , -0. , 0. , -0. , 0. , -0.707728, 0. , -0. , 0.150721],
[ 0.681958, 0.179019, -0. , -0. , 0. , 0. , 0.047598, -0.032178, 0. , -0.355068, 0.08105 , 0. ],
[-0.052467, -1.72975 , -0. , -0. , 0. , -0. , 0.032178, -0.613043, -0. , 0.08105 , 2.276729, 0. ],
[ 0. , -0. , 0.466507, -0. , -0. , 0. , -0. , -0. , 0.150721, 0. , 0. , -0.707728],
]),
"b3": np.array([
#[ 0.681958, 0.052467, -0. , 0.681958, -0.052467, 0. ],
#[ -0.179019, -1.72975 , -0. , 0.179019, -1.72975 , -0. ],
#[ 0. , -0. , 0.466507, -0. , -0. , 0.466507],
#[ -0.374488, 0.060761, 0. , -0.374488, -0.060761, -0. ],
#[ 0.227891, 0.066064, -0. , -0.227891, 0.066064, -0. ],
#[ 0. , -0. , 0.0905 , 0. , -0. , 0.0905 ],
[ -0.355068, -0.08105 , -0. , 0.047598, 0.032178, -0. ],
[ -0.08105 , 2.276729, 0. , -0.032178, -0.613043, -0. ],
[ -0. , 0. , -0.707728, 0. , -0. , 0.150721],
[ 0.047598, -0.032178, 0. , -0.355068, 0.08105 , 0. ],
[ 0.032178, -0.613043, -0. , 0.08105 , 2.276729, 0. ],
[ -0. , -0. , 0.150721, 0. , 0. , -0.707728],
]),
"b3_": np.array([
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , -0.355068, -0.08105 , -0. , 0.047598, 0.032178, -0. ],
[ 0. , -0. , -0. , 0. , 0. , -0. , -0.08105 , 2.276729, 0. , -0.032178, -0.613043, -0. ],
[-0. , -0. , 0. , 0. , -0. , 0. , -0. , 0. , -0.707728, 0. , -0. , 0.150721],
[ 0. , 0. , -0. , 0. , -0. , 0. , 0.047598, -0.032178, 0. , -0.355068, 0.08105 , 0. ],
[-0. , -0. , -0. , 0. , 0. , -0. , 0.032178, -0.613043, -0. , 0.08105 , 2.276729, 0. ],
[ 0. , -0. , 0. , -0. , -0. , 0. , -0. , -0. , 0.150721, 0. , 0. , -0.707728],
]),
}
@pytest.mark.parametrize("gin,bas,reverse,gans", [
pytest.param(f3grads["full"], [1, 2, 3], False, f3grads["full"]), # idempotent
pytest.param(f3grads["full"], [1, 2, 3], True, f3grads["full"]), # idempotent
pytest.param(f3grads["b13"], [1, 3], False, f3grads["b13_"]),
pytest.param(f3grads["b13_"], [1, 3], True, f3grads["b13"]),
pytest.param(f3grads["full"], [1, 3], True, f3grads["b13"]),
pytest.param(f3grads["b3"], [3], False, f3grads["b3_"]),
pytest.param(f3grads["b3_"], [3], True, f3grads["b3"]),
pytest.param(f3grads["full"], [3], True, f3grads["b3"]),
pytest.param(f3grads["full"], [], False, np.zeros((4, 3))), # zero
pytest.param(f3grads["full"], [], True, np.zeros((0, 3))), # collapse
])
def test_resize_gradient(gin, bas, reverse, gans):
gout = qcmb.resize_gradient(gin, bas, {1: 1, 2: 1, 3: 2}, {1: slice(0, 1), 2: slice(1, 2), 3: slice(2, 4)}, reverse=reverse)
assert compare_values(gans, gout, atol=1e-5, label="resize_gradient")
@pytest.mark.parametrize("hin,bas,reverse,hans", [
pytest.param(f3hesses["full"], [1, 2, 3], False, f3hesses["full"]), # idempotent
pytest.param(f3hesses["full"], [1, 2, 3], True, f3hesses["full"]), # idempotent
pytest.param(f3hesses["b13"], [1, 3], False, f3hesses["b13_"]),
pytest.param(f3hesses["b13_"], [1, 3], True, f3hesses["b13"]),
pytest.param(f3hesses["full"], [1, 3], True, f3hesses["b13"]),
pytest.param(f3hesses["b3"], [3], False, f3hesses["b3_"]),
pytest.param(f3hesses["b3_"], [3], True, f3hesses["b3"]),
pytest.param(f3hesses["full"], [3], True, f3hesses["b3"]),
pytest.param(f3hesses["full"], [], False, np.zeros((12, 12))), # zero
pytest.param(f3hesses["full"], [], True, np.zeros((0, 0))), # collapse
])
def test_resize_hessian(hin, bas, reverse, hans):
hout = qcmb.resize_hessian(hin, bas, {1: 1, 2: 1, 3: 2}, {1: slice(0, 1), 2: slice(1, 2), 3: slice(2, 4)}, reverse=reverse)
assert compare_values(hans, hout, atol=1e-5, label="resize_hessian")
@pytest.mark.parametrize("mc,frag,bas,opq,ans", [
pytest.param("mp2", 1, (1, 2), True, '["mp2", [1], [1, 2]]'),
pytest.param("mp2", (1,), (1, 2), True, '["mp2", [1], [1, 2]]'),
pytest.param("mp2", [1], [1, 2], True, '["mp2", [1], [1, 2]]'),
pytest.param("mp2", (1,2), (1, 2), True, '["mp2", [1, 2], [1, 2]]'),
pytest.param("1", 2, 2, True, '["1", [2], [2]]'),
pytest.param(1, 2, 2, True, '["1", [2], [2]]'),
pytest.param("mp2", 1, (1, 2), False, '§mp2_(1)@(1, 2)'),
pytest.param("mp2", (1,), (1, 2), False, '§mp2_(1)@(1, 2)'),
pytest.param("mp2", [1], [1, 2], False, '§mp2_(1)@(1, 2)'),
pytest.param("mp2", (1,2), (1, 2), False, '§mp2_(1, 2)@(1, 2)'),
pytest.param("1", 2, 2, False, '§1_(2)@(2)'),
pytest.param(1, 2, 22, False, '§1_(2)@(22)'),
pytest.param(None, (1,2), (1, 2), False, '(1, 2)@(1, 2)'), # !opaque-only
])
def test_labeler(mc, frag, bas, opq, ans):
lbl = qcmb.labeler(mc, frag, bas, opaque=opq)
assert lbl == ans, f"{lbl} != {ans}"
@pytest.mark.parametrize("lbl,mc_ans,frag_ans,bas_ans", [
pytest.param('["mp2", [1], [1, 2]]', "mp2", [1], [1, 2]), # x3
pytest.param('["mp2", [1, 2], [1, 2]]', "mp2", [1, 2], [1, 2]),
pytest.param('["1", [2], [2]]', "1", [2], [2]), # x2
pytest.param('["mp2", 1, [1, 2]]', 'mp2', 1, [1, 2]), # not usual
pytest.param('§mp2_(1)@(1, 2)', "mp2", [1], [1, 2]), # x3
pytest.param('§mp2_(1, 2)@(1, 2)', "mp2", [1, 2], [1, 2]),
pytest.param('§1_(2)@(2)', "1", [2], [2]), # x2
pytest.param('(1, 2)@(1, 2)', None, [1, 2], [1, 2]), # !opaque-only
])
def test_delabeler(lbl, mc_ans, frag_ans, bas_ans):
mc, frag, bas = qcmb.delabeler(lbl)
assert mc == mc_ans, f"{mc} != {mc_ans}"
assert frag == frag_ans, f"{frag} != {frag_ans}"
assert bas == bas_ans, f"{bas} != {bas_ans}"
@pytest.mark.parametrize("nbmc,ans", [
pytest.param({"hi": [2, 1], "lo": [3]}, {1: ("hi", "§A", "§12"), 2: ("hi", "§A", "§12"), 3: ("lo", "§B", "§3")}),
pytest.param({"lo": [1, 2, 3], "hi": [5, 6], "md": [4]}, {1: ('lo', '§A', '§123'), 2: ('lo', '§A', '§123'), 3: ('lo', '§A', '§123'), 4: ("md", "§B", "§4"), 5: ('hi', '§C', '§56'), 6: ('hi', '§C', '§56')}),
pytest.param({"md": [4], "hi": [5, 6], "lo": [1, 2, 3]}, {1: ('lo', '§A', '§123'), 2: ('lo', '§A', '§123'), 3: ('lo', '§A', '§123'), 4: ("md", "§B", "§4"), 5: ('hi', '§C', '§56'), 6: ('hi', '§C', '§56')}),
pytest.param({"hi": [5, 6], "md": [4], "lo": [1, 2, 3]}, {1: ('lo', '§A', '§123'), 2: ('lo', '§A', '§123'), 3: ('lo', '§A', '§123'), 4: ("md", "§B", "§4"), 5: ('hi', '§C', '§56'), 6: ('hi', '§C', '§56')}),
pytest.param({"md": [3, 4], "hi": [1, 2, 5], "lo": [6, 7, 8, 9, 10, 11]}, {1: ("hi", "§A", "§125"), 2: ("hi", "§A", "§125"), 3: ("md", "§B", "§34"), 4: ("md", "§B", "§34"), 5: ("hi", "§A", "§125"), 6: ("lo", "§C", "§<11"), 7: ("lo", "§C", "§<11"), 8: ("lo", "§C", "§<11"), 9: ("lo", "§C", "§<11"), 10: ("lo", "§C", "§<11"), 11: ("lo", "§C", "§<11")}),
pytest.param({"md": [3, 4], "hi": [5, 2, 1], "lo": [10, 11, 9, 6, 7, 8]}, {1: ("hi", "§A", "§125"), 2: ("hi", "§A", "§125"), 3: ("md", "§B", "§34"), 4: ("md", "§B", "§34"), 5: ("hi", "§A", "§125"), 6: ("lo", "§C", "§<11"), 7: ("lo", "§C", "§<11"), 8: ("lo", "§C", "§<11"), 9: ("lo", "§C", "§<11"), 10: ("lo", "§C", "§<11"), 11: ("lo", "§C", "§<11")}),
pytest.param({'hf/6-31g': ['supersystem'], 'ccsd/6-31g': [1], 'mp2/6-31g': [2]}, {1: ("ccsd/6-31g", "§A", "§1"), 2: ("mp2/6-31g", "§B", "§2"), "supersystem": ('hf/6-31g', '§C', '§SS')}),
pytest.param({'p4-ccsd': [1], 'p4-mp2': [2, 3], 'p4-hf': [4]}, {1: ('p4-ccsd', '§A', '§1'), 2: ('p4-mp2', '§B', '§23'), 3: ('p4-mp2', '§B', '§23'), 4: ('p4-hf', '§C', '§4')}),
pytest.param({'hi': [1, 2, 3], 'md': [4], 'md2': [5, 6, 7, 8, 9, 10], 'lo': ['supersystem']},
{1: ('hi', '§A', '§123'), 2: ('hi', '§A', '§123'), 3: ('hi', '§A', '§123'), 4: ('md', '§B', '§4'), 5: ('md2', '§C', '§<10'), 6: ('md2', '§C', '§<10'), 7: ('md2', '§C', '§<10'), 8: ('md2', '§C', '§<10'), 9: ('md2', '§C', '§<10'), 10: ('md2', '§C', '§<10'), "supersystem": ('lo', '§D', '§SS')}),
])
def test_modelchem_labels(nbmc, ans):
res = qcmb.utils.modelchem_labels(nbmc)
print(res)
assert compare_recursive(res, ans)