Skip to content

Commit

Permalink
use a default_dtype function
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Dec 5, 2024
1 parent e69435a commit 7548885
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 19 deletions.
9 changes: 9 additions & 0 deletions horqrux/_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import jax
import jax.numpy as jnp
from jax._src.typing import DType

def default_complex_dtype() -> DType:
if jax.config.jax_enable_x64:
return jnp.complex128
else:
return jnp.complex64
25 changes: 14 additions & 11 deletions horqrux/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
import jax.numpy as jnp
from jax import config

from ._misc import default_complex_dtype

config.update("jax_enable_x64", True) # Quantum ML requires higher precision
default_dtype = default_complex_dtype()

_X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex128)
_Y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex128)
_Z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex128)
_H = jnp.array([[1, 1], [1, -1]], dtype=jnp.complex128) * 1 / jnp.sqrt(2)
_S = jnp.array([[1, 0], [0, 1j]], dtype=jnp.complex128)
_T = jnp.array([[1, 0], [0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex128)
_I = jnp.asarray([[1, 0], [0, 1]], dtype=jnp.complex128)
_X = jnp.array([[0, 1], [1, 0]], dtype=default_dtype)
_Y = jnp.array([[0, -1j], [1j, 0]], dtype=default_dtype)
_Z = jnp.array([[1, 0], [0, -1]], dtype=default_dtype)
_H = jnp.array([[1, 1], [1, -1]], dtype=default_dtype) * 1 / jnp.sqrt(2)
_S = jnp.array([[1, 0], [0, 1j]], dtype=default_dtype)
_T = jnp.array([[1, 0], [0, jnp.exp(1j * jnp.pi / 4)]], dtype=default_dtype)
_I = jnp.asarray([[1, 0], [0, 1]], dtype=default_dtype)

_SWAP = jnp.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=jnp.complex128)
_SWAP = jnp.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=default_dtype)


_SQSWAP = jnp.asarray(
Expand All @@ -23,11 +26,11 @@
[0, 0.5 * (1 - 1j), 0.5 * (1 + 1j), 0],
[0, 0, 0, 1],
],
dtype=jnp.complex128,
dtype=default_dtype,
)

_ISWAP = jnp.asarray(
[[1, 0, 0, 0], [0, 0, 1j, 0], [0, 1j, 0, 0], [0, 0, 0, 1]], dtype=jnp.complex128
[[1, 0, 0, 0], [0, 0, 1j, 0], [0, 1j, 0, 0], [0, 0, 0, 1]], dtype=default_dtype
)

_ISQSWAP = jnp.asarray(
Expand All @@ -37,7 +40,7 @@
[0, 1j / jnp.sqrt(2), 1 / jnp.sqrt(2), 0],
[0, 0, 0, 1],
],
dtype=jnp.complex128,
dtype=default_dtype,
)

OPERATIONS_DICT = {
Expand Down
7 changes: 5 additions & 2 deletions horqrux/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jax import Array
from jax.tree_util import register_pytree_node_class

from ._misc import default_complex_dtype
from .matrices import OPERATIONS_DICT
from .primitive import Primitive
from .utils import (
Expand All @@ -18,6 +19,8 @@
is_controlled,
)

default_dtype = default_complex_dtype()


@register_pytree_node_class
@dataclass
Expand Down Expand Up @@ -118,12 +121,12 @@ def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None,

class _PHASE(Parametric):
def unitary(self, values: dict[str, float] = dict()) -> Array:
u = jnp.eye(2, 2, dtype=jnp.complex128)
u = jnp.eye(2, 2, dtype=default_dtype)
u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(values)))
return u

def jacobian(self, values: dict[str, float] = dict()) -> Array:
jac = jnp.zeros((2, 2), dtype=jnp.complex128)
jac = jnp.zeros((2, 2), dtype=default_dtype)
jac = jac.at[(1, 1)].set(1j * jnp.exp(1.0j * self.parse_values(values)))
return jac

Expand Down
16 changes: 10 additions & 6 deletions horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from jax.typing import ArrayLike
from numpy import log2

from ._misc import default_complex_dtype

default_dtype = default_complex_dtype()

State = ArrayLike
QubitSupport = Tuple[Any, ...]
ControlQubits = Tuple[Union[None, Tuple[int, ...]], ...]
Expand Down Expand Up @@ -51,22 +55,22 @@ def _dagger(operator: Array) -> Array:

def _unitary(generator: Array, theta: float) -> Array:
return (
jnp.cos(theta / 2) * jnp.eye(2, dtype=jnp.complex128) - 1j * jnp.sin(theta / 2) * generator
jnp.cos(theta / 2) * jnp.eye(2, dtype=default_dtype) - 1j * jnp.sin(theta / 2) * generator
)


def _jacobian(generator: Array, theta: float) -> Array:
return (
-1
/ 2
* (jnp.sin(theta / 2) * jnp.eye(2, dtype=jnp.complex128) + 1j * jnp.cos(theta / 2))
* (jnp.sin(theta / 2) * jnp.eye(2, dtype=default_dtype) + 1j * jnp.cos(theta / 2))
* generator
)


def _controlled(operator: Array, n_control: int) -> Array:
n_qubits = int(log2(operator.shape[0]))
control = jnp.eye(2 ** (n_control + n_qubits), dtype=jnp.complex128)
control = jnp.eye(2 ** (n_control + n_qubits), dtype=default_dtype)
control = control.at[-(2**n_qubits) :, -(2**n_qubits) :].set(operator)
return control

Expand All @@ -81,7 +85,7 @@ def product_state(bitstring: str) -> Array:
A state corresponding to 'bitstring'.
"""
n_qubits = len(bitstring)
space = jnp.zeros(tuple(2 for _ in range(n_qubits)), dtype=jnp.complex128)
space = jnp.zeros(tuple(2 for _ in range(n_qubits)), dtype=default_dtype)
space = space.at[tuple(map(int, bitstring))].set(1.0)
return space

Expand Down Expand Up @@ -140,8 +144,8 @@ def overlap(state: Array, projection: Array) -> Array:
def uniform_state(
n_qubits: int,
) -> Array:
state = jnp.ones(2**n_qubits, dtype=jnp.complex128)
state = state / jnp.sqrt(jnp.array(2**n_qubits, dtype=jnp.complex128))
state = jnp.ones(2**n_qubits, dtype=default_dtype)
state = state / jnp.sqrt(jnp.array(2**n_qubits, dtype=default_dtype))
return state.reshape([2] * n_qubits)


Expand Down

0 comments on commit 7548885

Please sign in to comment.