Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proto] Initial Phaser port #5

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/analog/gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# This example shows how to calculate gradients
from __future__ import annotations

import jax
import jax.numpy as jnp
from jax import random
from phaser import simulate
from phaser.models import RydbergHamiltonian
from phaser.utils import init_state

key = random.PRNGKey(42)

# Initializing Hamiltonian
n_qubits = 15
dt, N = 1e-3, 3000
laser_params = (1.0, 2.0)
U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2)
in_state = init_state(n_qubits)


def laser(laser_params, t):
(w_rabi, w_detune) = laser_params
return {
"rabi": 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t),
"detune": 15.0 * jnp.cos(2 * jnp.pi * w_detune * t),
}


hamiltonian = RydbergHamiltonian(n_qubits, U)
hamiltonian_params = hamiltonian.init(
key,
in_state,
laser(laser_params, 0),
)


# We take the gradient of some random state w.r.t the laser params and interaction_matrix
def forward(laser_params, hamiltonian_params):
out_state = simulate(
hamiltonian,
hamiltonian_params,
laser,
laser_params,
N,
dt,
in_state,
)
return (jnp.abs(out_state) ** 2).flatten()[-1]


# Getting the gradient fn w.r.t. both the pulse and interaction matrix and printing the grads
# Note that we jit (compile) the function so the timing here includes compiling
# but this only needs to happen once
grad_fn = jax.jit(jax.grad(forward, argnums=[0, 1]))
laser_grads, interaction_grads = grad_fn(laser_params, hamiltonian_params)

print(f"Gradients w.r.t laser params: \n {laser_grads}")
print(f"Gradients w.r.t interaction matrix: \n {interaction_grads}")
88 changes: 88 additions & 0 deletions examples/analog/introduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# This example shows how to build a model hamiltonian and simulate it.
from __future__ import annotations

from time import time

import flax.linen as nn
import jax.numpy as jnp
import numpy as np
from chex import Array
from jax import random
from phaser.hamiltonians import Interaction, Number, Pauli_x
from phaser.propagators import second_order_trotter
from phaser.simulate import simulate
from phaser.utils import init_state, kron_sum

key = random.PRNGKey(42)


class RydbergHamiltonian(nn.Module):
n_qubits: int
U: Array

def setup(self):
# Rabi terms
H_rabi = [Pauli_x((idx,), None) for idx in np.arange(self.n_qubits)]

# Detuning terms
H_detune = [Number((idx,), None) for idx in np.arange(self.n_qubits)]

# Interaction term
# We don't want to learn U here so it's just a matrix
self.U_params = self.U[np.triu_indices_from(self.U, k=1)]
H_interact = [Interaction(idx, None) for idx in zip(*np.triu_indices_from(self.U, k=1))]

# Joining all terms
self.H = H_rabi + H_detune + H_interact

def __call__(self, state, weights):
weights = jnp.concatenate([weights["rabi"] / 2, -weights["detune"], self.U_params])
return kron_sum(self.H, state, weights)

def evolve(self, state: Array, dt: float, weights: dict):
# Getting weights into same shape
weights = jnp.concatenate([weights["rabi"] / 2, -weights["detune"], self.U_params])
return second_order_trotter(self.H, state, dt, weights)


# Initializing Hamiltonian
n_qubits = 15
dt, N = 1e-3, 3000
laser_params = (1.0, 2.0)
U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2)
in_state = init_state(n_qubits)


# We call it laser here but it's just a function which takes in 1) some parameters and 2) the time of the simulation
# and returns the parameter values of the hamiltonian. So it's really just a way to simulate time dependent hamiltonians.
def laser(laser_params, t):
(w_rabi, w_detune) = laser_params
return {
"rabi": jnp.full((n_qubits,), 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t)),
"detune": jnp.full((n_qubits,), 15.0 * jnp.cos(2 * jnp.pi * w_detune * t)),
}


hamiltonian = RydbergHamiltonian(n_qubits, U)
hamiltonian_params = hamiltonian.init(
key,
in_state,
laser(laser_params, 0),
)


# Timing
start = time()
_ = simulate(
hamiltonian,
hamiltonian_params,
laser,
laser_params,
N,
dt,
in_state,
).block_until_ready()
stop = time()

print(f"Simulation time for {n_qubits} qubits and {N} steps: {stop - start}s")
print("Note that for clarity we didn't jit the final function, so compilation time is included.")
121 changes: 121 additions & 0 deletions examples/analog/making_efficient_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# This shows how to build an efficient model using diagonalization
from __future__ import annotations

from time import time

import flax.linen as nn
import jax.numpy as jnp
import numpy as np
from chex import Array
from jax import random
from phaser.diagonal import diagonal_onebody_hamiltonian, diagonal_twobody_hamiltonian
from phaser.hamiltonians import HamiltonianTerm, Pauli_x, n
from phaser.propagators import second_order_trotter
from phaser.simulate import simulate
from phaser.utils import init_state, kron_sum

key = random.PRNGKey(42)


