Skip to content

Commit

Permalink
Add non-negative least squares solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Dec 9, 2024
1 parent 3d8c391 commit ed83527
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 1 deletion.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ disable=R,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
invalid-name,


[REPORTS]
Expand Down
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Linear Algebra Operators
.. autosummary::
matrix_inverse_pth_root
power_iteration
nnls

Matrix inverse pth root
~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -59,6 +60,10 @@ Power iteration
~~~~~~~~~~~~~~~
.. autofunction:: power_iteration

Non-negative least squares
~~~~~~~~~~~~~~~
.. autofunction:: nnls


Second Order Optimization
-------------------------
Expand Down
2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from optax._src.linear_algebra import global_norm
from optax._src.linear_algebra import matrix_inverse_pth_root
from optax._src.linear_algebra import power_iteration
from optax._src.linear_algebra import nnls
from optax._src.linesearch import scale_by_backtracking_linesearch
from optax._src.linesearch import scale_by_zoom_linesearch
from optax._src.linesearch import ScaleByBacktrackingLinesearchState
Expand Down Expand Up @@ -381,6 +382,7 @@
"MultiTransformState",
"nadam",
"nadamw",
"nnls",
"noisy_sgd",
"novograd",
"NonNegativeParamsState",
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,7 +2482,7 @@ def lbfgs(
... )
... params = optax.apply_updates(params, updates)
... print('Objective function: ', f(params))
Objective function: 7.5166864
Objective function: 7.516686...
Objective function: 7.460699e-14
Objective function: 2.6505726e-28
Objective function: 0.0
Expand Down
97 changes: 97 additions & 0 deletions optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,100 @@ def _iter_body(state):
resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype)
return resultant_mat_h, error


def masked_argmax(x, mask):
y = jnp.where(mask, x, -jnp.inf)
assert isinstance(y, jax.Array)
return jnp.argmax(y)


def nnls(A: jax.Array, b: jax.Array, maxiter: int, tol: float) -> jax.Array:
r"""Solves the non-negative least squares problem.
Minimizes :math:`\|A x - b\|_2` subject to :math:`x \geq 0`.
Args:
A: A matrix.
b: A vector.
maxiter: The maximum number of iterations to run the algorithm for.
tol: The numerical tolerance to be used by the algorithm.
Returns:
The solution vector.
Examples:
>>> from jax import numpy as jnp
>>> import optax
>>> A = jnp.array([[1, 2], [3, 4]])
>>> b = jnp.array([5, 6])
>>> x = optax.nnls(A, b, 1000, 1e-3)
>>> print(x[0])
0.0
>>> print(x[1])
1.7
References:
Lawson and Hanson, `Solving Least Squares Problems
<https://doi.org/10.1137/1.9781611971217>`_, 1995
Bro and de Jong, `A fast non-negativity-constrained least squares algorithm
<https://analyticalsciencejournals.onlinelibrary.wiley.com/doi/10.1002/
(SICI)1099-128X(199709/10)11:5%3C393::AID-CEM483%3E3.0.CO;2-L>`_, 1999
"""
def out_cond_fn(carry):
x, p, w, it = carry
del x, it
return ~p.all() & (w > tol).any(where=~p)

def out_body_fn(carry):

def in_cond_fn(carry):
x, p, s, it = carry
del x
return (it < maxiter) & (s <= 0).any(where=p)

def in_body_fn(carry):
x, p, s, it = carry
it += 1
alpha = (x / (x - s)).min(where=p & (s <= 0), initial=jnp.inf)
x += alpha * (s - x)

p = jnp.where(x <= 0, False, p)
assert isinstance(p, jax.Array)

# s = jnp.linalg.lstsq(A * p, b)[0]
s = jnp.linalg.lstsq(AtA * p[:, None] * p[None, :], Atb * p)[0]

return x, p, s, it

x, p, w, it = carry

j = masked_argmax(w, ~p)
p = p.at[j].set(True)

# s = jnp.linalg.lstsq(A * p, b)[0]
s = jnp.linalg.lstsq(AtA * p[:, None] * p[None, :], Atb * p)[0]

x, p, s, it = lax.while_loop(in_cond_fn, in_body_fn, (x, p, s, it))

x = s
w = Atb - AtA @ x

return x, p, w, it

_, n = A.shape

if n == 0:
return jnp.zeros(0)

Atb = b @ A
AtA = A.T @ A

x = jnp.zeros(n)
p = jnp.zeros(n, bool)
w = Atb.astype(float)
it = 0

x, p, w, it = lax.while_loop(out_cond_fn, out_body_fn, (x, p, w, it))

return x
27 changes: 27 additions & 0 deletions optax/_src/linear_algebra_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,33 @@ def _gen_symmetrix_matrix(dim, condition_number):
# No guarantee of success after e >= 7
pass

@parameterized.parameters(
{'n': n, 'd': d}
for n in range(5)
for d in range(5)
for _ in range(5)
)
def test_nnls(self, n, d, atol=1e-4, tol=1e-5, maxiter=10**4):
"""Test non-negative least squares solver."""
A = np.random.normal(size=(n, d))
b = np.random.normal(size=(n,))

x = linear_algebra.nnls(A, b, maxiter=maxiter, tol=tol)

with self.subTest('x is non-negative'):
assert jnp.allclose(x.clip(max=0), 0, atol=atol)

xr, _ = scipy.optimize.nnls(A, b, maxiter=maxiter, atol=tol)

with self.subTest('xr is non-negative'):
assert jnp.allclose(xr.clip(max=0), 0, atol=atol)

d = jnp.square(A @ x - b).sum()
dr = jnp.square(A @ xr - b).sum()

with self.subTest('x is optimal'):
np.testing.assert_allclose(d, dr, atol=atol)


if __name__ == '__main__':
absltest.main()

0 comments on commit ed83527

Please sign in to comment.