Skip to content

Commit

Permalink
[Feature] Add features to extend expressions functionalities (#36)
Browse files Browse the repository at this point in the history
This PR aimed to accomplish the following

- [x] (**NO LONGER VALID**) extend arithmetic operations between
`Expression` and numeric values, such as int, float, complex, to also
external packages such as `numpy` and `torch` (currently only native
python numeric types supported)
- [x] remove `withrepr` decorator. It caused issues for not supporting
kwargs, such as the case of `NativeDrive` with its `instruction_name`
kwarg
- [x] include `__matmul__` (`@`) operation as a kron operation on
`Expression`

---------

Signed-off-by: Doomsk <[email protected]>
Co-authored-by: RolandMacDoland <[email protected]>
Co-authored-by: Kaonan Micadei <[email protected]>
Co-authored-by: Pim Venderbosch <[email protected]>
  • Loading branch information
4 people authored Nov 26, 2024
1 parent 286da76 commit 3fb112c
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 75 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ dependencies = [
"isort",
"ruff",
"pydocstringformatter",
"numpy",
"torch",
]

[tool.hatch.envs.default.scripts]
Expand Down
2 changes: 2 additions & 0 deletions qadence2_expressions/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .expression import Expression
from .support import Support
from .utils import Numeric

__all__ = [
"add_grid_options",
Expand Down Expand Up @@ -62,4 +63,5 @@
"unitary_hermitian_operator",
"value",
"variable",
"Numeric",
]
49 changes: 9 additions & 40 deletions qadence2_expressions/core/constructors.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,18 @@
from __future__ import annotations

from functools import wraps
from typing import Any, Callable

from .environment import Environment
from .expression import Expression
from .support import Support
from .utils import Numeric


def with_repr(repr_func: Callable) -> Callable:
"""Decorator to give a dynamic __repr__ to a function based on its arguments."""

class CallableWithRepr:
def __init__(
self,
func: Callable,
*args: Any,
**kwargs: Any,
) -> None:
self.func: Callable = func
self.args = args
self.kwargs = kwargs

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Call the original function with provided arguments
return self.func(*args, **kwargs)

def __repr__(self) -> str:
# Generate a custom repr using the provided repr_func
return str(repr_func(self.func, *self.args, **self.kwargs))

@wraps(repr_func)
def decorator(func: Callable) -> Callable:
# Return a wrapped CallableWithRepr instance
return lambda *args, **kwargs: CallableWithRepr(func(*args, **kwargs), *args, **kwargs)

return decorator


def value(x: complex | float | int) -> Expression:
def value(x: Numeric) -> Expression:
"""Create a numerical expression from the value `x`.
Args:
x (complex | float | int): Any numerical value.
x (Numeric): Any numerical value.
Returns:
Expression: An expression of type value.
Expand All @@ -52,18 +22,20 @@ def value(x: complex | float | int) -> Expression:
"""
if not isinstance(x, (complex, float, int)):
raise TypeError(
"Input to 'value' constructor must be of type 'complex', 'float' or 'int'. "
"Input to 'value' constructor must be of type numeric, e.g.:'complex',"
" 'float', 'int', 'torch.Tensor', 'numpy.ndarray', etc. "
f"Got {type(x)}."
)

return Expression.value(x)


def promote(x: Expression | complex | float | int) -> Expression:
def promote(x: Expression | Numeric) -> Expression:
"""Type cast inputs as value type expressions.
Args:
x (Expression | complex | float | int): A valid expression or numerical value. Numerical values are converted into `Value(x)` expressions.
x (Expression | Numeric): A valid expression or numerical value.
Numerical values are converted into `Value(x)` expressions.
Returns:
Expression: A value type or expression.
Expand Down Expand Up @@ -162,9 +134,8 @@ def function(name: str, *args: Any) -> Expression:
return Expression.function(name, *args)


@with_repr(lambda func, name: f"HermitianOperator(name='{name}')")
def unitary_hermitian_operator(name: str) -> Callable:
"""An unitary Hermitian operator.
"""A unitary Hermitian operator.
A Hermitian operator is a function that takes a list of indices
(or a target and control tuples) and return an Expression with
Expand Down Expand Up @@ -201,7 +172,6 @@ def core(
return core


@with_repr(lambda func, base, index: f"Projector(base='{base}', index='{index}')")
def projector(base: str, index: str) -> Callable:
"""A projector operator.
Expand Down Expand Up @@ -243,7 +213,6 @@ def core(
return core


@with_repr(lambda func, name: f"ParametricOperator(name='{name}')")
def parametric_operator(
name: str, *args: Any, join: Callable | None = None, **attributes: Any
) -> Callable:
Expand Down
63 changes: 37 additions & 26 deletions qadence2_expressions/core/expression.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import warnings
from enum import Enum
from functools import cached_property, reduce
from re import sub
from typing import Any

from .support import Support
from .utils import Numeric


class Expression:
Expand Down Expand Up @@ -41,7 +43,7 @@ def __init__(self, head: Expression.Tag, *args: Any, **attributes: Any) -> None:

# Constructors
@classmethod
def value(cls, x: complex | float | int) -> Expression:
def value(cls, x: Numeric) -> Expression:
"""Promote a numerical value (complex, float, int) to an expression.
Args:
Expand Down Expand Up @@ -246,7 +248,7 @@ def subspace(self) -> Support | None:
support: Support = self[1]
return support

# Collecting only non-null term's subpaces.
# Collecting only non-null term's subspaces.
subspaces = []
for arg in self.args:
sp = arg.subspace
Expand Down Expand Up @@ -285,7 +287,7 @@ def max_index(self) -> int:
if self.is_quantum_operator:
return self.subspace.max_index # type: ignore

# Retrun the maximum index among all the terms.
# Return the maximum index among all the terms.
return max(map(lambda arg: arg.max_index, self.args)) # type: ignore

# Helper functions.
Expand Down Expand Up @@ -377,11 +379,11 @@ def __eq__(self, other: object) -> bool:

# Algebraic operations
def __add__(self, other: object) -> Expression:
if not isinstance(other, Expression | complex | float | int):
if not isinstance(other, Expression | Numeric):
return NotImplemented

# Promote numerial values to Expression.
if isinstance(other, complex | float | int):
if isinstance(other, Numeric):
return self + Expression.value(other)

# Addition identity: a + 0 = 0 + a = a
Expand All @@ -405,24 +407,24 @@ def __add__(self, other: object) -> Expression:
args = (self, other)

# ⚠️ Warning: Ideally, this step should not perform the evaluation of the
# the expression. However, we want to provide a friendly interaction to
# the users, and the inacessibility of Python's evaluation (without writing
# our on REPL) forces to add the evaluation at this point.
# expression. However, we want to provide a friendly interaction to the users,
# and the inaccessibility of Python's evaluation (without writing our own REPL)
# forces to add the evaluation at this point.
return evaluate_addition(Expression.add(*args))

def __radd__(self, other: object) -> Expression:
# Promote numerical types to expression.
if isinstance(other, complex | float | int):
if isinstance(other, Numeric):
return Expression.value(other) + self

return NotImplemented

def __mul__(self, other: object) -> Expression:
if not isinstance(other, Expression | complex | float | int):
if not isinstance(other, Expression | Numeric):
return NotImplemented

# Promote numerical values to Expression.
if isinstance(other, complex | float | int):
if isinstance(other, Numeric):
return self * Expression.value(other)

# Null multiplication shortcut.
Expand Down Expand Up @@ -456,25 +458,25 @@ def __mul__(self, other: object) -> Expression:
args = (self, other)

# ⚠️ Warning: Ideally, this step should not perform the evaluation of the
# the expression. However, we want to provide a friendly intercation to
# the users, and the inacessibility of Python's evaluation (without writing
# our on REPL) forces to add the evaluation at this point.
# expression. However, we want to provide a friendly interaction to the users,
# and the inaccessibility of Python's evaluation (without writing our own REPL)
# forces to add the evaluation at this point.
return evaluate_multiplication(Expression.mul(*args))

def __rmul__(self, other: object) -> Expression:
# Promote numerical types to expression.
if isinstance(other, complex | float | int):
if isinstance(other, Numeric):
return Expression.value(other) * self

return NotImplemented

def __pow__(self, other: object) -> Expression:
"""Power involving quantum operators always promote expression to quantum operators."""

if not isinstance(other, Expression | complex | float | int):
if not isinstance(other, Expression | Numeric):
return NotImplemented

if isinstance(other, complex | float | int):
if isinstance(other, Numeric):
return self ** Expression.value(other)

# Numerical values are computed right away.
Expand All @@ -489,6 +491,7 @@ def __pow__(self, other: object) -> Expression:
if other.is_one:
return self

# Power of power is a simple operation and can be evaluated here.
if (
self.is_quantum_operator
and self.get("is_hermitian")
Expand All @@ -510,7 +513,7 @@ def __pow__(self, other: object) -> Expression:

def __rpow__(self, other: object) -> Expression:
# Promote numerical types to expression.
if isinstance(other, complex | float | int):
if isinstance(other, Numeric):
return Expression.value(other) ** self

return NotImplemented
Expand All @@ -519,28 +522,28 @@ def __neg__(self) -> Expression:
return -1 * self

def __sub__(self, other: object) -> Expression:
if not isinstance(other, Expression | complex | float | int):
if not isinstance(other, Expression | Numeric):
return NotImplemented

return self + (-other)

def __rsub__(self, other: object) -> Expression:
if not isinstance(other, Expression | complex | float | int):
if not isinstance(other, Expression | Numeric):
return NotImplemented

return (-self) + other

def __truediv__(self, other: object) -> Expression:
if not isinstance(other, Expression | complex | float | int):
if not isinstance(other, Expression | Numeric):
return NotImplemented

return self * (other**-1)

def __rtruediv__(self, other: object) -> Expression:
if not isinstance(other, complex | float | int):
if not isinstance(other, Numeric):
return NotImplemented

return other * (self**-1)
return other * (self**-1) # type: ignore

def __kron__(self, other: object) -> Expression:
if not isinstance(other, Expression):
Expand Down Expand Up @@ -569,11 +572,19 @@ def __kron__(self, other: object) -> Expression:
return self

# ⚠️ Warning: Ideally, this step should not perform the evaluation of the
# the expression. However, we want to provide a friendly intercation to
# the users, and the inacessibility of Python's evaluation (without writing
# our on REPL) forces to add the evaluation at this point.
# expression. However, we want to provide a friendly interaction to the users,
# and the inaccessibility of Python's evaluation (without writing our own REPL)
# forces to add the evaluation at this point.
return evaluate_kron(Expression.kron(self, other))

def __matmul__(self, other: object) -> Expression:
warnings.warn(
"The `@` (`__matmul__`) operator will be deprecated. Use `*` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.__kron__(other)


def evaluate_addition(expr: Expression) -> Expression:
if not expr.is_addition:
Expand Down
6 changes: 6 additions & 0 deletions qadence2_expressions/core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from __future__ import annotations

from typing import Union


Numeric = Union[complex | float | int]
11 changes: 6 additions & 5 deletions qadence2_expressions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,29 @@
Expression,
function,
promote,
Numeric,
)


def sin(x: Expression | complex | float | int) -> Expression:
def sin(x: Expression | Numeric) -> Expression:
return function("sin", promote(x))


def cos(x: Expression | complex | float | int) -> Expression:
def cos(x: Expression | Numeric) -> Expression:
return function("cos", promote(x))


# Exponential function as power.
def exp(x: Expression | complex | float | int) -> Expression:
def exp(x: Expression | Numeric) -> Expression:
return Expression.symbol("E") ** promote(x)


def log(x: Expression | complex | float | int) -> Expression:
def log(x: Expression | Numeric) -> Expression:
expr = function("log", promote(x))
# Logarithms of operators are also operators and need to be arranged as such.
return expr.as_quantum_operator()


# Using square root as power makes symbolic simplifications easier.
def sqrt(x: Expression | complex | float | int) -> Expression:
def sqrt(x: Expression | Numeric) -> Expression:
return promote(x) ** 0.5
Empty file removed tests/__init__.py
Empty file.
17 changes: 13 additions & 4 deletions tests/test_expression.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pytest

from qadence2_expressions import (
Expression,
Support,
Expand Down Expand Up @@ -90,19 +92,26 @@ def test_division() -> None:
def test_kron() -> None:
X = unitary_hermitian_operator("X")
term = Expression.kron(X(1), X(2), X(4))
expected = Expression.kron(X(1), X(2), X(3), X(4))

# Push term from the right.
assert term.__kron__(X(3)) == Expression.kron(X(1), X(2), X(3), X(4))
assert term.__kron__(X(3)) == expected
assert term @ X(3) == expected

# Push term from the left.
assert X(3).__kron__(term) == Expression.kron(X(1), X(2), X(3), X(4))
assert X(3).__kron__(term) == expected
assert X(3) @ term == expected

# Join `kron` expressions.
term1 = Expression.kron(X(1), X(4))
term2 = Expression.kron(X(2), X(3))

assert term1.__kron__(term2) == Expression.kron(X(1), X(2), X(3), X(4))
assert term2.__kron__(term1) == Expression.kron(X(1), X(2), X(3), X(4))
assert term1.__kron__(term2) == expected
assert term2.__kron__(term1) == expected
assert term1 @ term2 == expected
assert term2 @ term1 == expected
with pytest.deprecated_call():
term1 @ term2


def test_commutativity() -> None:
Expand Down

0 comments on commit 3fb112c

Please sign in to comment.