diff --git a/python/ffsim/variational/ucj_spin_balanced.py b/python/ffsim/variational/ucj_spin_balanced.py index ea2547db9..4bd320181 100644 --- a/python/ffsim/variational/ucj_spin_balanced.py +++ b/python/ffsim/variational/ucj_spin_balanced.py @@ -436,22 +436,34 @@ def _apply_unitary_( return NotImplemented if copy: vec = vec.copy() + current_basis = np.eye(norb) for (diag_coulomb_mat_aa, diag_coulomb_mat_ab), orbital_rotation in zip( self.diag_coulomb_mats, self.orbital_rotations ): + vec = gates.apply_orbital_rotation( + vec, + orbital_rotation.T.conj() @ current_basis, + norb=norb, + nelec=nelec, + copy=False, + ) vec = gates.apply_diag_coulomb_evolution( vec, (diag_coulomb_mat_aa, diag_coulomb_mat_ab, diag_coulomb_mat_aa), time=-1.0, norb=norb, nelec=nelec, - orbital_rotation=orbital_rotation, copy=False, ) - if self.final_orbital_rotation is not None: + current_basis = orbital_rotation + if self.final_orbital_rotation is None: + vec = gates.apply_orbital_rotation( + vec, current_basis, norb=norb, nelec=nelec, copy=False + ) + else: vec = gates.apply_orbital_rotation( vec, - mat=self.final_orbital_rotation, + self.final_orbital_rotation @ current_basis, norb=norb, nelec=nelec, copy=False, diff --git a/python/ffsim/variational/ucj_spin_unbalanced.py b/python/ffsim/variational/ucj_spin_unbalanced.py index 54df27179..898dcfd74 100644 --- a/python/ffsim/variational/ucj_spin_unbalanced.py +++ b/python/ffsim/variational/ucj_spin_unbalanced.py @@ -565,22 +565,35 @@ def _apply_unitary_( return NotImplemented if copy: vec = vec.copy() + eye = np.eye(norb) + current_basis = np.stack([eye, eye]) for diag_coulomb_mat, orbital_rotation in zip( self.diag_coulomb_mats, self.orbital_rotations ): + vec = gates.apply_orbital_rotation( + vec, + orbital_rotation.transpose(0, 2, 1).conj() @ current_basis, + norb=norb, + nelec=nelec, + copy=False, + ) vec = gates.apply_diag_coulomb_evolution( vec, diag_coulomb_mat, time=-1.0, norb=norb, nelec=nelec, - orbital_rotation=orbital_rotation, copy=False, ) - if self.final_orbital_rotation is not None: + current_basis = orbital_rotation + if self.final_orbital_rotation is None: + vec = gates.apply_orbital_rotation( + vec, current_basis, norb=norb, nelec=nelec, copy=False + ) + else: vec = gates.apply_orbital_rotation( vec, - mat=self.final_orbital_rotation, + self.final_orbital_rotation @ current_basis, norb=norb, nelec=nelec, copy=False, diff --git a/python/ffsim/variational/ucj_spinless.py b/python/ffsim/variational/ucj_spinless.py index 76b2e59db..d1926453f 100644 --- a/python/ffsim/variational/ucj_spinless.py +++ b/python/ffsim/variational/ucj_spinless.py @@ -363,37 +363,56 @@ def _apply_unitary_( ) -> np.ndarray: if copy: vec = vec.copy() + current_basis = np.eye(norb) if isinstance(nelec, int): for diag_coulomb_mat, orbital_rotation in zip( self.diag_coulomb_mats, self.orbital_rotations ): + vec = gates.apply_orbital_rotation( + vec, + orbital_rotation.T.conj() @ current_basis, + norb=norb, + nelec=nelec, + copy=False, + ) vec = gates.apply_diag_coulomb_evolution( vec, diag_coulomb_mat, time=-1.0, norb=norb, nelec=nelec, - orbital_rotation=orbital_rotation, copy=False, ) + current_basis = orbital_rotation else: zero = np.zeros((norb, norb)) for diag_coulomb_mat, orbital_rotation in zip( self.diag_coulomb_mats, self.orbital_rotations ): + vec = gates.apply_orbital_rotation( + vec, + orbital_rotation.T.conj() @ current_basis, + norb=norb, + nelec=nelec, + copy=False, + ) vec = gates.apply_diag_coulomb_evolution( vec, (diag_coulomb_mat, zero, diag_coulomb_mat), time=-1.0, norb=norb, nelec=nelec, - orbital_rotation=orbital_rotation, copy=False, ) - if self.final_orbital_rotation is not None: + current_basis = orbital_rotation + if self.final_orbital_rotation is None: + vec = gates.apply_orbital_rotation( + vec, current_basis, norb=norb, nelec=nelec, copy=False + ) + else: vec = gates.apply_orbital_rotation( vec, - mat=self.final_orbital_rotation, + self.final_orbital_rotation @ current_basis, norb=norb, nelec=nelec, copy=False, diff --git a/tests/python/variational/ucj_spinless_test.py b/tests/python/variational/ucj_spinless_test.py index 18d9c8bbe..1bb04e27f 100644 --- a/tests/python/variational/ucj_spinless_test.py +++ b/tests/python/variational/ucj_spinless_test.py @@ -135,6 +135,14 @@ def test_t_amplitudes_energy(): norb=norb, nelec=nelec, ) + energy_alt, _ = ffsim.multireference_state( + mol_hamiltonian, + operator, + reference_occupations, + norb=norb, + nelec=nelec, + ) + np.testing.assert_allclose(energy, energy_alt) np.testing.assert_allclose(energy, -108.519714)