From 3fb112cf684c8dab28fb4535e8cb1355189d4be6 Mon Sep 17 00:00:00 2001 From: Doomsk Date: Tue, 26 Nov 2024 15:15:52 +0100 Subject: [PATCH] [Feature] Add features to extend expressions functionalities (#36) This PR aimed to accomplish the following - [x] (**NO LONGER VALID**) extend arithmetic operations between `Expression` and numeric values, such as int, float, complex, to also external packages such as `numpy` and `torch` (currently only native python numeric types supported) - [x] remove `withrepr` decorator. It caused issues for not supporting kwargs, such as the case of `NativeDrive` with its `instruction_name` kwarg - [x] include `__matmul__` (`@`) operation as a kron operation on `Expression` --------- Signed-off-by: Doomsk Co-authored-by: RolandMacDoland <9250798+RolandMacDoland@users.noreply.github.com> Co-authored-by: Kaonan Micadei Co-authored-by: Pim Venderbosch --- pyproject.toml | 2 + qadence2_expressions/core/__init__.py | 2 + qadence2_expressions/core/constructors.py | 49 ++++-------------- qadence2_expressions/core/expression.py | 63 +++++++++++++---------- qadence2_expressions/core/utils.py | 6 +++ qadence2_expressions/functions.py | 11 ++-- tests/__init__.py | 0 tests/test_expression.py | 17 ++++-- 8 files changed, 75 insertions(+), 75 deletions(-) create mode 100644 qadence2_expressions/core/utils.py delete mode 100644 tests/__init__.py diff --git a/pyproject.toml b/pyproject.toml index cca1ad6..f8b65a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,8 @@ dependencies = [ "isort", "ruff", "pydocstringformatter", + "numpy", + "torch", ] [tool.hatch.envs.default.scripts] diff --git a/qadence2_expressions/core/__init__.py b/qadence2_expressions/core/__init__.py index 3fbb806..18df836 100644 --- a/qadence2_expressions/core/__init__.py +++ b/qadence2_expressions/core/__init__.py @@ -32,6 +32,7 @@ ) from .expression import Expression from .support import Support +from .utils import Numeric __all__ = [ "add_grid_options", @@ -62,4 +63,5 @@ "unitary_hermitian_operator", "value", "variable", + "Numeric", ] diff --git a/qadence2_expressions/core/constructors.py b/qadence2_expressions/core/constructors.py index a5770fd..bd6def0 100644 --- a/qadence2_expressions/core/constructors.py +++ b/qadence2_expressions/core/constructors.py @@ -1,48 +1,18 @@ from __future__ import annotations -from functools import wraps from typing import Any, Callable from .environment import Environment from .expression import Expression from .support import Support +from .utils import Numeric -def with_repr(repr_func: Callable) -> Callable: - """Decorator to give a dynamic __repr__ to a function based on its arguments.""" - - class CallableWithRepr: - def __init__( - self, - func: Callable, - *args: Any, - **kwargs: Any, - ) -> None: - self.func: Callable = func - self.args = args - self.kwargs = kwargs - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Call the original function with provided arguments - return self.func(*args, **kwargs) - - def __repr__(self) -> str: - # Generate a custom repr using the provided repr_func - return str(repr_func(self.func, *self.args, **self.kwargs)) - - @wraps(repr_func) - def decorator(func: Callable) -> Callable: - # Return a wrapped CallableWithRepr instance - return lambda *args, **kwargs: CallableWithRepr(func(*args, **kwargs), *args, **kwargs) - - return decorator - - -def value(x: complex | float | int) -> Expression: +def value(x: Numeric) -> Expression: """Create a numerical expression from the value `x`. Args: - x (complex | float | int): Any numerical value. + x (Numeric): Any numerical value. Returns: Expression: An expression of type value. @@ -52,18 +22,20 @@ def value(x: complex | float | int) -> Expression: """ if not isinstance(x, (complex, float, int)): raise TypeError( - "Input to 'value' constructor must be of type 'complex', 'float' or 'int'. " + "Input to 'value' constructor must be of type numeric, e.g.:'complex'," + " 'float', 'int', 'torch.Tensor', 'numpy.ndarray', etc. " f"Got {type(x)}." ) return Expression.value(x) -def promote(x: Expression | complex | float | int) -> Expression: +def promote(x: Expression | Numeric) -> Expression: """Type cast inputs as value type expressions. Args: - x (Expression | complex | float | int): A valid expression or numerical value. Numerical values are converted into `Value(x)` expressions. + x (Expression | Numeric): A valid expression or numerical value. + Numerical values are converted into `Value(x)` expressions. Returns: Expression: A value type or expression. @@ -162,9 +134,8 @@ def function(name: str, *args: Any) -> Expression: return Expression.function(name, *args) -@with_repr(lambda func, name: f"HermitianOperator(name='{name}')") def unitary_hermitian_operator(name: str) -> Callable: - """An unitary Hermitian operator. + """A unitary Hermitian operator. A Hermitian operator is a function that takes a list of indices (or a target and control tuples) and return an Expression with @@ -201,7 +172,6 @@ def core( return core -@with_repr(lambda func, base, index: f"Projector(base='{base}', index='{index}')") def projector(base: str, index: str) -> Callable: """A projector operator. @@ -243,7 +213,6 @@ def core( return core -@with_repr(lambda func, name: f"ParametricOperator(name='{name}')") def parametric_operator( name: str, *args: Any, join: Callable | None = None, **attributes: Any ) -> Callable: diff --git a/qadence2_expressions/core/expression.py b/qadence2_expressions/core/expression.py index fb9a6ed..1e3b600 100644 --- a/qadence2_expressions/core/expression.py +++ b/qadence2_expressions/core/expression.py @@ -1,11 +1,13 @@ from __future__ import annotations +import warnings from enum import Enum from functools import cached_property, reduce from re import sub from typing import Any from .support import Support +from .utils import Numeric class Expression: @@ -41,7 +43,7 @@ def __init__(self, head: Expression.Tag, *args: Any, **attributes: Any) -> None: # Constructors @classmethod - def value(cls, x: complex | float | int) -> Expression: + def value(cls, x: Numeric) -> Expression: """Promote a numerical value (complex, float, int) to an expression. Args: @@ -246,7 +248,7 @@ def subspace(self) -> Support | None: support: Support = self[1] return support - # Collecting only non-null term's subpaces. + # Collecting only non-null term's subspaces. subspaces = [] for arg in self.args: sp = arg.subspace @@ -285,7 +287,7 @@ def max_index(self) -> int: if self.is_quantum_operator: return self.subspace.max_index # type: ignore - # Retrun the maximum index among all the terms. + # Return the maximum index among all the terms. return max(map(lambda arg: arg.max_index, self.args)) # type: ignore # Helper functions. @@ -377,11 +379,11 @@ def __eq__(self, other: object) -> bool: # Algebraic operations def __add__(self, other: object) -> Expression: - if not isinstance(other, Expression | complex | float | int): + if not isinstance(other, Expression | Numeric): return NotImplemented # Promote numerial values to Expression. - if isinstance(other, complex | float | int): + if isinstance(other, Numeric): return self + Expression.value(other) # Addition identity: a + 0 = 0 + a = a @@ -405,24 +407,24 @@ def __add__(self, other: object) -> Expression: args = (self, other) # ⚠️ Warning: Ideally, this step should not perform the evaluation of the - # the expression. However, we want to provide a friendly interaction to - # the users, and the inacessibility of Python's evaluation (without writing - # our on REPL) forces to add the evaluation at this point. + # expression. However, we want to provide a friendly interaction to the users, + # and the inaccessibility of Python's evaluation (without writing our own REPL) + # forces to add the evaluation at this point. return evaluate_addition(Expression.add(*args)) def __radd__(self, other: object) -> Expression: # Promote numerical types to expression. - if isinstance(other, complex | float | int): + if isinstance(other, Numeric): return Expression.value(other) + self return NotImplemented def __mul__(self, other: object) -> Expression: - if not isinstance(other, Expression | complex | float | int): + if not isinstance(other, Expression | Numeric): return NotImplemented # Promote numerical values to Expression. - if isinstance(other, complex | float | int): + if isinstance(other, Numeric): return self * Expression.value(other) # Null multiplication shortcut. @@ -456,14 +458,14 @@ def __mul__(self, other: object) -> Expression: args = (self, other) # ⚠️ Warning: Ideally, this step should not perform the evaluation of the - # the expression. However, we want to provide a friendly intercation to - # the users, and the inacessibility of Python's evaluation (without writing - # our on REPL) forces to add the evaluation at this point. + # expression. However, we want to provide a friendly interaction to the users, + # and the inaccessibility of Python's evaluation (without writing our own REPL) + # forces to add the evaluation at this point. return evaluate_multiplication(Expression.mul(*args)) def __rmul__(self, other: object) -> Expression: # Promote numerical types to expression. - if isinstance(other, complex | float | int): + if isinstance(other, Numeric): return Expression.value(other) * self return NotImplemented @@ -471,10 +473,10 @@ def __rmul__(self, other: object) -> Expression: def __pow__(self, other: object) -> Expression: """Power involving quantum operators always promote expression to quantum operators.""" - if not isinstance(other, Expression | complex | float | int): + if not isinstance(other, Expression | Numeric): return NotImplemented - if isinstance(other, complex | float | int): + if isinstance(other, Numeric): return self ** Expression.value(other) # Numerical values are computed right away. @@ -489,6 +491,7 @@ def __pow__(self, other: object) -> Expression: if other.is_one: return self + # Power of power is a simple operation and can be evaluated here. if ( self.is_quantum_operator and self.get("is_hermitian") @@ -510,7 +513,7 @@ def __pow__(self, other: object) -> Expression: def __rpow__(self, other: object) -> Expression: # Promote numerical types to expression. - if isinstance(other, complex | float | int): + if isinstance(other, Numeric): return Expression.value(other) ** self return NotImplemented @@ -519,28 +522,28 @@ def __neg__(self) -> Expression: return -1 * self def __sub__(self, other: object) -> Expression: - if not isinstance(other, Expression | complex | float | int): + if not isinstance(other, Expression | Numeric): return NotImplemented return self + (-other) def __rsub__(self, other: object) -> Expression: - if not isinstance(other, Expression | complex | float | int): + if not isinstance(other, Expression | Numeric): return NotImplemented return (-self) + other def __truediv__(self, other: object) -> Expression: - if not isinstance(other, Expression | complex | float | int): + if not isinstance(other, Expression | Numeric): return NotImplemented return self * (other**-1) def __rtruediv__(self, other: object) -> Expression: - if not isinstance(other, complex | float | int): + if not isinstance(other, Numeric): return NotImplemented - return other * (self**-1) + return other * (self**-1) # type: ignore def __kron__(self, other: object) -> Expression: if not isinstance(other, Expression): @@ -569,11 +572,19 @@ def __kron__(self, other: object) -> Expression: return self # ⚠️ Warning: Ideally, this step should not perform the evaluation of the - # the expression. However, we want to provide a friendly intercation to - # the users, and the inacessibility of Python's evaluation (without writing - # our on REPL) forces to add the evaluation at this point. + # expression. However, we want to provide a friendly interaction to the users, + # and the inaccessibility of Python's evaluation (without writing our own REPL) + # forces to add the evaluation at this point. return evaluate_kron(Expression.kron(self, other)) + def __matmul__(self, other: object) -> Expression: + warnings.warn( + "The `@` (`__matmul__`) operator will be deprecated. Use `*` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.__kron__(other) + def evaluate_addition(expr: Expression) -> Expression: if not expr.is_addition: diff --git a/qadence2_expressions/core/utils.py b/qadence2_expressions/core/utils.py new file mode 100644 index 0000000..7a763ad --- /dev/null +++ b/qadence2_expressions/core/utils.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from typing import Union + + +Numeric = Union[complex | float | int] diff --git a/qadence2_expressions/functions.py b/qadence2_expressions/functions.py index 560a746..2aa44f0 100644 --- a/qadence2_expressions/functions.py +++ b/qadence2_expressions/functions.py @@ -4,28 +4,29 @@ Expression, function, promote, + Numeric, ) -def sin(x: Expression | complex | float | int) -> Expression: +def sin(x: Expression | Numeric) -> Expression: return function("sin", promote(x)) -def cos(x: Expression | complex | float | int) -> Expression: +def cos(x: Expression | Numeric) -> Expression: return function("cos", promote(x)) # Exponential function as power. -def exp(x: Expression | complex | float | int) -> Expression: +def exp(x: Expression | Numeric) -> Expression: return Expression.symbol("E") ** promote(x) -def log(x: Expression | complex | float | int) -> Expression: +def log(x: Expression | Numeric) -> Expression: expr = function("log", promote(x)) # Logarithms of operators are also operators and need to be arranged as such. return expr.as_quantum_operator() # Using square root as power makes symbolic simplifications easier. -def sqrt(x: Expression | complex | float | int) -> Expression: +def sqrt(x: Expression | Numeric) -> Expression: return promote(x) ** 0.5 diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_expression.py b/tests/test_expression.py index d95d102..94d713b 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from qadence2_expressions import ( Expression, Support, @@ -90,19 +92,26 @@ def test_division() -> None: def test_kron() -> None: X = unitary_hermitian_operator("X") term = Expression.kron(X(1), X(2), X(4)) + expected = Expression.kron(X(1), X(2), X(3), X(4)) # Push term from the right. - assert term.__kron__(X(3)) == Expression.kron(X(1), X(2), X(3), X(4)) + assert term.__kron__(X(3)) == expected + assert term @ X(3) == expected # Push term from the left. - assert X(3).__kron__(term) == Expression.kron(X(1), X(2), X(3), X(4)) + assert X(3).__kron__(term) == expected + assert X(3) @ term == expected # Join `kron` expressions. term1 = Expression.kron(X(1), X(4)) term2 = Expression.kron(X(2), X(3)) - assert term1.__kron__(term2) == Expression.kron(X(1), X(2), X(3), X(4)) - assert term2.__kron__(term1) == Expression.kron(X(1), X(2), X(3), X(4)) + assert term1.__kron__(term2) == expected + assert term2.__kron__(term1) == expected + assert term1 @ term2 == expected + assert term2 @ term1 == expected + with pytest.deprecated_call(): + term1 @ term2 def test_commutativity() -> None: