From 7548885239ebf1f93cf30027dcbb91ccdb86d1da Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Thu, 5 Dec 2024 15:26:01 +0100 Subject: [PATCH] use a default_dtype function --- horqrux/_misc.py | 9 +++++++++ horqrux/matrices.py | 25 ++++++++++++++----------- horqrux/parametric.py | 7 +++++-- horqrux/utils.py | 16 ++++++++++------ 4 files changed, 38 insertions(+), 19 deletions(-) create mode 100644 horqrux/_misc.py diff --git a/horqrux/_misc.py b/horqrux/_misc.py new file mode 100644 index 0000000..894c3c3 --- /dev/null +++ b/horqrux/_misc.py @@ -0,0 +1,9 @@ +import jax +import jax.numpy as jnp +from jax._src.typing import DType + +def default_complex_dtype() -> DType: + if jax.config.jax_enable_x64: + return jnp.complex128 + else: + return jnp.complex64 \ No newline at end of file diff --git a/horqrux/matrices.py b/horqrux/matrices.py index 7079e7b..6e50393 100644 --- a/horqrux/matrices.py +++ b/horqrux/matrices.py @@ -3,17 +3,20 @@ import jax.numpy as jnp from jax import config +from ._misc import default_complex_dtype + config.update("jax_enable_x64", True) # Quantum ML requires higher precision +default_dtype = default_complex_dtype() -_X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex128) -_Y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex128) -_Z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex128) -_H = jnp.array([[1, 1], [1, -1]], dtype=jnp.complex128) * 1 / jnp.sqrt(2) -_S = jnp.array([[1, 0], [0, 1j]], dtype=jnp.complex128) -_T = jnp.array([[1, 0], [0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex128) -_I = jnp.asarray([[1, 0], [0, 1]], dtype=jnp.complex128) +_X = jnp.array([[0, 1], [1, 0]], dtype=default_dtype) +_Y = jnp.array([[0, -1j], [1j, 0]], dtype=default_dtype) +_Z = jnp.array([[1, 0], [0, -1]], dtype=default_dtype) +_H = jnp.array([[1, 1], [1, -1]], dtype=default_dtype) * 1 / jnp.sqrt(2) +_S = jnp.array([[1, 0], [0, 1j]], dtype=default_dtype) +_T = jnp.array([[1, 0], [0, jnp.exp(1j * jnp.pi / 4)]], dtype=default_dtype) +_I = jnp.asarray([[1, 0], [0, 1]], dtype=default_dtype) -_SWAP = jnp.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=jnp.complex128) +_SWAP = jnp.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=default_dtype) _SQSWAP = jnp.asarray( @@ -23,11 +26,11 @@ [0, 0.5 * (1 - 1j), 0.5 * (1 + 1j), 0], [0, 0, 0, 1], ], - dtype=jnp.complex128, + dtype=default_dtype, ) _ISWAP = jnp.asarray( - [[1, 0, 0, 0], [0, 0, 1j, 0], [0, 1j, 0, 0], [0, 0, 0, 1]], dtype=jnp.complex128 + [[1, 0, 0, 0], [0, 0, 1j, 0], [0, 1j, 0, 0], [0, 0, 0, 1]], dtype=default_dtype ) _ISQSWAP = jnp.asarray( @@ -37,7 +40,7 @@ [0, 1j / jnp.sqrt(2), 1 / jnp.sqrt(2), 0], [0, 0, 0, 1], ], - dtype=jnp.complex128, + dtype=default_dtype, ) OPERATIONS_DICT = { diff --git a/horqrux/parametric.py b/horqrux/parametric.py index bd5d488..49320d8 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -7,6 +7,7 @@ from jax import Array from jax.tree_util import register_pytree_node_class +from ._misc import default_complex_dtype from .matrices import OPERATIONS_DICT from .primitive import Primitive from .utils import ( @@ -18,6 +19,8 @@ is_controlled, ) +default_dtype = default_complex_dtype() + @register_pytree_node_class @dataclass @@ -118,12 +121,12 @@ def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None, class _PHASE(Parametric): def unitary(self, values: dict[str, float] = dict()) -> Array: - u = jnp.eye(2, 2, dtype=jnp.complex128) + u = jnp.eye(2, 2, dtype=default_dtype) u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(values))) return u def jacobian(self, values: dict[str, float] = dict()) -> Array: - jac = jnp.zeros((2, 2), dtype=jnp.complex128) + jac = jnp.zeros((2, 2), dtype=default_dtype) jac = jac.at[(1, 1)].set(1j * jnp.exp(1.0j * self.parse_values(values))) return jac diff --git a/horqrux/utils.py b/horqrux/utils.py index 2641530..6587cdd 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -10,6 +10,10 @@ from jax.typing import ArrayLike from numpy import log2 +from ._misc import default_complex_dtype + +default_dtype = default_complex_dtype() + State = ArrayLike QubitSupport = Tuple[Any, ...] ControlQubits = Tuple[Union[None, Tuple[int, ...]], ...] @@ -51,7 +55,7 @@ def _dagger(operator: Array) -> Array: def _unitary(generator: Array, theta: float) -> Array: return ( - jnp.cos(theta / 2) * jnp.eye(2, dtype=jnp.complex128) - 1j * jnp.sin(theta / 2) * generator + jnp.cos(theta / 2) * jnp.eye(2, dtype=default_dtype) - 1j * jnp.sin(theta / 2) * generator ) @@ -59,14 +63,14 @@ def _jacobian(generator: Array, theta: float) -> Array: return ( -1 / 2 - * (jnp.sin(theta / 2) * jnp.eye(2, dtype=jnp.complex128) + 1j * jnp.cos(theta / 2)) + * (jnp.sin(theta / 2) * jnp.eye(2, dtype=default_dtype) + 1j * jnp.cos(theta / 2)) * generator ) def _controlled(operator: Array, n_control: int) -> Array: n_qubits = int(log2(operator.shape[0])) - control = jnp.eye(2 ** (n_control + n_qubits), dtype=jnp.complex128) + control = jnp.eye(2 ** (n_control + n_qubits), dtype=default_dtype) control = control.at[-(2**n_qubits) :, -(2**n_qubits) :].set(operator) return control @@ -81,7 +85,7 @@ def product_state(bitstring: str) -> Array: A state corresponding to 'bitstring'. """ n_qubits = len(bitstring) - space = jnp.zeros(tuple(2 for _ in range(n_qubits)), dtype=jnp.complex128) + space = jnp.zeros(tuple(2 for _ in range(n_qubits)), dtype=default_dtype) space = space.at[tuple(map(int, bitstring))].set(1.0) return space @@ -140,8 +144,8 @@ def overlap(state: Array, projection: Array) -> Array: def uniform_state( n_qubits: int, ) -> Array: - state = jnp.ones(2**n_qubits, dtype=jnp.complex128) - state = state / jnp.sqrt(jnp.array(2**n_qubits, dtype=jnp.complex128)) + state = jnp.ones(2**n_qubits, dtype=default_dtype) + state = state / jnp.sqrt(jnp.array(2**n_qubits, dtype=default_dtype)) return state.reshape([2] * n_qubits)