Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix axis use for sum #543

Merged
merged 3 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor:
# Find where all the bras and kets are so they can be conjugated appropriately
conjugates = [i not in self.wires.ket.indices for i in range(len(self.wires.indices))]
quad_basis = math.sum(
[quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axis=[0]
[quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axis=0
)
return quad_basis

Expand Down
6 changes: 3 additions & 3 deletions mrmustard/lab_dev/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def visualize_2d(
shape = [max(min_shape, d) for d in self.auto_shape()]
state = self.to_fock(tuple(shape))
state = state.dm()
dm = math.sum(state.ansatz.array, axis=[0])
dm = math.sum(state.ansatz.array, axis=0)

x, prob_x = quadrature_distribution(dm)
p, prob_p = quadrature_distribution(dm, np.pi / 2)
Expand Down Expand Up @@ -502,7 +502,7 @@ def visualize_3d(
shape = [max(min_shape, d) for d in self.auto_shape()]
state = self.to_fock(tuple(shape))
state = state.dm()
dm = math.sum(state.ansatz.array, axis=[0])
dm = math.sum(state.ansatz.array, axis=0)

xvec = np.linspace(*xbounds, resolution)
pvec = np.linspace(*pbounds, resolution)
Expand Down Expand Up @@ -576,7 +576,7 @@ def visualize_dm(
raise ValueError("DM visualization not available for multi-mode states.")
state = self.to_fock(cutoff)
state = state.dm()
dm = math.sum(state.ansatz.array, axis=[0])
dm = math.sum(state.ansatz.array, axis=0)

fig = go.Figure(
data=go.Heatmap(z=abs(dm), colorscale="viridis", name="abs(ρ)", showscale=False)
Expand Down
6 changes: 3 additions & 3 deletions mrmustard/math/backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,17 +1102,17 @@ def sqrtm(self, tensor: Tensor, dtype=None) -> Tensor:
The square root of ``x``"""
return self._apply("sqrtm", (tensor, dtype))

def sum(self, array: Tensor, axis: Sequence[int] = None):
def sum(self, array: Tensor, axis: int | Sequence[int] | None = None):
r"""The sum of array.

Args:
array: The array to take the sum of
axes (tuple): The axis/axes to sum over
axis (int | Sequence[int] | None): The axis/axes to sum over

Returns:
The sum of array
"""
if axis is not None:
if axis is not None and not isinstance(axis, int):
neg = [a for a in axis if a < 0]
pos = [a for a in axis if a >= 0]
axis = tuple(sorted(neg) + sorted(pos)[::-1])
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/math/backend_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def sort(self, array: np.ndarray, axis: int = -1) -> np.ndarray:
def sqrt(self, x: np.ndarray, dtype=None) -> np.ndarray:
return np.sqrt(self.cast(x, dtype))

def sum(self, array: np.ndarray, axis: Sequence[int] = None):
def sum(self, array: np.ndarray, axis: int | tuple[int] | None = None):
return np.sum(array, axis=axis)

@Autocast()
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/math/backend_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def sort(self, array: tf.Tensor, axis: int = -1) -> tf.Tensor:
def sqrt(self, x: tf.Tensor, dtype=None) -> tf.Tensor:
return tf.sqrt(self.cast(x, dtype))

def sum(self, array: tf.Tensor, axis: Sequence[int] = None):
def sum(self, array: tf.Tensor, axis: int | tuple[int] | None = None):
return tf.reduce_sum(array, axis)

@Autocast()
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/physics/ansatz/array_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def sum_batch(self) -> ArrayAnsatz:
Returns:
The collapsed ArrayAnsatz object.
"""
return ArrayAnsatz(math.expand_dims(math.sum(self.array, axis=[0]), 0), batched=True)
return ArrayAnsatz(math.expand_dims(math.sum(self.array, axis=0), 0), batched=True)

def to_dict(self) -> dict[str, ArrayLike]:
return {"array": self.data}
Expand Down
6 changes: 3 additions & 3 deletions mrmustard/physics/ansatz/polyexp_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,12 @@ def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz:
self.A[..., :dim_alpha, :dim_alpha] * zz, axis=[-1, -2]
) # sum((b_arg,1,n,n) * (b_abc,n,n), [-1,-2]) ~ (b_arg,b_abc)
b_part = math.sum(
self.b[..., :dim_alpha] * z[..., None, :], axis=[-1]
self.b[..., :dim_alpha] * z[..., None, :], axis=-1
) # sum((b_arg,1,n) * (b_abc,n), [-1]) ~ (b_arg,b_abc)

exp_sum = math.exp(1 / 2 * A_part + b_part) # (b_arg, b_abc)
if dim_beta == 0:
val = math.sum(exp_sum * self.c, axis=[-1]) # (b_arg)
val = math.sum(exp_sum * self.c, axis=-1) # (b_arg)
else:
b_poly = math.astensor(
math.einsum(
Expand All @@ -441,7 +441,7 @@ def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz:
poly * self.c,
axis=math.arange(2, 2 + dim_beta, dtype=math.int32).tolist(),
),
axis=[-1],
axis=-1,
) # (b_arg)
return val

Expand Down
2 changes: 1 addition & 1 deletion mrmustard/physics/gaussian_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def complex_gaussian_integral_1(
inv_M = math.inv(M)
c_post = c * math.reshape(
math.sqrt(math.cast((-1) ** m / det_M, "complex128"))
* math.exp(-0.5 * math.sum(bM * math.solve(M, bM), axis=[-1])),
* math.exp(-0.5 * math.sum(bM * math.solve(M, bM), axis=-1)),
c.shape[:1] + (1,) * (len(c.shape) - 1),
)
A_post = R - math.einsum("bij,bjk,blk->bil", D, inv_M, D)
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/physics/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def fock_array(self, shape: int | Sequence[int], batched=False) -> ComplexTensor
f"Expected Fock shape of length {num_vars}, got length {len(shape)}"
) from e
arrays = self.ansatz.reduce(shape).array
array = math.sum(arrays, axis=[0])
array = math.sum(arrays, axis=0)
arrays = math.expand_dims(array, 0) if batched else array
return arrays

Expand Down
2 changes: 1 addition & 1 deletion tests/test_lab_dev/test_transformations/test_cft.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_wigner_function(self):

state = Ket.random([0]) >> Dgate([0], x=1.0, y=0.1)

dm = math.sum(state.to_fock(100).dm().ansatz.array, axis=[0])
dm = math.sum(state.to_fock(100).dm().ansatz.array, axis=0)
vec = np.linspace(-5, 5, 100)
wigner, _, _ = wigner_discretized(dm, vec, vec)

Expand Down
Loading