Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-negative least squares solver. #1155

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Linear Algebra Operators
.. autosummary::
matrix_inverse_pth_root
power_iteration
nnls

Matrix inverse pth root
~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -59,6 +60,10 @@ Power iteration
~~~~~~~~~~~~~~~
.. autofunction:: power_iteration

Non-negative least squares
~~~~~~~~~~~~~~~
.. autofunction:: nnls


Second Order Optimization
-------------------------
Expand Down
2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from optax._src.linear_algebra import global_norm
from optax._src.linear_algebra import matrix_inverse_pth_root
from optax._src.linear_algebra import power_iteration
from optax._src.linear_algebra import nnls
from optax._src.linesearch import scale_by_backtracking_linesearch
from optax._src.linesearch import scale_by_zoom_linesearch
from optax._src.linesearch import ScaleByBacktrackingLinesearchState
Expand Down Expand Up @@ -378,6 +379,7 @@
"MultiTransformState",
"nadam",
"nadamw",
"nnls",
"noisy_sgd",
"novograd",
"NonNegativeParamsState",
Expand Down
95 changes: 95 additions & 0 deletions optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,98 @@ def _iter_body(state):
resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype)
return resultant_mat_h, error


def masked_argmax(x, mask):
y = jnp.where(mask, x, -jnp.inf)
assert isinstance(y, jax.Array)
return jnp.argmax(y)


def nnls(A: jax.Array, b: jax.Array, maxiter: int, tol: float) -> jax.Array:
r"""Solves the non-negative least squares problem.

Minimizes :math:`\|A x - b\|_2` subject to :math:`x \geq 0`.

Args:
A: A matrix.
b: A vector.
maxiter: The maximum number of iterations to run the algorithm for.
tol: The numerical tolerance to be used by the algorithm.

Returns:
The solution vector.

Examples:
>>> from jax import numpy as jnp
>>> import optax
>>> A = jnp.array([[1, 2], [3, 4]])
>>> b = jnp.array([5, 6])
>>> x = optax.nnls(A, b, 1000, 1e-3)
>>> print(f"{x[0]:.2f}")
0.00
>>> print(f"{x[1]:.2f}")
1.70

References:
Lawson and Hanson, `Solving Least Squares Problems
<https://doi.org/10.1137/1.9781611971217>`_, 1995
Bro and de Jong, `A fast non-negativity-constrained least squares algorithm
<https://analyticalsciencejournals.onlinelibrary.wiley.com/doi/10.1002/
(SICI)1099-128X(199709/10)11:5%3C393::AID-CEM483%3E3.0.CO;2-L>`_, 1999
"""
def out_cond_fn(carry):
x, p, w, it = carry
del x, it
return ~p.all() & (w > tol).any(where=~p)

def out_body_fn(carry):

def in_cond_fn(carry):
x, p, s, it = carry
del x
return (it < maxiter) & (s <= 0).any(where=p)

def in_body_fn(carry):
x, p, s, it = carry
it += 1
alpha = (x / (x - s)).min(where=p & (s <= 0), initial=jnp.inf)
x += alpha * (s - x)

p = jnp.where(x <= 0, False, p)
assert isinstance(p, jax.Array)

s = jnp.linalg.lstsq(A * p, b)[0]

return x, p, s, it

x, p, w, it = carry

j = masked_argmax(w, ~p)
p = p.at[j].set(True)

s = jnp.linalg.lstsq(A * p, b)[0]

x, p, s, it = lax.while_loop(in_cond_fn, in_body_fn, (x, p, s, it))

x = s
w = Atb - AtA @ x

return x, p, w, it

_, n = A.shape

if n == 0:
return jnp.zeros(0)

Atb = b @ A
AtA = A.T @ A

x = jnp.zeros(n)
p = jnp.zeros(n, bool)
w = Atb.astype(float)
it = 0

x, p, w, it = lax.while_loop(out_cond_fn, out_body_fn, (x, p, w, it))

return x
27 changes: 27 additions & 0 deletions optax/_src/linear_algebra_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,33 @@ def _gen_symmetrix_matrix(dim, condition_number):
# No guarantee of success after e >= 7
pass

@parameterized.parameters(
{'n': n, 'd': d}
for n in range(5)
for d in range(5)
for _ in range(5)
)
def test_nnls(self, n, d, atol=1e-4, tol=1e-5, maxiter=10**4):
"""Test non-negative least squares solver."""
A = np.random.normal(size=(n, d))
b = np.random.normal(size=(n,))

x = linear_algebra.nnls(A, b, maxiter=maxiter, tol=tol)

with self.subTest('x is non-negative'):
assert jnp.allclose(x.clip(max=0), 0, atol=atol)

xr, _ = scipy.optimize.nnls(A, b, maxiter=maxiter, atol=tol)

with self.subTest('xr is non-negative'):
assert jnp.allclose(xr.clip(max=0), 0, atol=atol)

d = jnp.square(A @ x - b).sum()
dr = jnp.square(A @ xr - b).sum()

with self.subTest('x is optimal'):
np.testing.assert_allclose(d, dr.clip(max=d), atol=atol)


if __name__ == '__main__':
absltest.main()
Loading