diff --git a/examples/analog/gradients.py b/examples/analog/gradients.py new file mode 100644 index 0000000..c549799 --- /dev/null +++ b/examples/analog/gradients.py @@ -0,0 +1,58 @@ +# This example shows how to calculate gradients +from __future__ import annotations + +import jax +import jax.numpy as jnp +from jax import random +from phaser import simulate +from phaser.models import RydbergHamiltonian +from phaser.utils import init_state + +key = random.PRNGKey(42) + +# Initializing Hamiltonian +n_qubits = 15 +dt, N = 1e-3, 3000 +laser_params = (1.0, 2.0) +U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2) +in_state = init_state(n_qubits) + + +def laser(laser_params, t): + (w_rabi, w_detune) = laser_params + return { + "rabi": 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t), + "detune": 15.0 * jnp.cos(2 * jnp.pi * w_detune * t), + } + + +hamiltonian = RydbergHamiltonian(n_qubits, U) +hamiltonian_params = hamiltonian.init( + key, + in_state, + laser(laser_params, 0), +) + + +# We take the gradient of some random state w.r.t the laser params and interaction_matrix +def forward(laser_params, hamiltonian_params): + out_state = simulate( + hamiltonian, + hamiltonian_params, + laser, + laser_params, + N, + dt, + in_state, + ) + return (jnp.abs(out_state) ** 2).flatten()[-1] + + +# Getting the gradient fn w.r.t. both the pulse and interaction matrix and printing the grads +# Note that we jit (compile) the function so the timing here includes compiling +# but this only needs to happen once +grad_fn = jax.jit(jax.grad(forward, argnums=[0, 1])) +laser_grads, interaction_grads = grad_fn(laser_params, hamiltonian_params) + +print(f"Gradients w.r.t laser params: \n {laser_grads}") +print(f"Gradients w.r.t interaction matrix: \n {interaction_grads}") diff --git a/examples/analog/introduction.py b/examples/analog/introduction.py new file mode 100644 index 0000000..2deaaf8 --- /dev/null +++ b/examples/analog/introduction.py @@ -0,0 +1,88 @@ +# This example shows how to build a model hamiltonian and simulate it. +from __future__ import annotations + +from time import time + +import flax.linen as nn +import jax.numpy as jnp +import numpy as np +from chex import Array +from jax import random +from phaser.hamiltonians import Interaction, Number, Pauli_x +from phaser.propagators import second_order_trotter +from phaser.simulate import simulate +from phaser.utils import init_state, kron_sum + +key = random.PRNGKey(42) + + +class RydbergHamiltonian(nn.Module): + n_qubits: int + U: Array + + def setup(self): + # Rabi terms + H_rabi = [Pauli_x((idx,), None) for idx in np.arange(self.n_qubits)] + + # Detuning terms + H_detune = [Number((idx,), None) for idx in np.arange(self.n_qubits)] + + # Interaction term + # We don't want to learn U here so it's just a matrix + self.U_params = self.U[np.triu_indices_from(self.U, k=1)] + H_interact = [Interaction(idx, None) for idx in zip(*np.triu_indices_from(self.U, k=1))] + + # Joining all terms + self.H = H_rabi + H_detune + H_interact + + def __call__(self, state, weights): + weights = jnp.concatenate([weights["rabi"] / 2, -weights["detune"], self.U_params]) + return kron_sum(self.H, state, weights) + + def evolve(self, state: Array, dt: float, weights: dict): + # Getting weights into same shape + weights = jnp.concatenate([weights["rabi"] / 2, -weights["detune"], self.U_params]) + return second_order_trotter(self.H, state, dt, weights) + + +# Initializing Hamiltonian +n_qubits = 15 +dt, N = 1e-3, 3000 +laser_params = (1.0, 2.0) +U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2) +in_state = init_state(n_qubits) + + +# We call it laser here but it's just a function which takes in 1) some parameters and 2) the time of the simulation +# and returns the parameter values of the hamiltonian. So it's really just a way to simulate time dependent hamiltonians. +def laser(laser_params, t): + (w_rabi, w_detune) = laser_params + return { + "rabi": jnp.full((n_qubits,), 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t)), + "detune": jnp.full((n_qubits,), 15.0 * jnp.cos(2 * jnp.pi * w_detune * t)), + } + + +hamiltonian = RydbergHamiltonian(n_qubits, U) +hamiltonian_params = hamiltonian.init( + key, + in_state, + laser(laser_params, 0), +) + + +# Timing +start = time() +_ = simulate( + hamiltonian, + hamiltonian_params, + laser, + laser_params, + N, + dt, + in_state, +).block_until_ready() +stop = time() + +print(f"Simulation time for {n_qubits} qubits and {N} steps: {stop - start}s") +print("Note that for clarity we didn't jit the final function, so compilation time is included.") diff --git a/examples/analog/making_efficient_models.py b/examples/analog/making_efficient_models.py new file mode 100644 index 0000000..fce35bf --- /dev/null +++ b/examples/analog/making_efficient_models.py @@ -0,0 +1,121 @@ +# This shows how to build an efficient model using diagonalization +from __future__ import annotations + +from time import time + +import flax.linen as nn +import jax.numpy as jnp +import numpy as np +from chex import Array +from jax import random +from phaser.diagonal import diagonal_onebody_hamiltonian, diagonal_twobody_hamiltonian +from phaser.hamiltonians import HamiltonianTerm, Pauli_x, n +from phaser.propagators import second_order_trotter +from phaser.simulate import simulate +from phaser.utils import init_state, kron_sum + +key = random.PRNGKey(42) + + +# Defining diagonal detuning +def diagonal_detune_H(idx, weights): + return diagonal_onebody_hamiltonian(n, weights, idx) + + +def diagonal_detune_expm(idx, weights): + return jnp.exp(-1j * diagonal_detune_H(idx, weights)) + + +DiagonalDetune = HamiltonianTerm.create(diagonal_detune_H, diagonal_detune_expm) + + +# Interaction +def diagonal_interaction_H(idx, weights): + return diagonal_twobody_hamiltonian((n, n), weights, idx) + + +def diagonal_interaction_expm(idx, weights): + return jnp.exp(-1j * diagonal_interaction_H(idx, weights)) + + +DiagonalInteraction = HamiltonianTerm.create(diagonal_interaction_H, diagonal_interaction_expm) + + +def generate_interaction(U): + U_params = jnp.stack(U[np.triu_indices_from(U, k=1)]) + idx = tuple(zip(*np.triu_indices_from(U, k=1))) + + return DiagonalInteraction(idx, lambda key: U_params) + + +class DiagonalRydbergHamiltonian(nn.Module): + n_qubits: int + U: Array + + def setup(self): + # Rabi terms + H_rabi = [Pauli_x((idx,), None) for idx in range(self.n_qubits)] + + # Detuning + H_detune = DiagonalDetune(range(self.n_qubits), None) + + # Interaction term + H_interact = generate_interaction(self.U) + + # Joining all terms + self.H = [*H_rabi, H_detune, H_interact] + + def __call__(self, state, weights): + return kron_sum(self.H, state, self.parse_weights(weights)) + + def evolve(self, state: Array, dt: float, weights: dict): + return second_order_trotter(self.H, state, dt, self.parse_weights(weights)) + + def parse_weights(self, weights): + # Parse the weights from tuple to correct shape and values + return [ + *jnp.full((self.n_qubits,), weights["rabi"] / 2), + jnp.full((self.n_qubits,), -weights["detune"]), + None, + ] + + +if __name__ == "__main__": + # Initializing Hamiltonian + n_qubits = 20 + dt, N = 1e-3, 3000 + laser_params = (1.0, 2.0) + U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2) + in_state = init_state(n_qubits) + + def laser(laser_params, t): + (w_rabi, w_detune) = laser_params + return { + "rabi": 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t), + "detune": 15.0 * jnp.cos(2 * jnp.pi * w_detune * t), + } + + hamiltonian = DiagonalRydbergHamiltonian(n_qubits, U) + hamiltonian_params = hamiltonian.init( + key, + in_state, + laser(laser_params, 0), + ) + + # Timing + start = time() + _ = simulate( + hamiltonian, + hamiltonian_params, + laser, + laser_params, + N, + dt, + in_state, + ).block_until_ready() + stop = time() + + print(f"Simulation time for {n_qubits} qubits and {N} steps: {stop - start}s") + print( + "Note that for clarity we didn't jit the final function, so compilation time is included." + ) diff --git a/horqrux/phaser/__init__.py b/horqrux/phaser/__init__.py new file mode 100644 index 0000000..19cdd61 --- /dev/null +++ b/horqrux/phaser/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from .simulate import simulate diff --git a/horqrux/phaser/diagonal.py b/horqrux/phaser/diagonal.py new file mode 100644 index 0000000..a944dd7 --- /dev/null +++ b/horqrux/phaser/diagonal.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from functools import reduce +from itertools import chain + +import jax.numpy as jnp +from chex import Array + +from .utils import diagonal_kronecker, kron_AI, kron_IA + + +def diagonal_onebody_hamiltonian(Hi: Array, weights: Array, idx: list[int]) -> Array: + # Generates diagonal of diagonal onebody hamiltonian terms. + # Not pretty but it works... + def diagonal_Hi(diagonal: Array, idx: int) -> Array: + return kron_IA(kron_AI(diagonal, 2 ** (n_qubits - idx - 1)), 2**idx) + + n_qubits = max(idx) + 1 # +1 cause of index + Hi_diag = jnp.diag(Hi) + return reduce( + lambda state, x: state + x[0] * diagonal_Hi(Hi_diag, x[1]), + zip(weights, idx), + jnp.zeros(2**n_qubits), + ) + + +def diagonal_twobody_hamiltonian( + HiHj: tuple[Array, Array], weights: Array, idx: list[tuple[int, int]] +) -> Array: + # Generates diagonal of diagonal two-body hamiltonian terms. + # Not pretty but it works... + def diagonal_Hi(diagonal: list[Array], idx_ij: tuple[int, int]) -> Array: + idx_i, idx_j = idx_ij + left = kron_IA(diagonal[0], 2 ** (idx_i)) + right = kron_IA(kron_AI(diagonal[1], 2 ** (n_qubits - idx_j - 1)), 2 ** (idx_j - idx_i - 1)) + return diagonal_kronecker(left, right) + + n_qubits = max(list(chain(*idx))) + 1 # +1 cause of index + HiHj_diag = [jnp.diag(H) for H in HiHj] + return reduce( + lambda state, x: state + x[0] * diagonal_Hi(HiHj_diag, x[1]), + zip(weights, idx), + jnp.zeros(2**n_qubits), + ) diff --git a/horqrux/phaser/hamiltonians.py b/horqrux/phaser/hamiltonians.py new file mode 100644 index 0000000..feeecc9 --- /dev/null +++ b/horqrux/phaser/hamiltonians.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from functools import partial +from typing import Callable, Optional + +import flax.linen as nn +import jax.numpy as jnp +from chex import Array, PRNGKey + +from .propagators import apply_unitary + + +class HamiltonianTerm(nn.Module): + idx: tuple[int] + weight_init_fn: Callable[[PRNGKey, tuple], Array] | None + expm: Callable + H: Callable + + def setup(self): + if self.weight_init_fn is not None: + self.weight = self.param("weight", self.weight_init_fn) + else: + self.weight = None + + def __call__(self, state, weight: Optional[Array] = None): + # returns H |psi> + if weight is None: + weight = self.weight + return apply_unitary(state, self.H(self.idx, weight), self.idx) + + def evolve(self, state, t: float, weight: Optional[Array] = None): + # return expm(-iHt)|psi> + if weight is None: + weight = self.weight + return apply_unitary(state, self.expm(self.idx, t * weight), self.idx) + + @classmethod + def create(cls, H_fn, expm_fn): + # Creates a specific hamiltonian. + return partial(cls, expm=expm_fn, H=H_fn) + + +# Useful matrices +sz = jnp.array([[1.0, 0.0], [0.0, -1.0]]) +sx = jnp.array([[0.0, 1.0], [1.0, 0.0]]) +n = (sz + jnp.eye(2)) / 2 + + +# Pauli z +def pauli_z_expm(idx, theta: float): + # Implements expm(-1j * theta * sz) + return jnp.cos(theta) * jnp.eye(2) - 1j * jnp.sin(theta) * sz + + +def pauli_z_H(idx, theta: float): + return theta * sz + + +Pauli_z = HamiltonianTerm.create(pauli_z_H, pauli_z_expm) + + +# Pauli_x +def pauli_x_expm(idx, theta: float): + # Implements expm(-1j * theta * sx) + return jnp.cos(theta) * jnp.eye(2) - 1j * jnp.sin(theta) * sx + + +def pauli_x_H(idx, theta: float): + return theta * sx + + +Pauli_x = HamiltonianTerm.create(pauli_x_H, pauli_x_expm) + + +# Number operator +def number_expm(idx, theta: float): + return jnp.diag(jnp.exp(-1j * theta * jnp.array([1.0, 0.0]))) + + +def number_H(idx, theta: float): + return theta * n + + +Number = HamiltonianTerm.create(number_H, number_expm) + + +# Interaction operator +def interaction_expm(idx, u_ij: float): + return jnp.diag(jnp.exp(-1j * u_ij * jnp.array([1.0, 0.0, 0.0, 0.0]))) + + +def interaction_H(idx, u_ij: float): + return u_ij * jnp.kron(n, n) + + +Interaction = HamiltonianTerm.create(interaction_H, interaction_expm) diff --git a/horqrux/phaser/models/__init__.py b/horqrux/phaser/models/__init__.py new file mode 100644 index 0000000..7e7fe32 --- /dev/null +++ b/horqrux/phaser/models/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from .rydberg import RydbergHamiltonian diff --git a/horqrux/phaser/models/rydberg.py b/horqrux/phaser/models/rydberg.py new file mode 100644 index 0000000..c08df21 --- /dev/null +++ b/horqrux/phaser/models/rydberg.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import flax.linen as nn +import jax.numpy as jnp +import numpy as np +from chex import Array +from jax import random +from horqrux.phaser.diagonal import diagonal_onebody_hamiltonian, diagonal_twobody_hamiltonian +from horqrux.phaser.hamiltonians import HamiltonianTerm, Pauli_x, n +from horqrux.phaser.propagators import second_order_trotter +from horqrux.phaser.utils import kron_sum + +key = random.PRNGKey(42) + + +# Defining diagonal detuning +def diagonal_detune_H(idx, weights): + return diagonal_onebody_hamiltonian(n, weights, idx) + + +def diagonal_detune_expm(idx, weights): + return jnp.exp(-1j * diagonal_detune_H(idx, weights)) + + +DiagonalDetune = HamiltonianTerm.create(diagonal_detune_H, diagonal_detune_expm) + + +# Interaction +def diagonal_interaction_H(idx, weights): + return diagonal_twobody_hamiltonian((n, n), weights, idx) + + +def diagonal_interaction_expm(idx, weights): + return jnp.exp(-1j * diagonal_interaction_H(idx, weights)) + + +DiagonalInteraction = HamiltonianTerm.create(diagonal_interaction_H, diagonal_interaction_expm) + + +def generate_interaction(U): + U_params = jnp.stack(U[np.triu_indices_from(U, k=1)]) + idx = tuple(zip(*np.triu_indices_from(U, k=1))) + + return DiagonalInteraction(idx, lambda key: U_params) + + +class RydbergHamiltonian(nn.Module): + n_qubits: int + U: Array + + def setup(self): + # Rabi terms + H_rabi = [Pauli_x((idx,), None) for idx in range(self.n_qubits)] + + # Detuning + H_detune = DiagonalDetune(range(self.n_qubits), None) + + # Interaction term + H_interact = generate_interaction(self.U) + + # Joining all terms + self.H = [*H_rabi, H_detune, H_interact] + + def __call__(self, state, weights): + return kron_sum(self.H, state, self.parse_weights(weights)) + + def evolve(self, state: Array, dt: float, weights: dict): + return second_order_trotter(self.H, state, dt, self.parse_weights(weights)) + + def parse_weights(self, weights): + # Parse the weights from tuple to correct shape and values + return [ + *jnp.full((self.n_qubits,), weights["rabi"] / 2), + jnp.full((self.n_qubits,), -weights["detune"]), + None, + ] diff --git a/horqrux/phaser/propagators.py b/horqrux/phaser/propagators.py new file mode 100644 index 0000000..4b6d5d0 --- /dev/null +++ b/horqrux/phaser/propagators.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from functools import reduce + +import jax.numpy as jnp +import numpy as np +from chex import Array + +from .utils import hilbert_reshape + + +def first_order_trotter(H: list, state, dt, weights=None): + # First order trotter + if weights is None: + weights = len(H) * [None] + + return reduce( + lambda state, x: x[0].evolve(state, dt, x[1]), + zip(H, weights), + state, + ) + + +def second_order_trotter(H: list, state, dt, weights=None): + # second order trotter + if weights is None: + weights = len(H) * [None] + + return reduce( + lambda state, x: x[0].evolve(state, dt / 2, x[1]), + zip([*H[::-1], *H], [*weights[::-1], *weights]), + state, + ) + + +def apply_unitary(state: Array, U: Array, target_idx: tuple) -> Array: + def _apply_diagonal_unitary(state: Array, U: Array, target_idx: tuple) -> Array: + return U.reshape(state.shape) * state + + def _apply_matrix_unitary(state: Array, U: Array, target_idx: tuple) -> Array: + if len(target_idx) > 1: + U = hilbert_reshape(U) + + # Move axis to front, operate, move back + state = jnp.moveaxis(state, target_idx, np.arange(len(target_idx))) + state = jnp.tensordot(U, state, axes=len(target_idx)) + return jnp.moveaxis(state, np.arange(len(target_idx)), target_idx) + + if U.ndim == 1: + return _apply_diagonal_unitary(state, U, target_idx) + elif U.ndim == 2: + return _apply_matrix_unitary(state, U, target_idx) + else: + raise NotImplementedError diff --git a/horqrux/phaser/simulate.py b/horqrux/phaser/simulate.py new file mode 100644 index 0000000..fd5bf0e --- /dev/null +++ b/horqrux/phaser/simulate.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from horqrux.phaser.solvers import forward_euler_solve + + +def simulate( + hamiltonian, hamiltonian_params, laser, laser_params, N, dt, in_state, **solver_kwargs +): + def propagate_fn(params, state, t, dt): + hamiltonian_params, laser_params = params + return hamiltonian.apply( + hamiltonian_params, + state, + dt, + laser(laser_params, t), + method=hamiltonian.evolve, + ) + + return forward_euler_solve( + in_state, propagate_fn, (hamiltonian_params, laser_params), N=N, dt=dt, **solver_kwargs + ) diff --git a/horqrux/phaser/solvers.py b/horqrux/phaser/solvers.py new file mode 100644 index 0000000..b3afcce --- /dev/null +++ b/horqrux/phaser/solvers.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from functools import partial +from typing import Any, Callable + +import jax +import jax.numpy as jnp +from chex import Array + + +@partial(jax.jit, static_argnames=("propagate_fn", "iterate_idx", "N")) +def forward_euler_solve( + state: Array, + propagate_fn: Callable, + params: Any, + N: int, + dt: float, + iterate_idx: bool = False, +) -> Array: + def update_fn(state: Array, t: float) -> Array: + return propagate_fn(params, state, t, dt), None + + # Iterate index gives step number instead of time + # useful for comparing to pulser. + + if iterate_idx is True: + t = jnp.arange(N) + else: + t = dt * jnp.arange(N) + + return jax.lax.scan(update_fn, state, t)[0] diff --git a/horqrux/phaser/test_utils.py b/horqrux/phaser/test_utils.py new file mode 100644 index 0000000..41c4c9e --- /dev/null +++ b/horqrux/phaser/test_utils.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import jax.numpy as jnp + + +def state_overlap(state_A, state_B): + return jnp.abs(jnp.dot(jnp.conjugate(state_A.flatten()), state_B.flatten())) + + +def state_norm(state): + return jnp.sum(jnp.abs(state) ** 2) diff --git a/horqrux/phaser/utils.py b/horqrux/phaser/utils.py new file mode 100644 index 0000000..2917836 --- /dev/null +++ b/horqrux/phaser/utils.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from functools import reduce +from typing import Any + +import jax.numpy as jnp +import numpy as np +from chex import Array + + +def hilbert_reshape(U: Array) -> Array: + # Reshapes O of shape [M, M] to array of shape [2, 2, ...]. Useful for working with controlled and multi-qubit gates. + n_axes = int(np.log2(U.size)) + return U.reshape(tuple(2 for _ in np.arange(n_axes))) + + +def kron_prod(*x: Array) -> Array: + # custom kronecker product which can multiply multiple matrices + # and filters out zeros because sometime we generate jnp.eye(0) + return reduce(jnp.kron, filter(lambda x: x.size != 0, x)) + + +def kron_sum(H: list, state: Array, weights: list = None) -> Array: + if weights is None: + weights = len(H) * [None] + # kronecker sum + return reduce( + lambda out_state, x: out_state + x[0](state, x[1]), + zip(H, weights), + jnp.zeros_like(state), + ) + + +def make_explicit(H: Array, params: list, n_qubits: int) -> Array: + def Hi(x: Any) -> Any: + term, param = x + idx = term.idx + _H = term.H(idx, param) + + # 1 body term + if len(idx) == 1: + (idx,) = idx + return kron_prod(jnp.eye(2**idx), _H, jnp.eye(2 ** (n_qubits - idx - 1))) + # 2 body term + elif len(idx) == 2: + idx_i, idx_j = idx + return kron_prod( + jnp.eye(2 ** (idx_i)), + _H[0], + jnp.eye(2 ** (idx_j - idx_i - 1)), + _H[1], + jnp.eye(2 ** (n_qubits - idx_j - 1)), + ) + + _H = jnp.zeros((2**n_qubits, 2**n_qubits)) + return reduce(lambda H_full, x: H_full + Hi(x), zip(H, params), _H) + + +def init_state(n_qubits: int) -> Array: + state = jnp.zeros(tuple(2 for _ in np.arange(n_qubits)), dtype=jnp.complex64) + state = state.at[tuple(-1 for _ in np.arange(n_qubits))].set(1.0) + return state + + +def diagonal_kronecker(A: Array, B: Array) -> Array: + """Given two diagonal A and B, calculates diagonal of kronecker product, + which is also diagonal.""" + return (A[:, None] * B[None, :]).reshape(A.size * B.size) + + +def kron_AI(A: Array, N: int) -> Array: + # Calculates A kron I, diagonal only. + return jnp.repeat(A, repeats=N, axis=0) + + +def kron_IA(A: Array, N: int) -> Array: + # Calculates I kron A, diagonal only. + return jnp.tile(A, reps=N) diff --git a/pyproject.toml b/pyproject.toml index 95681cc..40542d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ requires-python = ">=3.9,<3.12" license = {text = "Apache 2.0"} -version = "0.3.0" +version = "0.4.0" classifiers=[ "License :: Other/Proprietary License", @@ -35,7 +35,7 @@ dependencies = [ [project.optional-dependencies] -dev = ["black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff"] +dev = ["black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff", "pulser"] [tool.hatch.envs.tests] features = [ diff --git a/tests/analog/test_comparison_pulser.py b/tests/analog/test_comparison_pulser.py new file mode 100644 index 0000000..48b4c15 --- /dev/null +++ b/tests/analog/test_comparison_pulser.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import jax.numpy as jnp +import numpy as np +from jax import random +from pulser import Pulse, Register, Sequence +from pulser.devices import Chadoq2 +from pulser.waveforms import InterpolatedWaveform +from pulser_simulation.simulation import Simulation +from scipy.spatial.distance import cdist + +# from horqrux.phaser.models.rydberg import QUBOHamiltonian +from horqrux.phaser.simulate import simulate +from horqrux.phaser.test_utils import state_norm, state_overlap +from horqrux.phaser.utils import init_state + +key = random.PRNGKey(42) + + +def test_forward_euler_solver_constant_qubo(): + # Compare pulser to phaser for constant + + n_qubits = 5 + f_rabi = 4.5 + f_detune = 1.5 + + def pulser_result(): + # Setting up register + coords = random.normal(key, (n_qubits, 2)) + distances = cdist(coords, coords) + min_dist = np.min(distances[distances > 0]) + coords = 6 * coords / min_dist + reg = Register.from_coordinates(coords) + + param_seq = Sequence(reg, Chadoq2) + param_seq.declare_channel("ch0", "rydberg_global") + amplitudes = param_seq.declare_variable("amplitudes", size=2) + detunings = param_seq.declare_variable("detunings", size=2) + param_seq.add( + Pulse( + InterpolatedWaveform(100, amplitudes), + InterpolatedWaveform(100, detunings), + 0, + ), + "ch0", + ) + + # Detuning of constant 1.0, rabi of 2.0 + seq1 = param_seq.build(amplitudes=[f_rabi, f_rabi], detunings=[f_detune, f_detune]) + sim = Simulation(seq1) + res = sim.run() + + # Stuff we need to set up Phaser the same + dt = sim.evaluation_times[1] - sim.evaluation_times[0] + N = sim.evaluation_times.size + 1 + coords = np.stack(list(reg.qubits.values()), axis=0) + dists = np.maximum(cdist(coords, coords), 1.0) + C = Chadoq2.interaction_coeff + U_inter = np.triu(C / dists**6, k=1) + + return res.get_final_state().full(), (dt, N, U_inter) + + pulser_state, (dt, N, U) = pulser_result() + + # Phaser result + def phaser_result(dt, N, U): + # We define a laser function + # This one is constant but always a function + def laser(laser_params, t): + return { + "rabi": jnp.full((n_qubits,), f_rabi), + "detune": jnp.full((n_qubits,), f_detune), + } + + # Initializing Hamiltonian + in_state = init_state(n_qubits) + hamiltonian = QUBOHamiltonian(n_qubits, U) + hamiltonian_params = hamiltonian.init( + key, + in_state, + laser(None, 0.0), + ) + return simulate(hamiltonian, hamiltonian_params, laser, None, N, dt, in_state) + + phaser_state = phaser_result(dt, N, U) + assert jnp.allclose(state_norm(phaser_state), 1.0, atol=1e-4) + assert jnp.allclose(state_overlap(phaser_state, pulser_state), 1.0, atol=1e-2) + + +def test_forward_euler_time_varying_qubo(): + n_qubits = 5 + + def pulser_result(): + # Pulser result + coords = random.normal(key, (n_qubits, 2)) + distances = cdist(coords, coords) + min_dist = np.min(distances[distances > 0]) + coords = 6 * coords / min_dist + + reg = Register.from_coordinates(coords) + param_seq = Sequence(reg, Chadoq2) + param_seq.declare_channel("ch0", "rydberg_global") + amplitudes = param_seq.declare_variable("amplitudes", size=5) + detunings = param_seq.declare_variable("detunings", size=4) + param_seq.add( + Pulse( + InterpolatedWaveform(100, amplitudes), + InterpolatedWaveform(100, detunings), + 0, + ), + "ch0", + ) + + # Detuning of constant 1.0, rabi of 2.0 + seq1 = param_seq.build(amplitudes=[5, 15, 5, 10, 15], detunings=[-10, -20, -15, 0]) + sim = Simulation(seq1) + res = sim.run() + + # Stuff we need to set up Phaser the same + dt = sim.evaluation_times[1] - sim.evaluation_times[0] + N = sim.evaluation_times.size + 1 + coords = np.stack(list(reg.qubits.values()), axis=0) + dists = np.maximum(cdist(coords, coords), 1.0) + C = Chadoq2.interaction_coeff + U_inter = np.triu(C / dists**6, k=1) + f_rabi = jnp.array(sim.samples["Global"]["ground-rydberg"]["amp"]) + f_detune = jnp.array(sim.samples["Global"]["ground-rydberg"]["det"]) + + return res.get_final_state().full(), (dt, N, U_inter, f_rabi, f_detune) + + pulser_state, args = pulser_result() + + def phaser_result(dt, N, U, f_rabi, f_detune): + def laser(laser_params, t): + f_rabi, f_detune = laser_params + return { + "rabi": jnp.full((n_qubits,), f_rabi[t]), + "detune": jnp.full((n_qubits,), f_detune[t]), + } + + laser_params = (f_rabi, f_detune) + # Initializing Hamiltonian + in_state = init_state(n_qubits) + hamiltonian = QUBOHamiltonian(n_qubits, U) + hamiltonian_params = hamiltonian.init( + key, + in_state, + laser(laser_params, 0), + ) + return simulate( + hamiltonian, + hamiltonian_params, + laser, + laser_params, + N, + dt, + in_state, + iterate_idx=True, + ) + + phaser_state = phaser_result(*args) + assert jnp.allclose(state_norm(phaser_state), 1.0, atol=1e-4) + assert jnp.allclose(state_overlap(phaser_state, pulser_state), 1.0, atol=1e-2) + + +if __name__ == "__main__": + test_forward_euler_solver_constant_qubo() + test_forward_euler_time_varying_qubo() diff --git a/tests/analog/test_diagonal.py b/tests/analog/test_diagonal.py new file mode 100644 index 0000000..c6a647e --- /dev/null +++ b/tests/analog/test_diagonal.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import jax.numpy as jnp +import numpy as np +from jax import random +from pulser import Pulse, Register, Sequence +from pulser.devices import Chadoq2 +from pulser.waveforms import InterpolatedWaveform +from pulser_simulation.simulation import Simulation +from scipy.spatial.distance import cdist + +# from horqrux.phaser.models.rydberg import DiagonalQUBOHamiltonian +from horqrux.phaser.simulate import simulate +from horqrux.phaser.test_utils import state_norm, state_overlap +from horqrux.phaser.utils import init_state + +key = random.PRNGKey(42) + + +def test_forward_euler_solver_constant_qubo(): + # Compare pulser to phaser for constant + + n_qubits = 5 + f_rabi = 4.5 + f_detune = 1.5 + + def pulser_result(): + # Setting up register + coords = random.normal(key, (n_qubits, 2)) + distances = cdist(coords, coords) + min_dist = np.min(distances[distances > 0]) + coords = 6 * coords / min_dist + reg = Register.from_coordinates(coords) + + param_seq = Sequence(reg, Chadoq2) + param_seq.declare_channel("ch0", "rydberg_global") + amplitudes = param_seq.declare_variable("amplitudes", size=2) + detunings = param_seq.declare_variable("detunings", size=2) + param_seq.add( + Pulse( + InterpolatedWaveform(100, amplitudes), + InterpolatedWaveform(100, detunings), + 0, + ), + "ch0", + ) + + # Detuning of constant 1.0, rabi of 2.0 + seq1 = param_seq.build(amplitudes=[f_rabi, f_rabi], detunings=[f_detune, f_detune]) + sim = Simulation(seq1) + res = sim.run() + + # Stuff we need to set up Phaser the same + dt = sim.evaluation_times[1] - sim.evaluation_times[0] + N = sim.evaluation_times.size + 1 + coords = np.stack(list(reg.qubits.values()), axis=0) + dists = np.maximum(cdist(coords, coords), 1.0) + C = Chadoq2.interaction_coeff + U_inter = np.triu(C / dists**6, k=1) + + return res.get_final_state().full(), (dt, N, U_inter) + + pulser_state, (dt, N, U) = pulser_result() + + # Phaser result + def phaser_result(dt, N, U): + # We define a laser function + # This one is constant but always a function + def laser(laser_params, t): + return { + "rabi": jnp.full((n_qubits,), f_rabi), + "detune": jnp.full((n_qubits,), f_detune), + } + + # Initializing Hamiltonian + in_state = init_state(n_qubits) + hamiltonian = DiagonalQUBOHamiltonian(n_qubits, U) + hamiltonian_params = hamiltonian.init( + key, + in_state, + laser(None, 0.0), + ) + return simulate(hamiltonian, hamiltonian_params, laser, None, N, dt, in_state) + + phaser_state = phaser_result(dt, N, U) + assert jnp.allclose(state_norm(phaser_state), 1.0, atol=1e-4) + assert jnp.allclose(state_overlap(phaser_state, pulser_state), 1.0, atol=1e-2) + + +def test_forward_euler_time_varying_qubo(): + n_qubits = 5 + + def pulser_result(): + # Pulser result + coords = random.normal(key, (n_qubits, 2)) + distances = cdist(coords, coords) + min_dist = np.min(distances[distances > 0]) + coords = 6 * coords / min_dist + + reg = Register.from_coordinates(coords) + param_seq = Sequence(reg, Chadoq2) + param_seq.declare_channel("ch0", "rydberg_global") + amplitudes = param_seq.declare_variable("amplitudes", size=5) + detunings = param_seq.declare_variable("detunings", size=4) + param_seq.add( + Pulse( + InterpolatedWaveform(100, amplitudes), + InterpolatedWaveform(100, detunings), + 0, + ), + "ch0", + ) + + # Detuning of constant 1.0, rabi of 2.0 + seq1 = param_seq.build(amplitudes=[5, 15, 5, 10, 15], detunings=[-10, -20, -15, 0]) + sim = Simulation(seq1) + res = sim.run() + + # Stuff we need to set up Phaser the same + dt = sim.evaluation_times[1] - sim.evaluation_times[0] + N = sim.evaluation_times.size + 1 + coords = np.stack(list(reg.qubits.values()), axis=0) + dists = np.maximum(cdist(coords, coords), 1.0) + C = Chadoq2.interaction_coeff + U_inter = np.triu(C / dists**6, k=1) + f_rabi = jnp.array(sim.samples["Global"]["ground-rydberg"]["amp"]) + f_detune = jnp.array(sim.samples["Global"]["ground-rydberg"]["det"]) + + return res.get_final_state().full(), (dt, N, U_inter, f_rabi, f_detune) + + pulser_state, args = pulser_result() + + def phaser_result(dt, N, U, f_rabi, f_detune): + def laser(laser_params, t): + f_rabi, f_detune = laser_params + return { + "rabi": jnp.full((n_qubits,), f_rabi[t]), + "detune": jnp.full((n_qubits,), f_detune[t]), + } + + laser_params = (f_rabi, f_detune) + # Initializing Hamiltonian + in_state = init_state(n_qubits) + hamiltonian = DiagonalQUBOHamiltonian(n_qubits, U) + hamiltonian_params = hamiltonian.init( + key, + in_state, + laser(laser_params, 0), + ) + return simulate( + hamiltonian, + hamiltonian_params, + laser, + laser_params, + N, + dt, + in_state, + iterate_idx=True, + ) + + phaser_state = phaser_result(*args) + assert jnp.allclose(state_norm(phaser_state), 1.0, atol=1e-4) + assert jnp.allclose(state_overlap(phaser_state, pulser_state), 1.0, atol=1e-2) + + +if __name__ == "__main__": + test_forward_euler_solver_constant_qubo() + test_forward_euler_time_varying_qubo() diff --git a/tests/analog/test_diagonal_kronecker.py b/tests/analog/test_diagonal_kronecker.py new file mode 100644 index 0000000..d8af372 --- /dev/null +++ b/tests/analog/test_diagonal_kronecker.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import jax.numpy as jnp +import numpy as np +from jax import random + +from horqrux.phaser.diagonal import diagonal_onebody_hamiltonian, diagonal_twobody_hamiltonian +from horqrux.phaser.hamiltonians import Interaction, Number, n +from horqrux.phaser.utils import diagonal_kronecker, make_explicit + +key = random.PRNGKey(42) + + +def test_diagonal_kronecker(): + A = jnp.array([1, 3]) + B = jnp.array([2, 1]) + + assert jnp.allclose(diagonal_kronecker(A, B), jnp.diag(jnp.kron(jnp.diag(A), jnp.diag(B)))) + + +def test_onebody_hamiltonian(): + n_qubits = 6 + detune_weights = random.normal(key, (n_qubits,)) + H_detune = list(map(lambda x: Number((x,), None), range(n_qubits))) + H_detune = make_explicit(H_detune, detune_weights, n_qubits) + + H_detune_diagonal = diagonal_onebody_hamiltonian(n, detune_weights, np.arange(n_qubits)) + assert jnp.allclose(H_detune_diagonal, jnp.diag(H_detune)) + + +def test_twobody_hamiltonian(): + # This test currently fails as H_fn doesnt give back two H's. + n_qubits = 6 + + U = random.normal(key, (n_qubits, n_qubits)) + U = jnp.triu(U**2, k=1) + idx = list(zip(*np.triu_indices_from(U, k=1))) + H_interaction = list(map(lambda x: Interaction(x, None), idx)) + + weights = U[jnp.triu_indices_from(U, k=1)] + H_interaction = make_explicit(H_interaction, weights, n_qubits) + + H_interaction_diagonal = diagonal_twobody_hamiltonian([n, n], weights, idx) + + assert jnp.allclose(H_interaction_diagonal, jnp.diag(H_interaction)) + + +if __name__ == "__main__": + test_onebody_hamiltonian() diff --git a/tests/test_gates.py b/tests/digital/test_gates.py similarity index 100% rename from tests/test_gates.py rename to tests/digital/test_gates.py