Skip to content

Commit

Permalink
merge orbital rotations for ucj operators (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung authored Jun 3, 2024
1 parent ec10294 commit ec4fc6c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 10 deletions.
18 changes: 15 additions & 3 deletions python/ffsim/variational/ucj_spin_balanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,22 +436,34 @@ def _apply_unitary_(
return NotImplemented
if copy:
vec = vec.copy()
current_basis = np.eye(norb)
for (diag_coulomb_mat_aa, diag_coulomb_mat_ab), orbital_rotation in zip(
self.diag_coulomb_mats, self.orbital_rotations
):
vec = gates.apply_orbital_rotation(
vec,
orbital_rotation.T.conj() @ current_basis,
norb=norb,
nelec=nelec,
copy=False,
)
vec = gates.apply_diag_coulomb_evolution(
vec,
(diag_coulomb_mat_aa, diag_coulomb_mat_ab, diag_coulomb_mat_aa),
time=-1.0,
norb=norb,
nelec=nelec,
orbital_rotation=orbital_rotation,
copy=False,
)
if self.final_orbital_rotation is not None:
current_basis = orbital_rotation
if self.final_orbital_rotation is None:
vec = gates.apply_orbital_rotation(
vec, current_basis, norb=norb, nelec=nelec, copy=False
)
else:
vec = gates.apply_orbital_rotation(
vec,
mat=self.final_orbital_rotation,
self.final_orbital_rotation @ current_basis,
norb=norb,
nelec=nelec,
copy=False,
Expand Down
19 changes: 16 additions & 3 deletions python/ffsim/variational/ucj_spin_unbalanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,22 +565,35 @@ def _apply_unitary_(
return NotImplemented
if copy:
vec = vec.copy()
eye = np.eye(norb)
current_basis = np.stack([eye, eye])
for diag_coulomb_mat, orbital_rotation in zip(
self.diag_coulomb_mats, self.orbital_rotations
):
vec = gates.apply_orbital_rotation(
vec,
orbital_rotation.transpose(0, 2, 1).conj() @ current_basis,
norb=norb,
nelec=nelec,
copy=False,
)
vec = gates.apply_diag_coulomb_evolution(
vec,
diag_coulomb_mat,
time=-1.0,
norb=norb,
nelec=nelec,
orbital_rotation=orbital_rotation,
copy=False,
)
if self.final_orbital_rotation is not None:
current_basis = orbital_rotation
if self.final_orbital_rotation is None:
vec = gates.apply_orbital_rotation(
vec, current_basis, norb=norb, nelec=nelec, copy=False
)
else:
vec = gates.apply_orbital_rotation(
vec,
mat=self.final_orbital_rotation,
self.final_orbital_rotation @ current_basis,
norb=norb,
nelec=nelec,
copy=False,
Expand Down
27 changes: 23 additions & 4 deletions python/ffsim/variational/ucj_spinless.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,37 +363,56 @@ def _apply_unitary_(
) -> np.ndarray:
if copy:
vec = vec.copy()
current_basis = np.eye(norb)
if isinstance(nelec, int):
for diag_coulomb_mat, orbital_rotation in zip(
self.diag_coulomb_mats, self.orbital_rotations
):
vec = gates.apply_orbital_rotation(
vec,
orbital_rotation.T.conj() @ current_basis,
norb=norb,
nelec=nelec,
copy=False,
)
vec = gates.apply_diag_coulomb_evolution(
vec,
diag_coulomb_mat,
time=-1.0,
norb=norb,
nelec=nelec,
orbital_rotation=orbital_rotation,
copy=False,
)
current_basis = orbital_rotation
else:
zero = np.zeros((norb, norb))
for diag_coulomb_mat, orbital_rotation in zip(
self.diag_coulomb_mats, self.orbital_rotations
):
vec = gates.apply_orbital_rotation(
vec,
orbital_rotation.T.conj() @ current_basis,
norb=norb,
nelec=nelec,
copy=False,
)
vec = gates.apply_diag_coulomb_evolution(
vec,
(diag_coulomb_mat, zero, diag_coulomb_mat),
time=-1.0,
norb=norb,
nelec=nelec,
orbital_rotation=orbital_rotation,
copy=False,
)
if self.final_orbital_rotation is not None:
current_basis = orbital_rotation
if self.final_orbital_rotation is None:
vec = gates.apply_orbital_rotation(
vec, current_basis, norb=norb, nelec=nelec, copy=False
)
else:
vec = gates.apply_orbital_rotation(
vec,
mat=self.final_orbital_rotation,
self.final_orbital_rotation @ current_basis,
norb=norb,
nelec=nelec,
copy=False,
Expand Down
8 changes: 8 additions & 0 deletions tests/python/variational/ucj_spinless_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ def test_t_amplitudes_energy():
norb=norb,
nelec=nelec,
)
energy_alt, _ = ffsim.multireference_state(
mol_hamiltonian,
operator,
reference_occupations,
norb=norb,
nelec=nelec,
)
np.testing.assert_allclose(energy, energy_alt)
np.testing.assert_allclose(energy, -108.519714)


Expand Down

0 comments on commit ec4fc6c

Please sign in to comment.