diff --git a/.pylintrc b/.pylintrc index b26aeee4f..23aacadaa 100644 --- a/.pylintrc +++ b/.pylintrc @@ -129,6 +129,7 @@ disable=R, wrong-import-order, xrange-builtin, zip-builtin-not-iterating, + invalid-name, [REPORTS] diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index f404113fe..f2023ded1 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -50,6 +50,7 @@ Linear Algebra Operators .. autosummary:: matrix_inverse_pth_root power_iteration + nnls Matrix inverse pth root ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -59,6 +60,10 @@ Power iteration ~~~~~~~~~~~~~~~ .. autofunction:: power_iteration +Non-negative least squares +~~~~~~~~~~~~~~~ +.. autofunction:: nnls + Second Order Optimization ------------------------- diff --git a/optax/__init__.py b/optax/__init__.py index 0840b1d5a..3ad05f78b 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -78,6 +78,7 @@ from optax._src.linear_algebra import global_norm from optax._src.linear_algebra import matrix_inverse_pth_root from optax._src.linear_algebra import power_iteration +from optax._src.linear_algebra import nnls from optax._src.linesearch import scale_by_backtracking_linesearch from optax._src.linesearch import scale_by_zoom_linesearch from optax._src.linesearch import ScaleByBacktrackingLinesearchState @@ -381,6 +382,7 @@ "MultiTransformState", "nadam", "nadamw", + "nnls", "noisy_sgd", "novograd", "NonNegativeParamsState", diff --git a/optax/_src/alias.py b/optax/_src/alias.py index b2c2d1865..d251e2341 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -2482,7 +2482,7 @@ def lbfgs( ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: ', f(params)) - Objective function: 7.5166864 + Objective function: 7.516686... Objective function: 7.460699e-14 Objective function: 2.6505726e-28 Objective function: 0.0 diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index f05d5b2be..250f593f0 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -265,3 +265,100 @@ def _iter_body(state): resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) return resultant_mat_h, error + + +def masked_argmax(x, mask): + y = jnp.where(mask, x, -jnp.inf) + assert isinstance(y, jax.Array) + return jnp.argmax(y) + + +def nnls(A: jax.Array, b: jax.Array, maxiter: int, tol: float) -> jax.Array: + r"""Solves the non-negative least squares problem. + + Minimizes :math:`\|A x - b\|_2` subject to :math:`x \geq 0`. + + Args: + A: A matrix. + b: A vector. + maxiter: The maximum number of iterations to run the algorithm for. + tol: The numerical tolerance to be used by the algorithm. + + Returns: + The solution vector. + + Examples: + >>> from jax import numpy as jnp + >>> import optax + >>> A = jnp.array([[1, 2], [3, 4]]) + >>> b = jnp.array([5, 6]) + >>> x = optax.nnls(A, b, 1000, 1e-3) + >>> print(x[0]) + 0.0 + >>> print(x[1]) + 1.7 + + References: + Lawson and Hanson, `Solving Least Squares Problems + `_, 1995 + Bro and de Jong, `A fast non-negativity-constrained least squares algorithm + `_, 1999 + """ + def out_cond_fn(carry): + x, p, w, it = carry + del x, it + return ~p.all() & (w > tol).any(where=~p) + + def out_body_fn(carry): + + def in_cond_fn(carry): + x, p, s, it = carry + del x + return (it < maxiter) & (s <= 0).any(where=p) + + def in_body_fn(carry): + x, p, s, it = carry + it += 1 + alpha = (x / (x - s)).min(where=p & (s <= 0), initial=jnp.inf) + x += alpha * (s - x) + + p = jnp.where(x <= 0, False, p) + assert isinstance(p, jax.Array) + + # s = jnp.linalg.lstsq(A * p, b)[0] + s = jnp.linalg.lstsq(AtA * p[:, None] * p[None, :], Atb * p)[0] + + return x, p, s, it + + x, p, w, it = carry + + j = masked_argmax(w, ~p) + p = p.at[j].set(True) + + # s = jnp.linalg.lstsq(A * p, b)[0] + s = jnp.linalg.lstsq(AtA * p[:, None] * p[None, :], Atb * p)[0] + + x, p, s, it = lax.while_loop(in_cond_fn, in_body_fn, (x, p, s, it)) + + x = s + w = Atb - AtA @ x + + return x, p, w, it + + _, n = A.shape + + if n == 0: + return jnp.zeros(0) + + Atb = b @ A + AtA = A.T @ A + + x = jnp.zeros(n) + p = jnp.zeros(n, bool) + w = Atb.astype(float) + it = 0 + + x, p, w, it = lax.while_loop(out_cond_fn, out_body_fn, (x, p, w, it)) + + return x diff --git a/optax/_src/linear_algebra_test.py b/optax/_src/linear_algebra_test.py index a2da037e9..09fad0386 100644 --- a/optax/_src/linear_algebra_test.py +++ b/optax/_src/linear_algebra_test.py @@ -209,6 +209,33 @@ def _gen_symmetrix_matrix(dim, condition_number): # No guarantee of success after e >= 7 pass + @parameterized.parameters( + {'n': n, 'd': d} + for n in range(5) + for d in range(5) + for _ in range(5) + ) + def test_nnls(self, n, d, atol=1e-4, tol=1e-5, maxiter=10**4): + """Test non-negative least squares solver.""" + A = np.random.normal(size=(n, d)) + b = np.random.normal(size=(n,)) + + x = linear_algebra.nnls(A, b, maxiter=maxiter, tol=tol) + + with self.subTest('x is non-negative'): + assert jnp.allclose(x.clip(max=0), 0, atol=atol) + + xr, _ = scipy.optimize.nnls(A, b, maxiter=maxiter, atol=tol) + + with self.subTest('xr is non-negative'): + assert jnp.allclose(xr.clip(max=0), 0, atol=atol) + + d = jnp.square(A @ x - b).sum() + dr = jnp.square(A @ xr - b).sum() + + with self.subTest('x is optimal'): + np.testing.assert_allclose(d, dr, atol=atol) + if __name__ == '__main__': absltest.main()