# Defining diagonal detuning
def diagonal_detune_H(idx, weights):
return diagonal_onebody_hamiltonian(n, weights, idx)


def diagonal_detune_expm(idx, weights):
return jnp.exp(-1j * diagonal_detune_H(idx, weights))


DiagonalDetune = HamiltonianTerm.create(diagonal_detune_H, diagonal_detune_expm)


# Interaction
def diagonal_interaction_H(idx, weights):
return diagonal_twobody_hamiltonian((n, n), weights, idx)


def diagonal_interaction_expm(idx, weights):
return jnp.exp(-1j * diagonal_interaction_H(idx, weights))


DiagonalInteraction = HamiltonianTerm.create(diagonal_interaction_H, diagonal_interaction_expm)


def generate_interaction(U):
U_params = jnp.stack(U[np.triu_indices_from(U, k=1)])
idx = tuple(zip(*np.triu_indices_from(U, k=1)))

return DiagonalInteraction(idx, lambda key: U_params)


class DiagonalRydbergHamiltonian(nn.Module):
n_qubits: int
U: Array

def setup(self):
# Rabi terms
H_rabi = [Pauli_x((idx,), None) for idx in range(self.n_qubits)]

# Detuning
H_detune = DiagonalDetune(range(self.n_qubits), None)

# Interaction term
H_interact = generate_interaction(self.U)

# Joining all terms
self.H = [*H_rabi, H_detune, H_interact]

def __call__(self, state, weights):
return kron_sum(self.H, state, self.parse_weights(weights))

def evolve(self, state: Array, dt: float, weights: dict):
return second_order_trotter(self.H, state, dt, self.parse_weights(weights))

def parse_weights(self, weights):
# Parse the weights from tuple to correct shape and values
return [
*jnp.full((self.n_qubits,), weights["rabi"] / 2),
jnp.full((self.n_qubits,), -weights["detune"]),
None,
]


if __name__ == "__main__":
# Initializing Hamiltonian
n_qubits = 20
dt, N = 1e-3, 3000
laser_params = (1.0, 2.0)
U = jnp.triu(random.normal(key, (n_qubits, n_qubits)) ** 2)
in_state = init_state(n_qubits)

def laser(laser_params, t):
(w_rabi, w_detune) = laser_params
return {
"rabi": 20.0 * jnp.cos(2 * jnp.pi * w_rabi * t),
"detune": 15.0 * jnp.cos(2 * jnp.pi * w_detune * t),
}

hamiltonian = DiagonalRydbergHamiltonian(n_qubits, U)
hamiltonian_params = hamiltonian.init(
key,
in_state,
laser(laser_params, 0),
)

# Timing
start = time()
_ = simulate(
hamiltonian,
hamiltonian_params,
laser,
laser_params,
N,
dt,
in_state,
).block_until_ready()
stop = time()

print(f"Simulation time for {n_qubits} qubits and {N} steps: {stop - start}s")
print(
"Note that for clarity we didn't jit the final function, so compilation time is included."
)
3 changes: 3 additions & 0 deletions horqrux/phaser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import annotations

from .simulate import simulate
44 changes: 44 additions & 0 deletions horqrux/phaser/diagonal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from functools import reduce
from itertools import chain

import jax.numpy as jnp
from chex import Array

from .utils import diagonal_kronecker, kron_AI, kron_IA


def diagonal_onebody_hamiltonian(Hi: Array, weights: Array, idx: list[int]) -> Array:
# Generates diagonal of diagonal onebody hamiltonian terms.
# Not pretty but it works...
def diagonal_Hi(diagonal: Array, idx: int) -> Array:
return kron_IA(kron_AI(diagonal, 2 ** (n_qubits - idx - 1)), 2**idx)

n_qubits = max(idx) + 1 # +1 cause of index
Hi_diag = jnp.diag(Hi)
return reduce(
lambda state, x: state + x[0] * diagonal_Hi(Hi_diag, x[1]),
zip(weights, idx),
jnp.zeros(2**n_qubits),
)


def diagonal_twobody_hamiltonian(
HiHj: tuple[Array, Array], weights: Array, idx: list[tuple[int, int]]
) -> Array:
# Generates diagonal of diagonal two-body hamiltonian terms.
# Not pretty but it works...
def diagonal_Hi(diagonal: list[Array], idx_ij: tuple[int, int]) -> Array:
idx_i, idx_j = idx_ij
left = kron_IA(diagonal[0], 2 ** (idx_i))
right = kron_IA(kron_AI(diagonal[1], 2 ** (n_qubits - idx_j - 1)), 2 ** (idx_j - idx_i - 1))
return diagonal_kronecker(left, right)

n_qubits = max(list(chain(*idx))) + 1 # +1 cause of index
HiHj_diag = [jnp.diag(H) for H in HiHj]
return reduce(
lambda state, x: state + x[0] * diagonal_Hi(HiHj_diag, x[1]),
zip(weights, idx),
jnp.zeros(2**n_qubits),
)
Loading
Loading