From 2b1764c0ea1b351386d47925b4eaf1ed60163632 Mon Sep 17 00:00:00 2001
From: Chaithya G R
Date: Fri, 19 Jan 2024 13:03:52 +0100
Subject: [PATCH 01/45] Merge Develop for a new release (#322)
* Add support for tensorflow backend which allows for differentiability (#112)
* Added support for tensorflow
* Updates to get tests passing
* Or --> And
* Moving modopt to allow working with tensorflow
* Fix issues with wos
* Fix all flakes finally!
* Update modopt/base/backend.py
Co-authored-by: Samuel Farrens
* Update modopt/base/backend.py
Co-authored-by: Samuel Farrens
* Minute updates to codes
* Add dynamic module
* Fix docu
* Fix PEP
Co-authored-by: chaithyagr
Co-authored-by: Samuel Farrens
* Fix 115 (#116)
* Fix issues
* Add right tests
* Fix PEP
Co-authored-by: chaithyagr
* Minor bug fix, remove elif (#124)
Co-authored-by: chaithyagr
* Add tests for modopt.base.backend and fix minute bug uncovered (#126)
* Minor bug fix, remove elif
* Add tests for backend
* Fix tests
* Add tests
* Remove cupy
* PEP fixes
* Fix PEP
* Fix PEP and update
* Final PEP
* Update setup.cfg
Co-authored-by: Samuel Farrens
* Update test_base.py
Co-authored-by: chaithyagr
Co-authored-by: Samuel Farrens
* Release cleanup (#128)
* updated GPU dependencies
* added logo to manifest
* updated package version and release date
* Unpin package dependencies (#189)
* unpinned dependencies
* updated pinned documentation dependency versions
* Add Gradient descent algorithms (#196)
* Version 1.5.1 patch release (#114)
* Add support for tensorflow backend which allows for differentiability (#112)
* Added support for tensorflow
* Updates to get tests passing
* Or --> And
* Moving modopt to allow working with tensorflow
* Fix issues with wos
* Fix all flakes finally!
* Update modopt/base/backend.py
Co-authored-by: Samuel Farrens
* Update modopt/base/backend.py
Co-authored-by: Samuel Farrens
* Minute updates to codes
* Add dynamic module
* Fix docu
* Fix PEP
Co-authored-by: chaithyagr
Co-authored-by: Samuel Farrens
* Fix 115 (#116)
* Fix issues
* Add right tests
* Fix PEP
Co-authored-by: chaithyagr
* Minor bug fix, remove elif (#124)
Co-authored-by: chaithyagr
* Add tests for modopt.base.backend and fix minute bug uncovered (#126)
* Minor bug fix, remove elif
* Add tests for backend
* Fix tests
* Add tests
* Remove cupy
* PEP fixes
* Fix PEP
* Fix PEP and update
* Final PEP
* Update setup.cfg
Co-authored-by: Samuel Farrens
* Update test_base.py
Co-authored-by: chaithyagr
Co-authored-by: Samuel Farrens
* Release cleanup (#128)
* updated GPU dependencies
* added logo to manifest
* updated package version and release date
Co-authored-by: Chaithya G R
Co-authored-by: chaithyagr
* make algorithms a module.
* add Gradient Descent Algorithms
* enforce WPS compliance.
* add test for gradient descent
* Docstrings improvements
* Add See Also and minor corrections
* add idx initialisation for all algorithms.
* fix merge error
* fix typo
Co-authored-by: Samuel Farrens
Co-authored-by: Chaithya G R
Co-authored-by: chaithyagr
* Release cleanup (#198)
* started clean up for next release
* update progress
* further clean up
* additional clean up
* cleaned up link to logo
* fixed index.rst
* fixed conflict
* Fast Singular Value Thresholding (#209)
* add SingularValueThreshold
This Method provides 10x faster SVT estimation than the LowRankMatrix Operator.
* linting
* add test for fast computation.
* flake8 compliance
* Ignore DAR000 Error.
* Update modopt/signal/svd.py
tuples in docstring
Co-authored-by: Samuel Farrens
* Update modopt/signal/svd.py
typo
Co-authored-by: Samuel Farrens
* Update modopt/opt/proximity.py
typo
Co-authored-by: Samuel Farrens
* update docstring
* fix isort
* Update modopt/signal/svd.py
Co-authored-by: Samuel Farrens
* Update modopt/signal/svd.py
Co-authored-by: Samuel Farrens
* run isort
Co-authored-by: Samuel Farrens
* added writeable input data array feature for benchopt (#213)
* removed flake8 limit
* updated patch version
* [lint] pydocstyle compliance. (#228)
* [lint] pydocstyle compliance.
* use pytest-pydocstyle
* Power method: fix #211 (#212)
* Correct the norm update for Power Method
x_new should be divided by its norm, not by x_old_norm.
* fix test value
We are testing for eigen value of Identity. It should be one.
* fix WPS350
* fix test value for unconverged case
Co-authored-by: Samuel Farrens
* Switch from progressbar to tqdm (#231)
* switch from progressbar to tqdm.
The progress bar can be provided externally for nested usage.
* exposes the progress bar argument.
* Child classes better have to implement these.
(my linter was complaining)
* update docs for progress bar using tqdm.
* fix WPS errors
* drop progressbar requirement, add tqdm.
* [lint] disable warning for non implemented function.
* simplify progbar check and argument passthrough
* Update README for tqdm dependency (#240)
Remote progressbar, use tqdm.
* add small help for the metric argument. (#241)
* add small help for the metric argument.
* RST validation
* use single quote
* use double backticks.
Co-authored-by: Samuel Farrens
* add implementation for admm and fast admm.
Based on Goldstein2014
* add Goldstein ref.
* WPS compliance.
* Abstract class for cost function.
* add custom cost operator for admm.
* fix WPS compliance.
* Ci update (#268)
* update python version support.
* use string for CI.
* remove flake8 and wemake-python-styleguide
This anticipates the change to black formatting.
* remove wps checks
* apparently conda does not support 3.11 for now
* remove all linting testing.
* fix np.int warning/error
* fix dtype error
* fix precision for doctest
* added black and isort support
* Update python version in README
* add 3.7 for test back
* don't test 3.10 twice
* Test rewrite (#266)
* add MatrixOperator.
* move base test to pytest.
* [fixme] remove flake8 and emoji config.
* rewrite test_math module using pytest.
* use fail/skipparam helper function.
* generalize usage of failparam
* refactor test_signal.
* refactor test_signal, the end.
* lint
* fix missing parameter.
* add dummy object test helper.
* rewrite test for cost and gradients.
* show missing lines in coverage reports
* rewrite of proximity operators testing.
* add fail low rank method.
* add cases for algorithms test
* add algorithm test.
* add pytest-cases and pytest-xdists support.
* add support for testing metrics.
* improve base module coverage.
* test for wrong mask in metric module.
* add docstring.
* update email adress and authors field.
* 100% coverage for transform module.
* move linear operator to class
* update docstring.
* paramet(e)rization.
* update docstring.
* improve test_helper module.
* raises should be specified for each failparam call.
* encapsulate module's test in classes.
* skip test if sklearn is not installed.
* pin pydocstyle
* removed unnormalised Gaussian kernel option and corresponding test
* Restrict scikit-image version for testing
* added fix for basic test suite
* set behaviour for different astropy versions
* updated docstring for gaussian_kernel
* Use example scripts as tests. (#277)
* Initialize the example module.
* do not export the assert statements.
* add matplotlib as requirement.
* add support for sphinx-gallery
* Update modopt/examples/README.rst
Co-authored-by: Samuel Farrens
* Update modopt/examples/__init__.py
Co-authored-by: Samuel Farrens
* Update modopt/examples/conftest.py
Co-authored-by: Samuel Farrens
* Update modopt/examples/example_lasso_forward_backward.py
Co-authored-by: Samuel Farrens
* Update modopt/examples/example_lasso_forward_backward.py
Co-authored-by: Samuel Farrens
* ignore auto_example folder
* doc formatting.
* add pogm and basic comparison.
* fix: add matplotlib for the plotting in examples scripts.
* fix: add matplotlib for basic ci too.
* ci: run pytest with xdist for faster testing
---------
Co-authored-by: Samuel Farrens
* fix: specify data_range for ssim.
Refs: #290
* typos.
* feat(test): add test for admm.
* feat(admm): improve doc.
* refactor: rename abstract cost to CostParent.
* feat: add test for fast admm.
* feat(admm): improve docstrings.
* style: remove extra line.c
* feat: make POGM more memory efficient.
* feat: add a dummy cost for the identity operator.
* feat: create a linear operator module, add wavelet transform.
* feat: add test case for wavelet transform.
* Update setup.py
---------
Co-authored-by: chaithyagr
Co-authored-by: Samuel Farrens
Co-authored-by: Pierre-Antoine Comby <77174042+paquiteau@users.noreply.github.com>
Co-authored-by: Pierre-antoine Comby
Co-authored-by: Pierre-Antoine Comby
---
.github/workflows/cd-build.yml | 4 +-
.github/workflows/ci-build.yml | 21 +-
.gitignore | 1 +
README.md | 4 +-
develop.txt | 13 +-
docs/requirements.txt | 1 +
docs/source/conf.py | 13 +
docs/source/refs.bib | 12 +
docs/source/toc.rst | 1 +
modopt/examples/README.rst | 5 +
modopt/examples/__init__.py | 10 +
modopt/examples/conftest.py | 46 +
.../example_lasso_forward_backward.py | 153 ++
modopt/math/matrix.py | 17 +-
modopt/math/metrics.py | 6 +-
modopt/math/stats.py | 46 +-
modopt/opt/algorithms/__init__.py | 3 +-
modopt/opt/algorithms/admm.py | 337 ++++
modopt/opt/algorithms/base.py | 73 +-
modopt/opt/algorithms/forward_backward.py | 38 +-
modopt/opt/algorithms/primal_dual.py | 11 +-
modopt/opt/cost.py | 152 +-
modopt/opt/linear/__init__.py | 21 +
modopt/opt/{linear.py => linear/base.py} | 58 +-
modopt/opt/linear/wavelet.py | 216 +++
modopt/opt/proximity.py | 2 +-
modopt/signal/filter.py | 8 +-
modopt/signal/positivity.py | 2 +-
modopt/signal/svd.py | 4 +-
modopt/tests/test_algorithms.py | 673 +++-----
modopt/tests/test_base.py | 435 ++----
modopt/tests/test_helpers/__init__.py | 1 +
modopt/tests/test_helpers/utils.py | 23 +
modopt/tests/test_math.py | 661 +++-----
modopt/tests/test_opt.py | 1390 +++++------------
modopt/tests/test_signal.py | 561 +++----
requirements.txt | 2 +-
setup.cfg | 12 +-
setup.py | 4 +-
39 files changed, 2439 insertions(+), 2601 deletions(-)
create mode 100644 modopt/examples/README.rst
create mode 100644 modopt/examples/__init__.py
create mode 100644 modopt/examples/conftest.py
create mode 100644 modopt/examples/example_lasso_forward_backward.py
create mode 100644 modopt/opt/algorithms/admm.py
create mode 100644 modopt/opt/linear/__init__.py
rename modopt/opt/{linear.py => linear/base.py} (84%)
create mode 100644 modopt/opt/linear/wavelet.py
create mode 100644 modopt/tests/test_helpers/__init__.py
create mode 100644 modopt/tests/test_helpers/utils.py
diff --git a/.github/workflows/cd-build.yml b/.github/workflows/cd-build.yml
index 1e49f8bc..fca9feb1 100644
--- a/.github/workflows/cd-build.yml
+++ b/.github/workflows/cd-build.yml
@@ -62,9 +62,7 @@ jobs:
- name: Set up Conda with Python 3.8
uses: conda-incubator/setup-miniconda@v2
with:
- auto-update-conda: true
- python-version: 3.8
- auto-activate-base: false
+ python-version: "3.8"
- name: Install dependencies
shell: bash -l {0}
diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml
index 3ffcb6f4..c4ba28a0 100644
--- a/.github/workflows/ci-build.yml
+++ b/.github/workflows/ci-build.yml
@@ -16,21 +16,12 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
- python-version: [3.8]
+ python-version: ["3.10"]
steps:
- name: Checkout
uses: actions/checkout@v2
- - name: Report WPS Errors
- uses: wemake-services/wemake-python-styleguide@0.14.1
- continue-on-error: true
- with:
- reporter: 'github-pr-review'
- path: './modopt'
- env:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
- name: Set up Conda with Python ${{ matrix.python-version }}
uses: conda-incubator/setup-miniconda@v2
with:
@@ -52,7 +43,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install -r develop.txt
python -m pip install -r docs/requirements.txt
- python -m pip install astropy scikit-image scikit-learn
+ python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
python -m pip install tensorflow>=2.4.1
python -m pip install twine
python -m pip install .
@@ -61,7 +52,7 @@ jobs:
shell: bash -l {0}
run: |
export PATH=/usr/share/miniconda/bin:$PATH
- python setup.py test
+ pytest -n 2
- name: Save Test Results
if: always()
@@ -98,7 +89,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
- python-version: [3.6, 3.7, 3.9]
+ python-version: ["3.7", "3.8", "3.9"]
steps:
- name: Checkout
@@ -117,11 +108,11 @@ jobs:
python --version
python -m pip install --upgrade pip
python -m pip install -r develop.txt
- python -m pip install astropy scikit-image scikit-learn
+ python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
python -m pip install .
- name: Run Tests
shell: bash -l {0}
run: |
export PATH=/usr/share/miniconda/bin:$PATH
- python setup.py test
+ pytest -n 2
diff --git a/.gitignore b/.gitignore
index 06dff8db..f9eaaa68 100644
--- a/.gitignore
+++ b/.gitignore
@@ -73,6 +73,7 @@ instance/
docs/_build/
docs/source/fortuna.*
docs/source/scripts.*
+docs/source/auto_examples/
docs/source/*.nblink
# PyBuilder
diff --git a/README.md b/README.md
index acb316ad..223d0b73 100644
--- a/README.md
+++ b/README.md
@@ -37,11 +37,11 @@ All packages required by ModOpt should be installed automatically. Optional pack
In order to run the code in this repository the following packages must be
installed:
-* [Python](https://www.python.org/) [> 3.6]
+* [Python](https://www.python.org/) [> 3.7]
* [importlib_metadata](https://importlib-metadata.readthedocs.io/en/latest/) [==3.7.0]
* [Numpy](http://www.numpy.org/) [==1.19.5]
* [Scipy](http://www.scipy.org/) [==1.5.4]
-* [Progressbar 2](https://progressbar-2.readthedocs.io/) [==3.53.1]
+* [tqdm](https://tqdm.github.io/) [>=4.64.0]
### Optional Packages
diff --git a/develop.txt b/develop.txt
index 8beef0ff..6ff665eb 100644
--- a/develop.txt
+++ b/develop.txt
@@ -1,9 +1,12 @@
coverage>=5.5
-flake8>=4
-nose>=1.3.7
pytest>=6.2.2
+pytest-raises>=0.10
+pytest-cases>= 3.6
+pytest-xdist>= 3.0.1
pytest-cov>=2.11.1
-pytest-pep8>=1.0.6
pytest-emoji>=0.2.0
-pytest-flake8>=1.0.7
-wemake-python-styleguide>=0.15.2
+pydocstyle==6.1.1
+pytest-pydocstyle>=2.2.0
+black
+isort
+pytest-black
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 4d2a14fb..c9e29c88 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -6,3 +6,4 @@ numpydoc==1.1.0
sphinx==4.3.1
sphinxcontrib-bibtex==2.4.1
sphinxawesome-theme==3.2.1
+sphinx-gallery==0.11.1
diff --git a/docs/source/conf.py b/docs/source/conf.py
index fb954f6d..46564b9f 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -45,6 +45,7 @@
'nbsphinx',
'nbsphinx_link',
'numpydoc',
+ "sphinx_gallery.gen_gallery"
]
# Include module names for objects
@@ -145,6 +146,18 @@
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
html_show_copyright = True
+
+
+# -- Options for Sphinx Gallery ----------------------------------------------
+
+sphinx_gallery_conf = {
+ "examples_dirs": ["../../modopt/examples/"],
+ "filename_pattern": "/example_",
+ "ignore_pattern": r"/(__init__|conftest)\.py",
+}
+
+
+
# -- Options for nbshpinx output ------------------------------------------
diff --git a/docs/source/refs.bib b/docs/source/refs.bib
index d8365e71..7782ca52 100644
--- a/docs/source/refs.bib
+++ b/docs/source/refs.bib
@@ -207,3 +207,15 @@ @article{zou2005
journal = {Journal of the Royal Statistical Society Series B},
doi = {10.1111/j.1467-9868.2005.00527.x}
}
+
+@article{Goldstein2014,
+ author={Goldstein, Tom and O’Donoghue, Brendan and Setzer, Simon and Baraniuk, Richard},
+ year={2014},
+ month={Jan},
+ pages={1588–1623},
+ title={Fast Alternating Direction Optimization Methods},
+ journal={SIAM Journal on Imaging Sciences},
+ volume={7},
+ ISSN={1936-4954},
+ doi={10/gdwr49},
+}
diff --git a/docs/source/toc.rst b/docs/source/toc.rst
index 84a6af87..ef5753f5 100644
--- a/docs/source/toc.rst
+++ b/docs/source/toc.rst
@@ -25,6 +25,7 @@
plugin_example
notebooks
+ auto_examples/index
.. toctree::
:hidden:
diff --git a/modopt/examples/README.rst b/modopt/examples/README.rst
new file mode 100644
index 00000000..e6ffbe27
--- /dev/null
+++ b/modopt/examples/README.rst
@@ -0,0 +1,5 @@
+========
+Examples
+========
+
+This is a collection of Python scripts demonstrating the use of ModOpt.
diff --git a/modopt/examples/__init__.py b/modopt/examples/__init__.py
new file mode 100644
index 00000000..d7e77357
--- /dev/null
+++ b/modopt/examples/__init__.py
@@ -0,0 +1,10 @@
+"""EXAMPLES.
+
+This module contains documented examples that demonstrate the usage of various
+ModOpt tools.
+
+These examples also serve as integration tests for various methods.
+
+:Author: Pierre-Antoine Comby
+
+"""
diff --git a/modopt/examples/conftest.py b/modopt/examples/conftest.py
new file mode 100644
index 00000000..73358679
--- /dev/null
+++ b/modopt/examples/conftest.py
@@ -0,0 +1,46 @@
+"""TEST CONFIGURATION.
+
+This module contains methods for configuring the testing of the example
+scripts.
+
+:Author: Pierre-Antoine Comby
+
+Notes
+-----
+Based on:
+https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test
+
+"""
+from pathlib import Path
+import runpy
+import pytest
+
+def pytest_collect_file(path, parent):
+ """Pytest hook.
+
+ Create a collector for the given path, or None if not relevant.
+ The new node needs to have the specified parent as parent.
+ """
+ p = Path(path)
+ if p.suffix == '.py' and 'example' in p.name:
+ return Script.from_parent(parent, path=p, name=p.name)
+
+
+class Script(pytest.File):
+ """Script files collected by pytest."""
+
+ def collect(self):
+ """Collect the script as its own item."""
+ yield ScriptItem.from_parent(self, name=self.name)
+
+class ScriptItem(pytest.Item):
+ """Item script collected by pytest."""
+
+ def runtest(self):
+ """Run the script as a test."""
+ runpy.run_path(str(self.path))
+
+ def repr_failure(self, excinfo):
+ """Return only the error traceback of the script."""
+ excinfo.traceback = excinfo.traceback.cut(path=self.path)
+ return super().repr_failure(excinfo)
diff --git a/modopt/examples/example_lasso_forward_backward.py b/modopt/examples/example_lasso_forward_backward.py
new file mode 100644
index 00000000..7f820000
--- /dev/null
+++ b/modopt/examples/example_lasso_forward_backward.py
@@ -0,0 +1,153 @@
+# noqa: D205
+"""
+Solving the LASSO Problem with the Forward Backward Algorithm.
+==============================================================
+
+This an example to show how to solve an example LASSO Problem
+using the Forward-Backward Algorithm.
+
+In this example we are going to use:
+ - Modopt Operators (Linear, Gradient, Proximal)
+ - Modopt implementation of solvers
+ - Modopt Metric API.
+TODO: add reference to LASSO paper.
+"""
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from modopt.opt.algorithms import ForwardBackward, POGM
+from modopt.opt.cost import costObj
+from modopt.opt.linear import LinearParent, Identity
+from modopt.opt.gradient import GradBasic
+from modopt.opt.proximity import SparseThreshold
+from modopt.math.matrix import PowerMethod
+from modopt.math.stats import mse
+
+# %%
+# Here we create a instance of the LASSO Problem
+
+BETA_TRUE = np.array(
+ [3.0, 1.5, 0, 0, 2, 0, 0, 0]
+) # 8 original values from lLASSO Paper
+DIM = len(BETA_TRUE)
+
+
+rng = np.random.default_rng()
+sigma_noise = 1
+obs = 20
+# create a measurement matrix with decaying covariance matrix.
+cov = 0.4 ** abs((np.arange(DIM) * np.ones((DIM, DIM))).T - np.arange(DIM))
+x = rng.multivariate_normal(np.zeros(DIM), cov, obs)
+
+y = x @ BETA_TRUE
+y_noise = y + (sigma_noise * np.random.standard_normal(obs))
+
+
+# %%
+# Next we create Operators for solving the problem.
+
+# MatrixOperator could also work here.
+lin_op = LinearParent(lambda b: x @ b, lambda bb: x.T @ bb)
+grad_op = GradBasic(y_noise, op=lin_op.op, trans_op=lin_op.adj_op)
+
+prox_op = SparseThreshold(Identity(), 1, thresh_type="soft")
+
+# %%
+# In order to get the best convergence rate, we first determine the Lipschitz constant of the gradient Operator
+#
+
+calc_lips = PowerMethod(grad_op.trans_op_op, 8, data_type="float32", auto_run=True)
+lip = calc_lips.spec_rad
+print("lipschitz constant:", lip)
+
+# %%
+# Solving using FISTA algorithm
+# -----------------------------
+#
+# TODO: Add description/Reference of FISTA.
+
+cost_op_fista = costObj([grad_op, prox_op], verbose=False)
+
+fb_fista = ForwardBackward(
+ np.zeros(8),
+ beta_param=1 / lip,
+ grad=grad_op,
+ prox=prox_op,
+ cost=cost_op_fista,
+ metric_call_period=1,
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+)
+
+fb_fista.iterate()
+
+# %%
+# After the run we can have a look at the results
+
+print(fb_fista.x_final)
+mse_fista = mse(fb_fista.x_final, BETA_TRUE)
+plt.stem(fb_fista.x_final, label="estimation", linefmt="C0-")
+plt.stem(BETA_TRUE, label="reference", linefmt="C1-")
+plt.legend()
+plt.title(f"FISTA Estimation MSE={mse_fista:.4f}")
+
+# sphinx_gallery_start_ignore
+assert mse(fb_fista.x_final, BETA_TRUE) < 1
+# sphinx_gallery_end_ignore
+
+
+# %%
+# Solving Using the POGM Algorithm
+# --------------------------------
+#
+# TODO: Add description/Reference to POGM.
+
+
+cost_op_pogm = costObj([grad_op, prox_op], verbose=False)
+
+fb_pogm = POGM(
+ np.zeros(8),
+ np.zeros(8),
+ np.zeros(8),
+ np.zeros(8),
+ beta_param=1 / lip,
+ grad=grad_op,
+ prox=prox_op,
+ cost=cost_op_pogm,
+ metric_call_period=1,
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+)
+
+fb_pogm.iterate()
+
+# %%
+# After the run we can have a look at the results
+
+print(fb_pogm.x_final)
+mse_pogm = mse(fb_pogm.x_final, BETA_TRUE)
+
+plt.stem(fb_pogm.x_final, label="estimation", linefmt="C0-")
+plt.stem(BETA_TRUE, label="reference", linefmt="C1-")
+plt.legend()
+plt.title(f"FISTA Estimation MSE={mse_pogm:.4f}")
+#
+# sphinx_gallery_start_ignore
+assert mse(fb_pogm.x_final, BETA_TRUE) < 1
+
+# %%
+# Comparing the Two algorithms
+# ----------------------------
+
+plt.figure()
+plt.semilogy(cost_op_fista._cost_list, label="FISTA convergence")
+plt.semilogy(cost_op_pogm._cost_list, label="POGM convergence")
+plt.xlabel("iterations")
+plt.ylabel("Cost Function")
+plt.legend()
+plt.show()
+
+
+# %%
+# We can see that the two algorithm converges quickly, and POGM requires less iterations.
+# However the POGM iterations are more costly, so a proper benchmark with time measurement is needed.
+# Check the benchopt benchmark for more details.
diff --git a/modopt/math/matrix.py b/modopt/math/matrix.py
index 939cf41f..8361531d 100644
--- a/modopt/math/matrix.py
+++ b/modopt/math/matrix.py
@@ -285,9 +285,9 @@ class PowerMethod(object):
>>> np.random.seed(1)
>>> pm = PowerMethod(lambda x: x.dot(x.T), (3, 3))
>>> np.around(pm.spec_rad, 6)
- 0.904292
+ 1.0
>>> np.around(pm.inv_spec_rad, 6)
- 1.105837
+ 1.0
Notes
-----
@@ -348,17 +348,21 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
# Set (or reset) values of x.
x_old = self._set_initial_x()
+ xp = get_array_module(x_old)
+ x_old_norm = xp.linalg.norm(x_old)
+
+ x_old /= x_old_norm
+
# Iterate until the L2 norm of x converges.
for i_elem in range(max_iter):
-
xp = get_array_module(x_old)
- x_old_norm = xp.linalg.norm(x_old)
-
- x_new = self._operator(x_old) / x_old_norm
+ x_new = self._operator(x_old)
x_new_norm = xp.linalg.norm(x_new)
+ x_new /= x_new_norm
+
if (xp.abs(x_new_norm - x_old_norm) < tolerance):
message = (
' - Power Method converged after {0} iterations!'
@@ -374,6 +378,7 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
print(message.format(max_iter))
xp.copyto(x_old, x_new)
+ x_old_norm = x_new_norm
self.spec_rad = x_new_norm * extra_factor
self.inv_spec_rad = 1.0 / self.spec_rad
diff --git a/modopt/math/metrics.py b/modopt/math/metrics.py
index 1b870e23..21952624 100644
--- a/modopt/math/metrics.py
+++ b/modopt/math/metrics.py
@@ -23,7 +23,7 @@
def min_max_normalize(img):
"""Min-Max Normalize.
- Centre and normalize a given array.
+ Normalize a given array in the [0,1] range.
Parameters
----------
@@ -33,7 +33,7 @@ def min_max_normalize(img):
Returns
-------
numpy.ndarray
- Centred and normalized array
+ normalized array
"""
min_img = img.min()
@@ -126,7 +126,7 @@ def ssim(test, ref, mask=None):
test, ref, mask = _preprocess_input(test, ref, mask)
test = move_to_cpu(test)
- assim, ssim_value = compare_ssim(test, ref, full=True)
+ assim, ssim_value = compare_ssim(test, ref, full=True, data_range=1.0)
if mask is None:
return assim
diff --git a/modopt/math/stats.py b/modopt/math/stats.py
index 3ac818a7..59bf6759 100644
--- a/modopt/math/stats.py
+++ b/modopt/math/stats.py
@@ -11,6 +11,8 @@
import numpy as np
try:
+ from packaging import version
+ from astropy import __version__ as astropy_version
from astropy.convolution import Gaussian2DKernel
except ImportError: # pragma: no cover
import_astropy = False
@@ -18,7 +20,7 @@
import_astropy = True
-def gaussian_kernel(data_shape, sigma, norm='max'):
+def gaussian_kernel(data_shape, sigma, norm="max"):
"""Gaussian kernel.
This method produces a Gaussian kerenal of a specified size and dispersion.
@@ -29,9 +31,8 @@ def gaussian_kernel(data_shape, sigma, norm='max'):
Desiered shape of the kernel
sigma : float
Standard deviation of the kernel
- norm : {'max', 'sum', 'none'}, optional
- Normalisation of the kerenl (options are ``'max'``, ``'sum'`` or
- ``'none'``, default is ``'max'``)
+ norm : {'max', 'sum'}, optional
+ Normalisation of the kerenl (options are ``'max'`` or ``'sum'``, default is ``'max'``)
Returns
-------
@@ -60,22 +61,22 @@ def gaussian_kernel(data_shape, sigma, norm='max'):
"""
if not import_astropy: # pragma: no cover
- raise ImportError('Astropy package not found.')
+ raise ImportError("Astropy package not found.")
- if norm not in {'max', 'sum', 'none'}:
+ if norm not in {"max", "sum"}:
raise ValueError('Invalid norm, options are "max", "sum" or "none".')
kernel = np.array(
Gaussian2DKernel(sigma, x_size=data_shape[1], y_size=data_shape[0]),
)
- if norm == 'max':
+ if norm == "max":
return kernel / np.max(kernel)
- elif norm == 'sum':
+ elif version.parse(astropy_version) < version.parse("5.2"):
return kernel / np.sum(kernel)
- elif norm == 'none':
+ else:
return kernel
@@ -147,7 +148,7 @@ def mse(data1, data2):
return np.mean((data1 - data2) ** 2)
-def psnr(data1, data2, method='starck', max_pix=255):
+def psnr(data1, data2, method="starck", max_pix=255):
r"""Peak Signal-to-Noise Ratio.
This method calculates the Peak Signal-to-Noise Ratio between two data
@@ -202,23 +203,21 @@ def psnr(data1, data2, method='starck', max_pix=255):
10\log_{10}(\mathrm{MSE}))
"""
- if method == 'starck':
- return (
- 20 * np.log10(
- (data1.shape[0] * np.abs(np.max(data1) - np.min(data1)))
- / np.linalg.norm(data1 - data2),
- )
+ if method == "starck":
+ return 20 * np.log10(
+ (data1.shape[0] * np.abs(np.max(data1) - np.min(data1)))
+ / np.linalg.norm(data1 - data2),
)
- elif method == 'wiki':
- return (20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2)))
+ elif method == "wiki":
+ return 20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2))
raise ValueError(
'Invalid PSNR method. Options are "starck" and "wiki"',
)
-def psnr_stack(data1, data2, metric=np.mean, method='starck'):
+def psnr_stack(data1, data2, metric=np.mean, method="starck"):
"""Peak Signa-to-Noise for stack of images.
This method calculates the PSNRs for two stacks of 2D arrays.
@@ -261,12 +260,11 @@ def psnr_stack(data1, data2, metric=np.mean, method='starck'):
"""
if data1.ndim != 3 or data2.ndim != 3:
- raise ValueError('Input data must be a 3D np.ndarray')
+ raise ValueError("Input data must be a 3D np.ndarray")
- return metric([
- psnr(i_elem, j_elem, method=method)
- for i_elem, j_elem in zip(data1, data2)
- ])
+ return metric(
+ [psnr(i_elem, j_elem, method=method) for i_elem, j_elem in zip(data1, data2)]
+ )
def sigma_mad(input_data):
diff --git a/modopt/opt/algorithms/__init__.py b/modopt/opt/algorithms/__init__.py
index e0ac2572..d4e7082b 100644
--- a/modopt/opt/algorithms/__init__.py
+++ b/modopt/opt/algorithms/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-r"""OPTIMISATION ALGOTITHMS.
+r"""OPTIMISATION ALGORITHMS.
This module contains class implementations of various optimisation algoritms.
@@ -57,3 +57,4 @@
SAGAOptGradOpt,
VanillaGenericGradOpt)
from modopt.opt.algorithms.primal_dual import Condat
+from modopt.opt.algorithms.admm import ADMM, FastADMM
diff --git a/modopt/opt/algorithms/admm.py b/modopt/opt/algorithms/admm.py
new file mode 100644
index 00000000..b881b770
--- /dev/null
+++ b/modopt/opt/algorithms/admm.py
@@ -0,0 +1,337 @@
+"""ADMM Algorithms."""
+import numpy as np
+
+from modopt.base.backend import get_array_module
+from modopt.opt.algorithms.base import SetUp
+from modopt.opt.cost import CostParent
+
+
+class ADMMcostObj(CostParent):
+ r"""Cost Object for the ADMM problem class.
+
+ Parameters
+ ----------
+ cost_funcs: 2-tuples of callable
+ f and g function.
+ A : OperatorBase
+ First Operator
+ B : OperatorBase
+ Second Operator
+ b : numpy.ndarray
+ Observed data
+ **kwargs : dict
+ Extra parameters for cost operator configuration
+
+ Notes
+ -----
+ Compute :math:`f(u)+g(v) + \tau \| Au +Bv - b\|^2`
+
+ See Also
+ --------
+ CostParent: parent class
+ """
+
+ def __init__(self, cost_funcs, A, B, b, tau, **kwargs):
+ super().__init__(*kwargs)
+ self.cost_funcs = cost_funcs
+ self.A = A
+ self.B = B
+ self.b = b
+ self.tau = tau
+
+ def _calc_cost(self, u, v, **kwargs):
+ """Calculate the cost.
+
+ This method calculates the cost from each of the input operators.
+
+ Parameters
+ ----------
+ u: numpy.ndarray
+ First primal variable of ADMM
+ v: numpy.ndarray
+ Second primal variable of ADMM
+
+ Returns
+ -------
+ float
+ Cost value
+
+ """
+ xp = get_array_module(u)
+ cost = self.cost_funcs[0](u)
+ cost += self.cost_funcs[1](v)
+ cost += self.tau * xp.linalg.norm(self.A.op(u) + self.B.op(v) - self.b)
+ return cost
+
+
+class ADMM(SetUp):
+ r"""Fast ADMM Optimisation Algorihm.
+
+ This class implement the ADMM algorithm described in :cite:`Goldstein2014` (Algorithm 1).
+
+ Parameters
+ ----------
+ u: numpy.ndarray
+ Initial value for first primal variable of ADMM
+ v: numpy.ndarray
+ Initial value for second primal variable of ADMM
+ mu: numpy.ndarray
+ Initial value for lagrangian multiplier.
+ A : modopt.opt.linear.LinearOperator
+ Linear operator for u
+ B: modopt.opt.linear.LinearOperator
+ Linear operator for v
+ b : numpy.ndarray
+ Constraint vector
+ optimizers: tuple
+ 2-tuple of callable, that are the optimizers for the u and v.
+ Each callable should access an init and obs argument and returns an estimate for:
+ .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2
+ .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2
+ cost_funcs: tuple
+ 2-tuple of callable, that compute values of H and G.
+ tau: float, default=1
+ Coupling parameter for ADMM.
+
+ Notes
+ -----
+ The algorithm solve the problem:
+
+ .. math:: u, v = \arg\min H(u) + G(v) + \frac\tau2 \|Au + Bv - b \|_2^2
+
+ with the following augmented lagrangian:
+
+ .. math :: \mathcal{L}_{\tau}(u,v, \lambda) = H(u) + G(v)
+ +\langle\lambda |Au + Bv -b \rangle + \frac\tau2 \| Au + Bv -b \|^2
+
+ To allow easy iterative solving, the change of variable
+ :math:`\mu=\lambda/\tau` is used. Hence, the lagrangian of interest is:
+
+ .. math :: \tilde{\mathcal{L}}_{\tau}(u,v, \mu) = H(u) + G(v)
+ + \frac\tau2 \left(\|\mu + Au +Bv - b\|^2 - \|\mu\|^2\right)
+
+ See Also
+ --------
+ SetUp: parent class
+ """
+
+ def __init__(
+ self,
+ u,
+ v,
+ mu,
+ A,
+ B,
+ b,
+ optimizers,
+ tau=1,
+ cost_funcs=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.A = A
+ self.B = B
+ self.b = b
+ self._opti_H = optimizers[0]
+ self._opti_G = optimizers[1]
+ self._tau = tau
+ if cost_funcs is not None:
+ self._cost_func = ADMMcostObj(cost_funcs, A, B, b, tau)
+ else:
+ self._cost_func = None
+
+ # init iteration variables.
+ self._u_old = self.xp.copy(u)
+ self._u_new = self.xp.copy(u)
+ self._v_old = self.xp.copy(v)
+ self._v_new = self.xp.copy(v)
+ self._mu_new = self.xp.copy(mu)
+ self._mu_old = self.xp.copy(mu)
+
+ def _update(self):
+ self._u_new = self._opti_H(
+ init=self._u_old,
+ obs=self.B.op(self._v_old) + self._u_old - self.b,
+ )
+ tmp = self.A.op(self._u_new)
+ self._v_new = self._opti_G(
+ init=self._v_old,
+ obs=tmp + self._u_old - self.b,
+ )
+
+ self._mu_new = self._mu_old + (tmp + self.B.op(self._v_new) - self.b)
+
+ # update cycle
+ self._u_old = self.xp.copy(self._u_new)
+ self._v_old = self.xp.copy(self._v_new)
+ self._mu_old = self.xp.copy(self._mu_new)
+
+ # Test cost function for convergence.
+ if self._cost_func:
+ self.converge = self.any_convergence_flag()
+ self.converge |= self._cost_func.get_cost(self._u_new, self._v_new)
+
+ def iterate(self, max_iter=150):
+ """Iterate.
+
+ This method calls update until either convergence criteria is met or
+ the maximum number of iterations is reached.
+
+ Parameters
+ ----------
+ max_iter : int, optional
+ Maximum number of iterations (default is ``150``)
+ """
+ self._run_alg(max_iter)
+
+ # retrieve metrics results
+ self.retrieve_outputs()
+ # rename outputs as attributes
+ self.u_final = self._u_new
+ self.x_final = self.u_final # for backward compatibility
+ self.v_final = self._v_new
+
+ def get_notify_observers_kwargs(self):
+ """Notify observers.
+
+ Return the mapping between the metrics call and the iterated
+ variables.
+
+ Returns
+ -------
+ dict
+ The mapping between the iterated variables
+ """
+ return {
+ 'x_new': self._u_new,
+ 'v_new': self._v_new,
+ 'idx': self.idx,
+ }
+
+ def retrieve_outputs(self):
+ """Retrieve outputs.
+
+ Declare the outputs of the algorithms as attributes: x_final,
+ y_final, metrics.
+ """
+ metrics = {}
+ for obs in self._observers['cv_metrics']:
+ metrics[obs.name] = obs.retrieve_metrics()
+ self.metrics = metrics
+
+
+class FastADMM(ADMM):
+ r"""Fast ADMM Optimisation Algorihm.
+
+ This class implement the fast ADMM algorithm
+ (Algorithm 8 from :cite:`Goldstein2014`)
+
+ Parameters
+ ----------
+ u: numpy.ndarray
+ Initial value for first primal variable of ADMM
+ v: numpy.ndarray
+ Initial value for second primal variable of ADMM
+ mu: numpy.ndarray
+ Initial value for lagrangian multiplier.
+ A : modopt.opt.linear.LinearOperator
+ Linear operator for u
+ B: modopt.opt.linear.LinearOperator
+ Linear operator for v
+ b : numpy.ndarray
+ Constraint vector
+ optimizers: tuple
+ 2-tuple of callable, that are the optimizers for the u and v.
+ Each callable should access an init and obs argument and returns an estimate for:
+ .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2
+ .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2
+ cost_funcs: tuple
+ 2-tuple of callable, that compute values of H and G.
+ tau: float, default=1
+ Coupling parameter for ADMM.
+ eta: float, default=0.999
+ Convergence parameter for ADMM.
+ alpha: float, default=1.
+ Initial value for the FISTA-like acceleration parameter.
+
+ Notes
+ -----
+ This is an accelerated version of the ADMM algorithm. The convergence hypothesis are stronger than for the ADMM algorithm.
+
+ See Also
+ --------
+ ADMM: parent class
+ """
+
+ def __init__(
+ self,
+ u,
+ v,
+ mu,
+ A,
+ B,
+ b,
+ optimizers,
+ cost_funcs=None,
+ alpha=1,
+ eta=0.999,
+ tau=1,
+ **kwargs,
+ ):
+ super().__init__(
+ u=u,
+ v=b,
+ mu=mu,
+ A=A,
+ B=B,
+ b=b,
+ optimizers=optimizers,
+ cost_funcs=cost_funcs,
+ **kwargs,
+ )
+ self._c_old = np.inf
+ self._c_new = 0
+ self._eta = eta
+ self._alpha_old = alpha
+ self._alpha_new = alpha
+ self._v_hat = self.xp.copy(self._v_new)
+ self._mu_hat = self.xp.copy(self._mu_new)
+
+ def _update(self):
+ # Classical ADMM steps
+ self._u_new = self._opti_H(
+ init=self._u_old,
+ obs=self.B.op(self._v_hat) + self._u_old - self.b,
+ )
+ tmp = self.A.op(self._u_new)
+ self._v_new = self._opti_G(
+ init=self._v_hat,
+ obs=tmp + self._u_old - self.b,
+ )
+
+ self._mu_new = self._mu_hat + (tmp + self.B.op(self._v_new) - self.b)
+
+ # restarting condition
+ self._c_new = self.xp.linalg.norm(self._mu_new - self._mu_hat)
+ self._c_new += self._tau * self.xp.linalg.norm(
+ self.B.op(self._v_new - self._v_hat),
+ )
+ if self._c_new < self._eta * self._c_old:
+ self._alpha_new = 1 + np.sqrt(1 + 4 * self._alpha_old**2)
+ beta = (self._alpha_new - 1) / self._alpha_old
+ self._v_hat = self._v_new + (self._v_new - self._v_old) * beta
+ self._mu_hat = self._mu_new + (self._mu_new - self._mu_old) * beta
+ else:
+ # reboot to old iteration
+ self._alpha_new = 1
+ self._v_hat = self._v_old
+ self._mu_hat = self._mu_old
+ self._c_new = self._c_old / self._eta
+
+ self.xp.copyto(self._u_old, self._u_new)
+ self.xp.copyto(self._v_old, self._v_new)
+ self.xp.copyto(self._mu_old, self._mu_new)
+ # Test cost function for convergence.
+ if self._cost_func:
+ self.converge = self.any_convergence_flag()
+ self.convergd |= self._cost_func.get_cost(self._u_new, self._v_new)
diff --git a/modopt/opt/algorithms/base.py b/modopt/opt/algorithms/base.py
index 85c36306..c5a4b101 100644
--- a/modopt/opt/algorithms/base.py
+++ b/modopt/opt/algorithms/base.py
@@ -4,7 +4,7 @@
from inspect import getmro
import numpy as np
-from progressbar import ProgressBar
+from tqdm.auto import tqdm
from modopt.base import backend
from modopt.base.observable import MetricObserver, Observable
@@ -12,17 +12,17 @@
class SetUp(Observable):
- r"""Algorithm Set-Up.
+ """Algorithm Set-Up.
This class contains methods for checking the set-up of an optimisation
- algotithm and produces warnings if they do not comply.
+ algorithm and produces warnings if they do not comply.
Parameters
----------
metric_call_period : int, optional
Metric call period (default is ``5``)
metrics : dict, optional
- Metrics to be used (default is ``\{\}``)
+ Metrics to be used (default is ``None``)
verbose : bool, optional
Option for verbose output (default is ``False``)
progress : bool, optional
@@ -34,11 +34,32 @@ class SetUp(Observable):
use_gpu : bool, optional
Option to use available GPU
+ Notes
+ -----
+ If provided, the ``metrics`` argument should be a nested dictionary of the
+ following form::
+
+ metrics = {
+ 'metric_name': {
+ 'metric': callable,
+ 'mapping': {'x_new': 'test'},
+ 'cst_kwargs': {'ref': ref_image},
+ 'early_stopping': False,
+ }
+ }
+
+ Where ``callable`` is a function with arguments being for instance
+ ``test`` and ``ref``. The mapping of the argument uses the same keys as the
+ output of ``get_notify_observer_kwargs``, ``cst_kwargs`` defines constant
+ arguments that will always be passed to the metric call.
+ If ``early_stopping`` is True, the metric will be used to check for
+ convergence of the algorithm, in that case it is recommended to have
+ ``metric_call_period = 1``
+
See Also
--------
modopt.base.observable.Observable : parent class
modopt.base.observable.MetricObserver : definition of metrics
-
"""
def __init__(
@@ -240,9 +261,8 @@ def _iterations(self, max_iter, progbar=None):
----------
max_iter : int
Maximum number of iterations
- progbar : progressbar.bar.ProgressBar
- Progress bar (default is ``None``)
-
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
"""
for idx in range(max_iter):
self.idx = idx
@@ -268,10 +288,10 @@ def _iterations(self, max_iter, progbar=None):
print(' - Converged!')
break
- if not isinstance(progbar, type(None)):
- progbar.update(idx)
+ if progbar:
+ progbar.update()
- def _run_alg(self, max_iter):
+ def _run_alg(self, max_iter, progbar=None):
"""Run algorithm.
Run the update step of a given algorithm up to the maximum number of
@@ -281,17 +301,34 @@ def _run_alg(self, max_iter):
----------
max_iter : int
Maximum number of iterations
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
See Also
--------
- progressbar.bar.ProgressBar
+ tqdm.tqdm
"""
- if self.progress:
- with ProgressBar(
- redirect_stdout=True,
- max_value=max_iter,
- ) as progbar:
- self._iterations(max_iter, progbar=progbar)
+ if self.progress and progbar is None:
+ with tqdm(total=max_iter) as pb:
+ self._iterations(max_iter, progbar=pb)
+ elif progbar:
+ self._iterations(max_iter, progbar=progbar)
else:
self._iterations(max_iter)
+
+ def _update(self):
+ raise NotImplementedError
+
+ def get_notify_observers_kwargs(self):
+ """Notify Observers.
+
+ Return the mapping between the metrics call and the iterated
+ variables.
+
+ Raises
+ ------
+ NotImplementedError
+ This method should be overriden by subclasses.
+ """
+ raise NotImplementedError
diff --git a/modopt/opt/algorithms/forward_backward.py b/modopt/opt/algorithms/forward_backward.py
index e18f66c3..702799c6 100644
--- a/modopt/opt/algorithms/forward_backward.py
+++ b/modopt/opt/algorithms/forward_backward.py
@@ -467,7 +467,7 @@ def _update(self):
or self._cost_func.get_cost(self._x_new)
)
- def iterate(self, max_iter=150):
+ def iterate(self, max_iter=150, progbar=None):
"""Iterate.
This method calls update until either the convergence criteria is met
@@ -477,9 +477,10 @@ def iterate(self, max_iter=150):
----------
max_iter : int, optional
Maximum number of iterations (default is ``150``)
-
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
"""
- self._run_alg(max_iter)
+ self._run_alg(max_iter, progbar)
# retrieve metrics results
self.retrieve_outputs()
@@ -750,7 +751,7 @@ def _update(self):
if self._cost_func:
self.converge = self._cost_func.get_cost(self._x_new)
- def iterate(self, max_iter=150):
+ def iterate(self, max_iter=150, progbar=None):
"""Iterate.
This method calls update until either convergence criteria is met or
@@ -760,9 +761,10 @@ def iterate(self, max_iter=150):
----------
max_iter : int, optional
Maximum number of iterations (default is ``150``)
-
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
"""
- self._run_alg(max_iter)
+ self._run_alg(max_iter, progbar)
# retrieve metrics results
self.retrieve_outputs()
@@ -815,9 +817,9 @@ class POGM(SetUp):
Initial guess for the :math:`y` variable
z : numpy.ndarray
Initial guess for the :math:`z` variable
- grad
+ grad : GradBasic
Gradient operator class
- prox
+ prox : ProximalParent
Proximity operator class
cost : class instance or str, optional
Cost function class instance (default is ``'auto'``); Use ``'auto'`` to
@@ -942,7 +944,9 @@ def _update(self):
"""
# Step 4 from alg. 3
self._grad.get_grad(self._x_old)
- self._u_new = self._x_old - self._beta * self._grad.grad
+ #self._u_new = self._x_old - self._beta * self._grad.grad
+ self._u_new = -self._beta * self._grad.grad
+ self._u_new += self._x_old
# Step 5 from alg. 3
self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old ** 2))
@@ -964,10 +968,15 @@ def _update(self):
# Restarting and gamma-Decreasing
# Step 9 from alg. 3
- self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
+ #self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
+ self._g_new = (self._z - self._x_new)
+ self._g_new /= self._xi
+ self._g_new += self._grad.grad
# Step 10 from alg 3.
- self._y_new = self._x_old - self._beta * self._g_new
+ #self._y_new = self._x_old - self._beta * self._g_new
+ self._y_new = - self._beta * self._g_new
+ self._y_new += self._x_old
# Step 11 from alg. 3
restart_crit = (
@@ -995,7 +1004,7 @@ def _update(self):
or self._cost_func.get_cost(self._x_new)
)
- def iterate(self, max_iter=150):
+ def iterate(self, max_iter=150, progbar=None):
"""Iterate.
This method calls update until either convergence criteria is met or
@@ -1005,9 +1014,10 @@ def iterate(self, max_iter=150):
----------
max_iter : int, optional
Maximum number of iterations (default is ``150``)
-
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
"""
- self._run_alg(max_iter)
+ self._run_alg(max_iter, progbar)
# retrieve metrics results
self.retrieve_outputs()
diff --git a/modopt/opt/algorithms/primal_dual.py b/modopt/opt/algorithms/primal_dual.py
index c8566969..d5bdd431 100644
--- a/modopt/opt/algorithms/primal_dual.py
+++ b/modopt/opt/algorithms/primal_dual.py
@@ -225,7 +225,7 @@ def _update(self):
or self._cost_func.get_cost(self._x_new, self._y_new)
)
- def iterate(self, max_iter=150, n_rewightings=1):
+ def iterate(self, max_iter=150, n_rewightings=1, progbar=None):
"""Iterate.
This method calls update until either convergence criteria is met or
@@ -237,14 +237,17 @@ def iterate(self, max_iter=150, n_rewightings=1):
Maximum number of iterations (default is ``150``)
n_rewightings : int, optional
Number of reweightings to perform (default is ``1``)
-
+ progbar: tqdm.tqdm
+ Progress bar handle (default is ``None``)
"""
- self._run_alg(max_iter)
+ self._run_alg(max_iter, progbar)
if not isinstance(self._reweight, type(None)):
for _ in range(n_rewightings):
self._reweight.reweight(self._linear.op(self._x_new))
- self._run_alg(max_iter)
+ if progbar:
+ progbar.reset(total=max_iter)
+ self._run_alg(max_iter, progbar)
# retrieve metrics results
self.retrieve_outputs()
diff --git a/modopt/opt/cost.py b/modopt/opt/cost.py
index 3cdfcc50..688a3959 100644
--- a/modopt/opt/cost.py
+++ b/modopt/opt/cost.py
@@ -6,6 +6,8 @@
"""
+import abc
+
import numpy as np
from modopt.base.backend import get_array_module
@@ -13,8 +15,8 @@
from modopt.plot.cost_plot import plotCost
-class costObj(object):
- """Generic cost function object.
+class CostParent(abc.ABC):
+ """Abstract cost function object.
This class updates the cost according to the input operator classes and
tests for convergence.
@@ -40,7 +42,8 @@ class costObj(object):
Notes
-----
- The costFunc class must contain a method called ``cost``.
+ All child classes should implement a ``_calc_cost`` method (returning
+ a float) or a ``get_cost`` for more complex behavior on convergence test.
Examples
--------
@@ -71,7 +74,6 @@ class costObj(object):
def __init__(
self,
- operators,
initial_cost=1e6,
tolerance=1e-4,
cost_interval=1,
@@ -80,9 +82,6 @@ def __init__(
plot_output=None,
):
- self._operators = operators
- if not isinstance(operators, type(None)):
- self._check_operators()
self.cost = initial_cost
self._cost_list = []
self._cost_interval = cost_interval
@@ -93,30 +92,6 @@ def __init__(
self._plot_output = plot_output
self._verbose = verbose
- def _check_operators(self):
- """Check operators.
-
- This method checks if the input operators have a ``cost`` method.
-
- Raises
- ------
- TypeError
- For invalid operators type
- ValueError
- For operators without ``cost`` method
-
- """
- if not isinstance(self._operators, (list, tuple, np.ndarray)):
- message = (
- 'Input operators must be provided as a list, not {0}'
- )
- raise TypeError(message.format(type(self._operators)))
-
- for op in self._operators:
- if not hasattr(op, 'cost'):
- raise ValueError('Operators must contain "cost" method.')
- op.cost = check_callable(op.cost)
-
def _check_cost(self):
"""Check cost function.
@@ -167,6 +142,7 @@ def _check_cost(self):
return False
+ @abc.abstractmethod
def _calc_cost(self, *args, **kwargs):
"""Calculate the cost.
@@ -178,14 +154,7 @@ def _calc_cost(self, *args, **kwargs):
Positional arguments
**kwargs : dict
Keyword arguments
-
- Returns
- -------
- float
- Cost value
-
"""
- return np.sum([op.cost(*args, **kwargs) for op in self._operators])
def get_cost(self, *args, **kwargs):
"""Get cost function.
@@ -241,3 +210,110 @@ def plot_cost(self): # pragma: no cover
"""
plotCost(self._cost_list, self._plot_output)
+
+
+class costObj(CostParent):
+ """Abstract cost function object.
+
+ This class updates the cost according to the input operator classes and
+ tests for convergence.
+
+ Parameters
+ ----------
+ opertors : list, tuple or numpy.ndarray
+ List of operators classes containing ``cost`` method
+ initial_cost : float, optional
+ Initial value of the cost (default is ``1e6``)
+ tolerance : float, optional
+ Tolerance threshold for convergence (default is ``1e-4``)
+ cost_interval : int, optional
+ Iteration interval to calculate cost (default is ``1``).
+ If ``cost_interval`` is ``None`` the cost is never calculated,
+ thereby saving on computation time.
+ test_range : int, optional
+ Number of cost values to be used in test (default is ``4``)
+ verbose : bool, optional
+ Option for verbose output (default is ``True``)
+ plot_output : str, optional
+ Output file name for cost function plot
+
+ Examples
+ --------
+ >>> from modopt.opt.cost import *
+ >>> class dummy(object):
+ ... def cost(self, x):
+ ... return x ** 2
+ ...
+ ...
+ >>> inst = costObj([dummy(), dummy()])
+ >>> inst.get_cost(2)
+ - ITERATION: 1
+ - COST: 8
+
+ False
+ >>> inst.get_cost(2)
+ - ITERATION: 2
+ - COST: 8
+
+ False
+ >>> inst.get_cost(2)
+ - ITERATION: 3
+ - COST: 8
+
+ False
+ """
+
+ def __init__(
+ self,
+ operators,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self._operators = operators
+ if not isinstance(operators, type(None)):
+ self._check_operators()
+
+ def _check_operators(self):
+ """Check operators.
+
+ This method checks if the input operators have a ``cost`` method.
+
+ Raises
+ ------
+ TypeError
+ For invalid operators type
+ ValueError
+ For operators without ``cost`` method
+
+ """
+ if not isinstance(self._operators, (list, tuple, np.ndarray)):
+ message = (
+ 'Input operators must be provided as a list, not {0}'
+ )
+ raise TypeError(message.format(type(self._operators)))
+
+ for op in self._operators:
+ if not hasattr(op, 'cost'):
+ raise ValueError('Operators must contain "cost" method.')
+ op.cost = check_callable(op.cost)
+
+ def _calc_cost(self, *args, **kwargs):
+ """Calculate the cost.
+
+ This method calculates the cost from each of the input operators.
+
+ Parameters
+ ----------
+ *args : tuple
+ Positional arguments
+ **kwargs : dict
+ Keyword arguments
+
+ Returns
+ -------
+ float
+ Cost value
+
+ """
+ return np.sum([op.cost(*args, **kwargs) for op in self._operators])
diff --git a/modopt/opt/linear/__init__.py b/modopt/opt/linear/__init__.py
new file mode 100644
index 00000000..d5c0d21f
--- /dev/null
+++ b/modopt/opt/linear/__init__.py
@@ -0,0 +1,21 @@
+"""LINEAR OPERATORS.
+
+This module contains linear operator classes.
+
+:Author: Samuel Farrens
+:Author: Pierre-Antoine Comby
+"""
+
+from .base import LinearParent, Identity, MatrixOperator, LinearCombo
+
+from .wavelet import WaveletConvolve, WaveletTransform
+
+
+__all__ = [
+ "LinearParent",
+ "Identity",
+ "MatrixOperator",
+ "LinearCombo",
+ "WaveletConvolve",
+ "WaveletTransform",
+]
diff --git a/modopt/opt/linear.py b/modopt/opt/linear/base.py
similarity index 84%
rename from modopt/opt/linear.py
rename to modopt/opt/linear/base.py
index d8679998..e347970d 100644
--- a/modopt/opt/linear.py
+++ b/modopt/opt/linear/base.py
@@ -1,18 +1,9 @@
-# -*- coding: utf-8 -*-
-
-"""LINEAR OPERATORS.
-
-This module contains linear operator classes.
-
-:Author: Samuel Farrens
-
-"""
+"""Base classes for linear operators."""
import numpy as np
-from modopt.base.types import check_callable, check_float
-from modopt.signal.wavelet import filter_convolve_stack
-
+from modopt.base.types import check_callable
+from modopt.base.backend import get_array_module
class LinearParent(object):
"""Linear Operator Parent Class.
@@ -78,42 +69,24 @@ def __init__(self):
self.op = lambda input_data: input_data
self.adj_op = self.op
+ self.cost= lambda *args, **kwargs: 0
-class WaveletConvolve(LinearParent):
- """Wavelet Convolution Class.
-
- This class defines the wavelet transform operators via convolution with
- predefined filters.
-
- Parameters
- ----------
- filters: numpy.ndarray
- Array of wavelet filter coefficients
- method : str, optional
- Convolution method (default is ``'scipy'``)
-
- See Also
- --------
- LinearParent : parent class
- modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution
+class MatrixOperator(LinearParent):
+ """
+ Matrix Operator class.
+ This class transforms an array into a suitable linear operator.
"""
- def __init__(self, filters, method='scipy'):
+ def __init__(self, array):
+ self.op = lambda x: array @ x
+ xp = get_array_module(array)
- self._filters = check_float(filters)
- self.op = lambda input_data: filter_convolve_stack(
- input_data,
- self._filters,
- method=method,
- )
- self.adj_op = lambda input_data: filter_convolve_stack(
- input_data,
- self._filters,
- filter_rot=True,
- method=method,
- )
+ if xp.any(xp.iscomplex(array)):
+ self.adj_op = lambda x: array.T.conjugate() @ x
+ else:
+ self.adj_op = lambda x: array.T @ x
class LinearCombo(LinearParent):
@@ -150,7 +123,6 @@ class LinearCombo(LinearParent):
See Also
--------
LinearParent : parent class
-
"""
def __init__(self, operators, weights=None):
diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py
new file mode 100644
index 00000000..6e22a2b0
--- /dev/null
+++ b/modopt/opt/linear/wavelet.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+"""Wavelet operator, using either scipy filter or pywavelet."""
+import warnings
+
+import numpy as np
+
+from modopt.base.types import check_float
+from modopt.signal.wavelet import filter_convolve_stack
+
+from .base import LinearParent
+
+pywt_available = True
+try:
+ import pywt
+ from joblib import Parallel, cpu_count, delayed
+except ImportError:
+ pywt_available = False
+
+
+class WaveletConvolve(LinearParent):
+ """Wavelet Convolution Class.
+
+ This class defines the wavelet transform operators via convolution with
+ predefined filters.
+
+ Parameters
+ ----------
+ filters: numpy.ndarray
+ Array of wavelet filter coefficients
+ method : str, optional
+ Convolution method (default is ``'scipy'``)
+
+ See Also
+ --------
+ LinearParent : parent class
+ modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution
+
+ """
+
+ def __init__(self, filters, method='scipy'):
+
+ self._filters = check_float(filters)
+ self.op = lambda input_data: filter_convolve_stack(
+ input_data,
+ self._filters,
+ method=method,
+ )
+ self.adj_op = lambda input_data: filter_convolve_stack(
+ input_data,
+ self._filters,
+ filter_rot=True,
+ method=method,
+ )
+
+
+
+class WaveletTransform(LinearParent):
+ """
+ 2D and 3D wavelet transform class.
+
+ This is a light wrapper around PyWavelet, with multicoil support.
+
+ Parameters
+ ----------
+ wavelet_name: str
+ the wavelet name to be used during the decomposition.
+ shape: tuple[int,...]
+ Shape of the input data. The shape should be a tuple of length 2 or 3.
+ It should not contains coils or batch dimension.
+ nb_scales: int, default 4
+ the number of scales in the decomposition.
+ n_batchs: int, default 1
+ the number of channel/ batch dimension
+ n_jobs: int, default 1
+ the number of cores to use for multichannel.
+ backend: str, default "threading"
+ the backend to use for parallel multichannel linear operation.
+ verbose: int, default 0
+ the verbosity level.
+
+ Attributes
+ ----------
+ nb_scale: int
+ number of scale decomposed in wavelet space.
+ n_jobs: int
+ number of jobs for parallel computation
+ n_batchs: int
+ number of coils use f
+ backend: str
+ Backend use for parallel computation
+ verbose: int
+ Verbosity level
+ """
+
+ def __init__(
+ self,
+ wavelet_name,
+ shape,
+ level=4,
+ n_batch=1,
+ n_jobs=1,
+ decimated=True,
+ backend="threading",
+ mode="symmetric",
+ ):
+ if not pywt_available:
+ raise ImportError(
+ "PyWavelet and/or joblib are not available. Please install it to use WaveletTransform."
+ )
+ if wavelet_name not in pywt.wavelist(kind="all"):
+ raise ValueError(
+ "Invalid wavelet name. Availables are ``pywt.waveletlist(kind='all')``"
+ )
+
+ self.wavelet = wavelet_name
+ if isinstance(shape, int):
+ shape = (shape,)
+ self.shape = shape
+ self.n_jobs = n_jobs
+ self.mode = mode
+ self.level = level
+ if not decimated:
+ raise NotImplementedError(
+ "Undecimated Wavelet Transform is not implemented yet."
+ )
+ ca, *cds = pywt.wavedecn_shapes(
+ self.shape, wavelet=self.wavelet, mode=self.mode, level=self.level
+ )
+ self.coeffs_shape = [ca] + [s for cd in cds for s in cd.values()]
+
+ if len(shape) > 1:
+ self.dwt = pywt.wavedecn
+ self.idwt = pywt.waverecn
+ self._pywt_fun = "wavedecn"
+ else:
+ self.dwt = pywt.wavedec
+ self.idwt = pywt.waverec
+ self._pywt_fun = "wavedec"
+
+ self.n_batch = n_batch
+ if self.n_batch == 1 and self.n_jobs != 1:
+ warnings.warn("Making n_jobs = 1 for WaveletTransform as n_batchs = 1")
+ self.n_jobs = 1
+ self.backend = backend
+ n_proc = self.n_jobs
+ if n_proc < 0:
+ n_proc = cpu_count() + self.n_jobs + 1
+
+ def op(self, data):
+ """Define the wavelet operator.
+
+ This method returns the input data convolved with the wavelet filter.
+
+ Parameters
+ ----------
+ data: ndarray or Image
+ input 2D data array.
+
+ Returns
+ -------
+ coeffs: ndarray
+ the wavelet coefficients.
+ """
+ if self.n_batch > 1:
+ coeffs, self.coeffs_slices, self.raw_coeffs_shape = zip(
+ *Parallel(
+ n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose
+ )(delayed(self._op)(data[i]) for i in np.arange(self.n_batch))
+ )
+ coeffs = np.asarray(coeffs)
+ else:
+ coeffs, self.coeffs_slices, self.raw_coeffs_shape = self._op(data)
+ return coeffs
+
+ def _op(self, data):
+ """Single coil wavelet transform."""
+ return pywt.ravel_coeffs(
+ self.dwt(data, mode=self.mode, level=self.level, wavelet=self.wavelet)
+ )
+
+ def adj_op(self, coeffs):
+ """Define the wavelet adjoint operator.
+
+ This method returns the reconstructed image.
+
+ Parameters
+ ----------
+ coeffs: ndarray
+ the wavelet coefficients.
+
+ Returns
+ -------
+ data: ndarray
+ the reconstructed data.
+ """
+ if self.n_batch > 1:
+ images = Parallel(
+ n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose
+ )(
+ delayed(self._adj_op)(coeffs[i], self.coeffs_shape[i])
+ for i in np.arange(self.n_batch)
+ )
+ images = np.asarray(images)
+ else:
+ images = self._adj_op(coeffs)
+ return images
+
+ def _adj_op(self, coeffs):
+ """Single coil inverse wavelet transform."""
+ return self.idwt(
+ pywt.unravel_coeffs(
+ coeffs, self.coeffs_slices, self.raw_coeffs_shape, self._pywt_fun
+ ),
+ wavelet=self.wavelet,
+ mode=self.mode,
+ )
diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py
index f8f368ef..e8492367 100644
--- a/modopt/opt/proximity.py
+++ b/modopt/opt/proximity.py
@@ -993,7 +993,7 @@ def _interpolate(self, alpha0, alpha1, sum0, sum1):
:math:`\sum\theta(\alpha^*)=k` via a linear interpolation.
Parameters
- -----------
+ ----------
alpha0: float
A value for wich :math:`\sum\theta(\alpha^0) \leq k`
alpha1: float
diff --git a/modopt/signal/filter.py b/modopt/signal/filter.py
index 8e24768c..84dd8160 100644
--- a/modopt/signal/filter.py
+++ b/modopt/signal/filter.py
@@ -73,8 +73,8 @@ def mex_hat(data_point, sigma):
Examples
--------
>>> from modopt.signal.filter import mex_hat
- >>> mex_hat(2, 1)
- -0.3521390522571337
+ >>> round(mex_hat(2, 1), 15)
+ -0.352139052257134
"""
data_point = check_float(data_point)
@@ -108,8 +108,8 @@ def mex_hat_dir(data_gauss, data_mex, sigma):
Examples
--------
>>> from modopt.signal.filter import mex_hat_dir
- >>> mex_hat_dir(1, 2, 1)
- 0.17606952612856686
+ >>> round(mex_hat_dir(1, 2, 1), 16)
+ 0.1760695261285668
"""
data_gauss = check_float(data_gauss)
diff --git a/modopt/signal/positivity.py b/modopt/signal/positivity.py
index e4ec098d..c19ba62c 100644
--- a/modopt/signal/positivity.py
+++ b/modopt/signal/positivity.py
@@ -48,7 +48,7 @@ def pos_recursive(input_data):
"""
if input_data.dtype == 'O':
- res = np.array([pos_recursive(elem) for elem in input_data])
+ res = np.array([pos_recursive(elem) for elem in input_data], dtype="object")
else:
res = pos_thresh(input_data)
diff --git a/modopt/signal/svd.py b/modopt/signal/svd.py
index 6dcb9eda..f3d40a51 100644
--- a/modopt/signal/svd.py
+++ b/modopt/signal/svd.py
@@ -57,7 +57,7 @@ def find_n_pc(u_vec, factor=0.5):
)
# Get the shape of the array
- array_shape = np.repeat(np.int(np.sqrt(u_vec.shape[0])), 2)
+ array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2)
# Find the auto correlation of the left singular vector.
u_auto = [
@@ -299,7 +299,7 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
a_matrix = np.dot(s_values, v_vec)
# Get the shape of the array
- array_shape = np.repeat(np.int(np.sqrt(u_vec.shape[0])), 2)
+ array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2)
# Compute threshold matrix.
ti = np.array([
diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py
index 7ff96a8b..5671b8e3 100644
--- a/modopt/tests/test_algorithms.py
+++ b/modopt/tests/test_algorithms.py
@@ -1,470 +1,279 @@
# -*- coding: utf-8 -*-
-"""UNIT TESTS FOR OPT.ALGORITHMS.
+"""UNIT TESTS FOR Algorithms.
-This module contains unit tests for the modopt.opt.algorithms module.
-
-:Author: Samuel Farrens
+This module contains unit tests for the modopt.opt module.
+:Authors:
+ Samuel Farrens
+ Pierre-Antoine Comby
"""
-from unittest import TestCase
-
import numpy as np
import numpy.testing as npt
-
+import pytest
from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight
-
-# Basic functions to be used as operators or as dummy functions
-func_identity = lambda x_val: x_val
-func_double = lambda x_val: x_val * 2
-func_sq = lambda x_val: x_val ** 2
-func_cube = lambda x_val: x_val ** 3
-
-
-class Dummy(object):
- """Dummy class for tests."""
-
- pass
-
-
-class AlgorithmTestCase(TestCase):
- """Test case for algorithms module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6
- self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1
-
- grad_inst = gradient.GradBasic(
- self.data1,
- func_identity,
- func_identity,
- )
-
+from pytest_cases import (
+ case,
+ fixture,
+ fixture_ref,
+ lazy_value,
+ parametrize,
+ parametrize_with_cases,
+)
+
+from test_helpers import Dummy
+
+SKLEARN_AVAILABLE = True
+try:
+ import sklearn
+except ImportError:
+ SKLEARN_AVAILABLE = False
+
+
+@fixture
+def idty():
+ """Identity function."""
+ return lambda x: x
+
+
+@fixture
+def reweight_op():
+ """Reweight operator."""
+ data3 = np.arange(9).reshape(3, 3).astype(float) + 1
+ return reweight.cwbReweight(data3)
+
+
+def build_kwargs(kwargs, use_metrics):
+ """Build the kwargs for each algorithm, replacing placeholders by true values.
+
+ This function has to be call for each test, as direct parameterization somehow
+ is not working with pytest-xdist and pytest-cases.
+ It also adds dummy metric measurement to validate the metric api.
+ """
+ update_value = {
+ "idty": lambda x: x,
+ "lin_idty": linear.Identity(),
+ "reweight_op": reweight.cwbReweight(
+ np.arange(9).reshape(3, 3).astype(float) + 1
+ ),
+ }
+ new_kwargs = dict()
+ print(kwargs)
+ # update the value of the dict is possible.
+ for key in kwargs:
+ new_kwargs[key] = update_value.get(kwargs[key], kwargs[key])
+
+ if use_metrics:
+ new_kwargs["linear"] = linear.Identity()
+ new_kwargs["metrics"] = {
+ "diff": {
+ "metric": lambda test, ref: np.sum(test - ref),
+ "mapping": {"x_new": "test"},
+ "cst_kwargs": {"ref": np.arange(9).reshape((3, 3))},
+ "early_stopping": False,
+ }
+ }
+
+ return new_kwargs
+
+
+@parametrize(use_metrics=[True, False])
+class AlgoCases:
+ """Cases for algorithms.
+
+ Most of the test solves the trivial problem
+
+ .. math::
+ \\min_x \\frac{1}{2} \\| y - x \\|_2^2 \\quad\\text{s.t.} x \\geq 0
+
+ More complex and concrete usecases are shown in examples.
+ """
+
+ data1 = np.arange(9).reshape(3, 3).astype(float)
+ data2 = data1 + np.random.randn(*data1.shape) * 1e-6
+ max_iter = 20
+
+ @parametrize(
+ kwargs=[
+ {"beta_update": "idty", "auto_iterate": False, "cost": None},
+ {"beta_update": "idty"},
+ {"cost": None, "lambda_update": None},
+ {"beta_update": "idty", "a_cd": 3},
+ {"beta_update": "idty", "r_lazy": 3, "p_lazy": 0.7, "q_lazy": 0.7},
+ {"restart_strategy": "adaptive", "xi_restart": 0.9},
+ {
+ "restart_strategy": "greedy",
+ "xi_restart": 0.9,
+ "min_beta": 1.0,
+ "s_greedy": 1.1,
+ },
+ ]
+ )
+ def case_forward_backward(self, kwargs, idty, use_metrics):
+ """Forward Backward case.
+ """
+ update_kwargs = build_kwargs(kwargs, use_metrics)
+ algo = algorithms.ForwardBackward(
+ self.data1,
+ grad=gradient.GradBasic(self.data1, idty, idty),
+ prox=proximity.Positivity(),
+ **update_kwargs,
+ )
+ if update_kwargs.get("auto_iterate", None) is False:
+ algo.iterate(self.max_iter)
+ return algo, update_kwargs
+
+ @parametrize(
+ kwargs=[
+ {
+ "cost": None,
+ "auto_iterate": False,
+ "gamma_update": "idty",
+ "beta_update": "idty",
+ },
+ {"gamma_update": "idty", "lambda_update": "idty"},
+ {"cost": True},
+ {"cost": True, "step_size": 2},
+ ]
+ )
+ def case_gen_forward_backward(self, kwargs, use_metrics, idty):
+ """General FB setup."""
+ update_kwargs = build_kwargs(kwargs, use_metrics)
+ grad_inst = gradient.GradBasic(self.data1, idty, idty)
prox_inst = proximity.Positivity()
prox_dual_inst = proximity.IdentityProx()
- linear_inst = linear.Identity()
- reweight_inst = reweight.cwbReweight(self.data3)
- cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
- self.setup = algorithms.SetUp()
- self.max_iter = 20
-
- self.fb_all_iter = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=None,
- auto_iterate=False,
- beta_update=func_identity,
- )
- self.fb_all_iter.iterate(self.max_iter)
-
- self.fb1 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- )
-
- self.fb2 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- lambda_update=None,
- )
-
- self.fb3 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- a_cd=3,
- )
-
- self.fb4 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- r_lazy=3,
- p_lazy=0.7,
- q_lazy=0.7,
- )
-
- self.fb5 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- restart_strategy='adaptive',
- xi_restart=0.9,
- )
-
- self.fb6 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- restart_strategy='greedy',
- xi_restart=0.9,
- min_beta=1.0,
- s_greedy=1.1,
- )
-
- self.gfb_all_iter = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=None,
- auto_iterate=False,
- gamma_update=func_identity,
- beta_update=func_identity,
- )
- self.gfb_all_iter.iterate(self.max_iter)
-
- self.gfb1 = algorithms.GenForwardBackward(
+ if update_kwargs.get("cost", None) is True:
+ update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
+ algo = algorithms.GenForwardBackward(
self.data1,
grad=grad_inst,
prox_list=[prox_inst, prox_dual_inst],
- gamma_update=func_identity,
- lambda_update=func_identity,
- )
-
- self.gfb2 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=cost_inst,
- )
-
- self.gfb3 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=cost_inst,
- step_size=2,
- )
-
- self.condat_all_iter = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- cost=None,
- prox_dual=prox_dual_inst,
- sigma_update=func_identity,
- tau_update=func_identity,
- rho_update=func_identity,
- auto_iterate=False,
- )
- self.condat_all_iter.iterate(self.max_iter)
-
- self.condat1 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- sigma_update=func_identity,
- tau_update=func_identity,
- rho_update=func_identity,
- )
-
- self.condat2 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- linear=linear_inst,
- cost=cost_inst,
- reweight=reweight_inst,
- )
+ **update_kwargs,
+ )
+ if update_kwargs.get("auto_iterate", None) is False:
+ algo.iterate(self.max_iter)
+ return algo, update_kwargs
+
+ @parametrize(
+ kwargs=[
+ {
+ "sigma_dual": "idty",
+ "tau_update": "idty",
+ "rho_update": "idty",
+ "auto_iterate": False,
+ },
+ {
+ "sigma_dual": "idty",
+ "tau_update": "idty",
+ "rho_update": "idty",
+ },
+ {
+ "linear": "lin_idty",
+ "cost": True,
+ "reweight": "reweight_op",
+ },
+ ]
+ )
+ def case_condat(self, kwargs, use_metrics, idty):
+ """Condat Vu Algorithm setup."""
+ update_kwargs = build_kwargs(kwargs, use_metrics)
+ grad_inst = gradient.GradBasic(self.data1, idty, idty)
+ prox_inst = proximity.Positivity()
+ prox_dual_inst = proximity.IdentityProx()
+ if update_kwargs.get("cost", None) is True:
+ update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
- self.condat3 = algorithms.Condat(
+ algo = algorithms.Condat(
self.data1,
self.data2,
grad=grad_inst,
prox=prox_inst,
prox_dual=prox_dual_inst,
- linear=Dummy(),
- cost=cost_inst,
- auto_iterate=False,
+ **update_kwargs,
)
+ if update_kwargs.get("auto_iterate", None) is False:
+ algo.iterate(self.max_iter)
+ return algo, update_kwargs
- self.pogm_all_iter = algorithms.POGM(
+ @parametrize(kwargs=[{"auto_iterate": False, "cost": None}, {}])
+ def case_pogm(self, kwargs, use_metrics, idty):
+ """POGM setup."""
+ update_kwargs = build_kwargs(kwargs, use_metrics)
+ grad_inst = gradient.GradBasic(self.data1, idty, idty)
+ prox_inst = proximity.Positivity()
+ algo = algorithms.POGM(
u=self.data1,
x=self.data1,
y=self.data1,
z=self.data1,
grad=grad_inst,
prox=prox_inst,
- auto_iterate=False,
- cost=None,
+ **update_kwargs,
)
- self.pogm_all_iter.iterate(self.max_iter)
- self.pogm1 = algorithms.POGM(
- u=self.data1,
- x=self.data1,
- y=self.data1,
- z=self.data1,
- grad=grad_inst,
- prox=prox_inst,
- )
+ if update_kwargs.get("auto_iterate", None) is False:
+ algo.iterate(self.max_iter)
+ return algo, update_kwargs
- self.vanilla_grad = algorithms.VanillaGenericGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
- self.ada_grad = algorithms.AdaGenericGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
- self.adam_grad = algorithms.ADAMGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
- self.momentum_grad = algorithms.MomentumGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
- self.rms_grad = algorithms.RMSpropGradOpt(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- )
- self.saga_grad = algorithms.SAGAOptGradOpt(
+ @parametrize(
+ GradDescent=[
+ algorithms.VanillaGenericGradOpt,
+ algorithms.AdaGenericGradOpt,
+ algorithms.ADAMGradOpt,
+ algorithms.MomentumGradOpt,
+ algorithms.RMSpropGradOpt,
+ algorithms.SAGAOptGradOpt,
+ ]
+ )
+ def case_grad(self, GradDescent, use_metrics, idty):
+ """Gradient Descent algorithm test."""
+ update_kwargs = build_kwargs({}, use_metrics)
+ grad_inst = gradient.GradBasic(self.data1, idty, idty)
+ prox_inst = proximity.Positivity()
+ cost_inst = cost.costObj([grad_inst, prox_inst])
+
+ algo = GradDescent(
self.data1,
grad=grad_inst,
prox=prox_inst,
cost=cost_inst,
+ **update_kwargs,
)
+ algo.iterate()
+ return algo, update_kwargs
+ @parametrize(admm=[algorithms.ADMM,algorithms.FastADMM])
+ def case_admm(self, admm, use_metrics, idty):
+ """ADMM setup."""
+ def optim1(init, obs):
+ return obs
- self.dummy = Dummy()
- self.dummy.cost = func_identity
- self.setup._check_operator(self.dummy.cost)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.setup = None
- self.fb_all_iter = None
- self.fb1 = None
- self.fb2 = None
- self.gfb_all_iter = None
- self.gfb1 = None
- self.gfb2 = None
- self.condat_all_iter = None
- self.condat1 = None
- self.condat2 = None
- self.condat3 = None
- self.pogm1 = None
- self.pogm_all_iter = None
- self.dummy = None
-
- def test_set_up(self):
- """Test set_up."""
- npt.assert_raises(TypeError, self.setup._check_input_data, 1)
-
- npt.assert_raises(TypeError, self.setup._check_param, 1)
-
- npt.assert_raises(TypeError, self.setup._check_param_update, 1)
-
- def test_all_iter(self):
- """Test if all opt run for all iterations."""
- opts = [
- self.fb_all_iter,
- self.gfb_all_iter,
- self.condat_all_iter,
- self.pogm_all_iter,
- ]
- for opt in opts:
- npt.assert_equal(opt.idx, self.max_iter - 1)
-
- def test_forward_backward(self):
- """Test forward_backward."""
- npt.assert_array_equal(
- self.fb1.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb2.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb3.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb4.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb5.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb6.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
+ def optim2(init, obs):
+ return obs
- def test_gen_forward_backward(self):
- """Test gen_forward_backward."""
- npt.assert_array_equal(
- self.gfb1.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.gfb2.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.gfb3.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_equal(
- self.gfb3.step_size,
- 2,
- err_msg='Incorrect step size.',
- )
-
- npt.assert_raises(
- TypeError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=1,
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[1],
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[0.5, 0.5],
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[0.5],
- )
-
- def test_condat(self):
- """Test gen_condat."""
- npt.assert_almost_equal(
- self.condat1.x_final,
- self.data1,
- err_msg='Incorrect Condat result.',
- )
-
- npt.assert_almost_equal(
- self.condat2.x_final,
- self.data1,
- err_msg='Incorrect Condat result.',
- )
-
- def test_pogm(self):
- """Test pogm."""
- npt.assert_almost_equal(
- self.pogm1.x_final,
- self.data1,
- err_msg='Incorrect POGM result.',
- )
-
- def test_ada_grad(self):
- """Test ADA Gradient Descent."""
- self.ada_grad.iterate()
- npt.assert_almost_equal(
- self.ada_grad.x_final,
- self.data1,
- err_msg='Incorrect ADAGrad results.',
- )
-
- def test_adam_grad(self):
- """Test ADAM Gradient Descent."""
- self.adam_grad.iterate()
- npt.assert_almost_equal(
- self.adam_grad.x_final,
- self.data1,
- err_msg='Incorrect ADAMGrad results.',
- )
-
- def test_momemtum_grad(self):
- """Test Momemtum Gradient Descent."""
- self.momentum_grad.iterate()
- npt.assert_almost_equal(
- self.momentum_grad.x_final,
- self.data1,
- err_msg='Incorrect MomentumGrad results.',
- )
-
- def test_rmsprop_grad(self):
- """Test RMSProp Gradient Descent."""
- self.rms_grad.iterate()
- npt.assert_almost_equal(
- self.rms_grad.x_final,
- self.data1,
- err_msg='Incorrect RMSPropGrad results.',
- )
-
- def test_saga_grad(self):
- """Test SAGA Descent."""
- self.saga_grad.iterate()
- npt.assert_almost_equal(
- self.saga_grad.x_final,
- self.data1,
- err_msg='Incorrect SAGA Grad results.',
- )
-
- def test_vanilla_grad(self):
- """Test Vanilla Gradient Descent."""
- self.vanilla_grad.iterate()
- npt.assert_almost_equal(
- self.vanilla_grad.x_final,
- self.data1,
- err_msg='Incorrect VanillaGrad results.',
- )
+ update_kwargs = build_kwargs({}, use_metrics)
+ algo = admm(
+ u=self.data1,
+ v=self.data1,
+ mu=np.zeros_like(self.data1),
+ A=linear.Identity(),
+ B=linear.Identity(),
+ b=self.data1,
+ optimizers=(optim1, optim2),
+ **update_kwargs,
+ )
+ algo.iterate()
+ return algo, update_kwargs
+
+@parametrize_with_cases("algo, kwargs", cases=AlgoCases)
+def test_algo(algo, kwargs):
+ """Test algorithms."""
+ if kwargs.get("auto_iterate") is False:
+ # algo already run
+ npt.assert_almost_equal(algo.idx, AlgoCases.max_iter - 1)
+ else:
+ npt.assert_almost_equal(algo.x_final, AlgoCases.data1)
+
+ if kwargs.get("metrics"):
+ print(algo.metrics)
+ npt.assert_almost_equal(algo.metrics["diff"]["values"][-1], 0, 3)
diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py
index 873a4506..e32ff94b 100644
--- a/modopt/tests/test_base.py
+++ b/modopt/tests/test_base.py
@@ -1,192 +1,139 @@
-# -*- coding: utf-8 -*-
-
-"""UNIT TESTS FOR BASE.
-
-This module contains unit tests for the modopt.base module.
-
-:Author: Samuel Farrens
-
"""
+Test for base module.
-from builtins import range
-from unittest import TestCase, skipIf
-
+:Authors:
+ Samuel Farrens
+ Pierre-Antoine Comby
+"""
import numpy as np
import numpy.testing as npt
+import pytest
+from test_helpers import failparam, skipparam
-from modopt.base import np_adjust, transform, types
-from modopt.base.backend import (LIBRARIES, change_backend, get_array_module,
- get_backend)
+from modopt.base import backend, np_adjust, transform, types
+from modopt.base.backend import LIBRARIES
-class NPAdjustTestCase(TestCase):
- """Test case for np_adjust module."""
+class TestNpAdjust:
+ """Test for npadjust."""
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape((3, 3))
- self.data2 = np.arange(18).reshape((2, 3, 3))
- self.data3 = np.array([
+ array33 = np.arange(9).reshape((3, 3))
+ array233 = np.arange(18).reshape((2, 3, 3))
+ arraypad = np.array(
+ [
[0, 0, 0, 0, 0],
[0, 0, 1, 2, 0],
[0, 3, 4, 5, 0],
[0, 6, 7, 8, 0],
[0, 0, 0, 0, 0],
- ])
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
+ ]
+ )
def test_rotate(self):
"""Test rotate."""
npt.assert_array_equal(
- np_adjust.rotate(self.data1),
- np.array([[8, 7, 6], [5, 4, 3], [2, 1, 0]]),
- err_msg='Incorrect rotation',
+ np_adjust.rotate(self.array33),
+ np.rot90(np.rot90(self.array33)),
+ err_msg="Incorrect rotation.",
)
def test_rotate_stack(self):
"""Test rotate_stack."""
npt.assert_array_equal(
- np_adjust.rotate_stack(self.data2),
- np.array([
- [[8, 7, 6], [5, 4, 3], [2, 1, 0]],
- [[17, 16, 15], [14, 13, 12], [11, 10, 9]],
- ]),
- err_msg='Incorrect stack rotation',
+ np_adjust.rotate_stack(self.array233),
+ np.rot90(self.array233, k=2, axes=(1, 2)),
+ err_msg="Incorrect stack rotation.",
)
- def test_pad2d(self):
+ @pytest.mark.parametrize(
+ "padding",
+ [
+ 1,
+ [1, 1],
+ np.array([1, 1]),
+ failparam("1", raises=ValueError),
+ ],
+ )
+ def test_pad2d(self, padding):
"""Test pad2d."""
- npt.assert_array_equal(
- np_adjust.pad2d(self.data1, (1, 1)),
- self.data3,
- err_msg='Incorrect padding',
- )
-
- npt.assert_array_equal(
- np_adjust.pad2d(self.data1, 1),
- self.data3,
- err_msg='Incorrect padding',
- )
-
- npt.assert_array_equal(
- np_adjust.pad2d(self.data1, np.array([1, 1])),
- self.data3,
- err_msg='Incorrect padding',
- )
-
- npt.assert_raises(ValueError, np_adjust.pad2d, self.data1, '1')
+ npt.assert_equal(np_adjust.pad2d(self.array33, padding), self.arraypad)
def test_fancy_transpose(self):
- """Test fancy_transpose."""
+ """Test fancy transpose."""
npt.assert_array_equal(
- np_adjust.fancy_transpose(self.data2),
- np.array([
- [[0, 3, 6], [9, 12, 15]],
- [[1, 4, 7], [10, 13, 16]],
- [[2, 5, 8], [11, 14, 17]],
- ]),
- err_msg='Incorrect fancy transpose',
+ np_adjust.fancy_transpose(self.array233),
+ np.array(
+ [
+ [[0, 3, 6], [9, 12, 15]],
+ [[1, 4, 7], [10, 13, 16]],
+ [[2, 5, 8], [11, 14, 17]],
+ ]
+ ),
+ err_msg="Incorrect fancy transpose",
)
def test_ftr(self):
"""Test ftr."""
npt.assert_array_equal(
- np_adjust.ftr(self.data2),
- np.array([
- [[0, 3, 6], [9, 12, 15]],
- [[1, 4, 7], [10, 13, 16]],
- [[2, 5, 8], [11, 14, 17]],
- ]),
- err_msg='Incorrect fancy transpose: ftr',
+ np_adjust.ftr(self.array233),
+ np.array(
+ [
+ [[0, 3, 6], [9, 12, 15]],
+ [[1, 4, 7], [10, 13, 16]],
+ [[2, 5, 8], [11, 14, 17]],
+ ]
+ ),
+ err_msg="Incorrect fancy transpose: ftr",
)
def test_ftl(self):
- """Test ftl."""
- npt.assert_array_equal(
- np_adjust.ftl(self.data2),
- np.array([
- [[0, 9], [1, 10], [2, 11]],
- [[3, 12], [4, 13], [5, 14]],
- [[6, 15], [7, 16], [8, 17]],
- ]),
- err_msg='Incorrect fancy transpose: ftl',
- )
-
-
-class TransformTestCase(TestCase):
- """Test case for transform module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.cube = np.arange(16).reshape((4, 2, 2))
- self.map = np.array(
- [[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]],
- )
- self.matrix = np.array(
- [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]],
- )
- self.layout = (2, 2)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.cube = None
- self.map = None
- self.layout = None
-
- def test_cube2map(self):
+ """Test fancy transpose left."""
+ npt.assert_array_equal(
+ np_adjust.ftl(self.array233),
+ np.array(
+ [
+ [[0, 9], [1, 10], [2, 11]],
+ [[3, 12], [4, 13], [5, 14]],
+ [[6, 15], [7, 16], [8, 17]],
+ ]
+ ),
+ err_msg="Incorrect fancy transpose: ftl",
+ )
+
+
+class TestTransforms:
+ """Test for the transform module."""
+
+ cube = np.arange(16).reshape((4, 2, 2))
+ map = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]])
+ matrix = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]])
+ layout = (2, 2)
+ fail_layout = (3, 3)
+
+ @pytest.mark.parametrize(
+ ("func", "indata", "layout", "outdata"),
+ [
+ (transform.cube2map, cube, layout, map),
+ failparam(transform.cube2map, np.eye(2), layout, map, raises=ValueError),
+ (transform.map2cube, map, layout, cube),
+ (transform.map2matrix, map, layout, matrix),
+ (transform.matrix2map, matrix, matrix.shape, map),
+ ],
+ )
+ def test_map(self, func, indata, layout, outdata):
"""Test cube2map."""
npt.assert_array_equal(
- transform.cube2map(self.cube, self.layout),
- self.map,
- err_msg='Incorrect transformation: cube2map',
- )
-
- npt.assert_raises(
- ValueError,
- transform.cube2map,
- self.map,
- self.layout,
- )
-
- npt.assert_raises(ValueError, transform.cube2map, self.cube, (3, 3))
-
- def test_map2cube(self):
- """Test map2cube."""
- npt.assert_array_equal(
- transform.map2cube(self.map, self.layout),
- self.cube,
- err_msg='Incorrect transformation: map2cube',
- )
-
- npt.assert_raises(ValueError, transform.map2cube, self.map, (3, 3))
-
- def test_map2matrix(self):
- """Test map2matrix."""
- npt.assert_array_equal(
- transform.map2matrix(self.map, self.layout),
- self.matrix,
- err_msg='Incorrect transformation: map2matrix',
- )
-
- def test_matrix2map(self):
- """Test matrix2map."""
- npt.assert_array_equal(
- transform.matrix2map(self.matrix, self.map.shape),
- self.map,
- err_msg='Incorrect transformation: matrix2map',
+ func(indata, layout),
+ outdata,
)
+ if func.__name__ != "map2matrix":
+ npt.assert_raises(ValueError, func, indata, self.fail_layout)
def test_cube2matrix(self):
"""Test cube2matrix."""
npt.assert_array_equal(
transform.cube2matrix(self.cube),
self.matrix,
- err_msg='Incorrect transformation: cube2matrix',
)
def test_matrix2cube(self):
@@ -194,136 +141,78 @@ def test_matrix2cube(self):
npt.assert_array_equal(
transform.matrix2cube(self.matrix, self.cube[0].shape),
self.cube,
- err_msg='Incorrect transformation: matrix2cube',
- )
-
-
-class TypesTestCase(TestCase):
- """Test case for types module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = list(range(5))
- self.data2 = np.arange(5)
- self.data3 = np.arange(5).astype(float)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
-
- def test_check_float(self):
- """Test check_float."""
- npt.assert_array_equal(
- types.check_float(1.0),
- 1.0,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_float(1),
- 1.0,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_float(self.data1),
- self.data3,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_float(self.data2),
- self.data3,
- err_msg='Float check failed',
- )
-
- npt.assert_raises(TypeError, types.check_float, '1')
-
- def test_check_int(self):
- """Test check_int."""
- npt.assert_array_equal(
- types.check_int(1),
- 1,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_int(1.0),
- 1,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_int(self.data1),
- self.data2,
- err_msg='Float check failed',
- )
-
- npt.assert_array_equal(
- types.check_int(self.data3),
- self.data2,
- err_msg='Int check failed',
- )
-
- npt.assert_raises(TypeError, types.check_int, '1')
-
- def test_check_npndarray(self):
+ err_msg="Incorrect transformation: matrix2cube",
+ )
+
+
+class TestType:
+ """Test for type module."""
+
+ data_list = list(range(5))
+ data_int = np.arange(5)
+ data_flt = np.arange(5).astype(float)
+
+ @pytest.mark.parametrize(
+ ("data", "checked"),
+ [
+ (1.0, 1.0),
+ (1, 1.0),
+ (data_list, data_flt),
+ (data_int, data_flt),
+ failparam("1.0", 1.0, raises=TypeError),
+ ],
+ )
+ def test_check_float(self, data, checked):
+ """Test check float."""
+ npt.assert_array_equal(types.check_float(data), checked)
+
+ @pytest.mark.parametrize(
+ ("data", "checked"),
+ [
+ (1.0, 1),
+ (1, 1),
+ (data_list, data_int),
+ (data_flt, data_int),
+ failparam("1", None, raises=TypeError),
+ ],
+ )
+ def test_check_int(self, data, checked):
+ """Test check int."""
+ npt.assert_array_equal(types.check_int(data), checked)
+
+ @pytest.mark.parametrize(
+ ("data", "dtype"), [(data_flt, np.integer), (data_int, np.floating)]
+ )
+ def test_check_npndarray(self, data, dtype):
"""Test check_npndarray."""
npt.assert_raises(
TypeError,
types.check_npndarray,
- self.data3,
- dtype=np.integer,
- )
-
-
-class TestBackend(TestCase):
- """Test the backend codes."""
-
- def setUp(self):
- """Set test parameter values."""
- self.input = np.array([10, 10])
-
- @skipIf(LIBRARIES['tensorflow'] is None, 'tensorflow library not installed')
- def test_tf_backend(self):
- """Test tensorflow backend."""
- xp, backend = get_backend('tensorflow')
- if backend != 'tensorflow' or xp != LIBRARIES['tensorflow']:
- raise AssertionError('tensorflow get_backend fails!')
- tf_input = change_backend(self.input, 'tensorflow')
- if (
- get_array_module(LIBRARIES['tensorflow'].ones(1)) != LIBRARIES['tensorflow']
- or get_array_module(tf_input) != LIBRARIES['tensorflow']
- ):
- raise AssertionError('tensorflow backend fails!')
-
- @skipIf(LIBRARIES['cupy'] is None, 'cupy library not installed')
- def test_cp_backend(self):
- """Test cupy backend."""
- xp, backend = get_backend('cupy')
- if backend != 'cupy' or xp != LIBRARIES['cupy']:
- raise AssertionError('cupy get_backend fails!')
- cp_input = change_backend(self.input, 'cupy')
- if (
- get_array_module(LIBRARIES['cupy'].ones(1)) != LIBRARIES['cupy']
- or get_array_module(cp_input) != LIBRARIES['cupy']
- ):
- raise AssertionError('cupy backend fails!')
-
- def test_np_backend(self):
- """Test numpy backend."""
- xp, backend = get_backend('numpy')
- if backend != 'numpy' or xp != LIBRARIES['numpy']:
- raise AssertionError('numpy get_backend fails!')
- np_input = change_backend(self.input, 'numpy')
- if (
- get_array_module(LIBRARIES['numpy'].ones(1)) != LIBRARIES['numpy']
- or get_array_module(np_input) != LIBRARIES['numpy']
- ):
- raise AssertionError('numpy backend fails!')
-
- def tearDown(self):
- """Tear Down of objects."""
- self.input = None
+ data,
+ dtype=dtype,
+ )
+
+ def test_check_callable(self):
+ """Test callable."""
+ npt.assert_raises(TypeError, types.check_callable, 1)
+
+
+@pytest.mark.parametrize(
+ "backend_name",
+ [
+ skipparam(name, cond=LIBRARIES[name] is None, reason=f"{name} not installed")
+ for name in LIBRARIES
+ ],
+)
+def test_tf_backend(backend_name):
+ """Test Modopt computational backends."""
+ xp, checked_backend_name = backend.get_backend(backend_name)
+ if checked_backend_name != backend_name or xp != LIBRARIES[backend_name]:
+ raise AssertionError(f"{backend_name} get_backend fails!")
+ xp_input = backend.change_backend(np.array([10, 10]), backend_name)
+ if (
+ backend.get_array_module(LIBRARIES[backend_name].ones(1))
+ != backend.LIBRARIES[backend_name]
+ or backend.get_array_module(xp_input) != LIBRARIES[backend_name]
+ ):
+ raise AssertionError(f"{backend_name} backend fails!")
diff --git a/modopt/tests/test_helpers/__init__.py b/modopt/tests/test_helpers/__init__.py
new file mode 100644
index 00000000..3886b877
--- /dev/null
+++ b/modopt/tests/test_helpers/__init__.py
@@ -0,0 +1 @@
+from .utils import failparam, skipparam, Dummy
diff --git a/modopt/tests/test_helpers/utils.py b/modopt/tests/test_helpers/utils.py
new file mode 100644
index 00000000..d8227640
--- /dev/null
+++ b/modopt/tests/test_helpers/utils.py
@@ -0,0 +1,23 @@
+"""
+Some helper functions for the test parametrization.
+They should be used inside ``@pytest.mark.parametrize`` call.
+
+:Author: Pierre-Antoine Comby
+"""
+import pytest
+
+
+def failparam(*args, raises=None):
+ """Return a pytest parameterization that should raise an error."""
+ if not issubclass(raises, Exception):
+ raise ValueError("raises should be an expected Exception.")
+ return pytest.param(*args, marks=pytest.mark.raises(exception=raises))
+
+
+def skipparam(*args, cond=True, reason=""):
+ """Return a pytest parameterization that should be skip if cond is valid."""
+ return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason))
+
+
+class Dummy:
+ pass
diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py
index 99908e02..ea177b15 100644
--- a/modopt/tests/test_math.py
+++ b/modopt/tests/test_math.py
@@ -1,215 +1,181 @@
-# -*- coding: utf-8 -*-
-
"""UNIT TESTS FOR MATH.
This module contains unit tests for the modopt.math module.
-:Author: Samuel Farrens
-
+:Authors:
+ Samuel Farrens
+ Pierre-Antoine Comby
"""
-
-from unittest import TestCase, skipIf, skipUnless
+import pytest
+from test_helpers import failparam, skipparam
import numpy as np
import numpy.testing as npt
+
from modopt.math import convolve, matrix, metrics, stats
try:
import astropy
except ImportError: # pragma: no cover
- import_astropy = False
+ ASTROPY_AVAILABLE = False
else: # pragma: no cover
- import_astropy = True
+ ASTROPY_AVAILABLE = True
try:
from skimage.metrics import structural_similarity as compare_ssim
except ImportError: # pragma: no cover
- import_skimage = False
+ SKIMAGE_AVAILABLE = False
else:
- import_skimage = True
-
-
-class ConvolveTestCase(TestCase):
- """Test case for convolve module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(18).reshape(2, 3, 3)
- self.data2 = self.data1 + 1
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
-
- @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover
- def test_convolve_astropy(self):
- """Test convolve using astropy."""
- npt.assert_allclose(
- convolve.convolve(self.data1[0], self.data2[0], method='astropy'),
- np.array([
- [210.0, 201.0, 210.0],
- [129.0, 120.0, 129.0],
- [210.0, 201.0, 210.0],
- ]),
- err_msg='Incorrect convolution: astropy',
- )
-
- npt.assert_raises(
- ValueError,
- convolve.convolve,
- self.data1[0],
- self.data2,
- )
-
- npt.assert_raises(
- ValueError,
- convolve.convolve,
- self.data1[0],
- self.data2[0],
- method='bla',
- )
-
- def test_convolve_scipy(self):
- """Test convolve using scipy."""
- npt.assert_allclose(
- convolve.convolve(self.data1[0], self.data2[0], method='scipy'),
- np.array([
+ SKIMAGE_AVAILABLE = True
+
+
+class TestConvolve:
+ """Test convolve functions."""
+
+ array233 = np.arange(18).reshape((2, 3, 3))
+ array233_1 = array233 + 1
+ result_astropy = np.array(
+ [
+ [210.0, 201.0, 210.0],
+ [129.0, 120.0, 129.0],
+ [210.0, 201.0, 210.0],
+ ]
+ )
+ result_scipy = np.array(
+ [
+ [
[14.0, 35.0, 38.0],
[57.0, 120.0, 111.0],
[110.0, 197.0, 158.0],
- ]),
- err_msg='Incorrect convolution: scipy',
- )
-
- def test_convolve_stack(self):
- """Test convolve_stack."""
+ ],
+ [
+ [518.0, 845.0, 614.0],
+ [975.0, 1578.0, 1137.0],
+ [830.0, 1331.0, 950.0],
+ ],
+ ]
+ )
+
+ result_rot_kernel = np.array(
+ [
+ [
+ [66.0, 115.0, 82.0],
+ [153.0, 240.0, 159.0],
+ [90.0, 133.0, 82.0],
+ ],
+ [
+ [714.0, 1087.0, 730.0],
+ [1125.0, 1698.0, 1131.0],
+ [738.0, 1105.0, 730.0],
+ ],
+ ]
+ )
+
+ @pytest.mark.parametrize(
+ ("input_data", "kernel", "method", "result"),
+ [
+ skipparam(
+ array233[0],
+ array233_1[0],
+ "astropy",
+ result_astropy,
+ cond=not ASTROPY_AVAILABLE,
+ reason="astropy not available",
+ ),
+ failparam(
+ array233[0], array233_1, "astropy", result_astropy, raises=ValueError
+ ),
+ failparam(
+ array233[0], array233_1[0], "fail!", result_astropy, raises=ValueError
+ ),
+ (array233[0], array233_1[0], "scipy", result_scipy[0]),
+ ],
+ )
+ def test_convolve(self, input_data, kernel, method, result):
+ """Test convolve function."""
+ npt.assert_allclose(convolve.convolve(input_data, kernel, method), result)
+
+ @pytest.mark.parametrize(
+ ("result", "rot_kernel"),
+ [
+ (result_scipy, False),
+ (result_rot_kernel, True),
+ ],
+ )
+ def test_convolve_stack(self, result, rot_kernel):
+ """Test convolve stack function."""
npt.assert_allclose(
- convolve.convolve_stack(self.data1, self.data2),
- np.array([
- [
- [14.0, 35.0, 38.0],
- [57.0, 120.0, 111.0],
- [110.0, 197.0, 158.0],
- ],
- [
- [518.0, 845.0, 614.0],
- [975.0, 1578.0, 1137.0],
- [830.0, 1331.0, 950.0],
- ],
- ]),
- err_msg='Incorrect convolution: stack',
+ convolve.convolve_stack(
+ self.array233, self.array233_1, rot_kernel=rot_kernel
+ ),
+ result,
)
- def test_convolve_stack_rot(self):
- """Test convolve_stack rotated."""
- npt.assert_allclose(
- convolve.convolve_stack(self.data1, self.data2, rot_kernel=True),
- np.array([
- [
- [66.0, 115.0, 82.0],
- [153.0, 240.0, 159.0],
- [90.0, 133.0, 82.0],
- ],
- [
- [714.0, 1087.0, 730.0],
- [1125.0, 1698.0, 1131.0],
- [738.0, 1105.0, 730.0],
- ],
- ]),
- err_msg='Incorrect convolution: stack rot',
- )
+class TestMatrix:
+ """Test matrix module."""
-class MatrixTestCase(TestCase):
- """Test case for matrix module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3)
- self.data2 = np.arange(3)
- self.data3 = np.arange(6).reshape(2, 3)
- np.random.seed(1)
- self.pmInstance1 = matrix.PowerMethod(
- lambda x_val: x_val.dot(x_val.T),
- self.data1.shape,
- verbose=True,
- )
- np.random.seed(1)
- self.pmInstance2 = matrix.PowerMethod(
- lambda x_val: x_val.dot(x_val.T),
- self.data1.shape,
- auto_run=False,
- verbose=True,
- )
- self.pmInstance2.get_spec_rad(max_iter=1)
- self.gram_schmidt_out = (
- np.array([
+ array3 = np.arange(3)
+ array33 = np.arange(9).reshape((3, 3))
+ array23 = np.arange(6).reshape((2, 3))
+ gram_schmidt_out = (
+ np.array(
+ [
[0, 1.0, 2.0],
[3.0, 1.2, -6e-1],
[-1.77635684e-15, 0, 0],
- ]),
- np.array([
+ ]
+ ),
+ np.array(
+ [
[0, 0.4472136, 0.89442719],
[0.91287093, 0.36514837, -0.18257419],
[-1.0, 0, 0],
- ]),
- )
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.pmInstance1 = None
- self.pmInstance2 = None
- self.gram_schmidt_out = None
-
- def test_gram_schmidt_orthonormal(self):
- """Test gram_schmidt with orthonormal output."""
- npt.assert_allclose(
- matrix.gram_schmidt(self.data1),
- self.gram_schmidt_out[1],
- err_msg='Incorrect Gram-Schmidt: orthonormal',
- )
+ ]
+ ),
+ )
- npt.assert_raises(
- ValueError,
- matrix.gram_schmidt,
- self.data1,
- return_opt='bla',
- )
-
- def test_gram_schmidt_orthogonal(self):
- """Test gram_schmidt with orthogonal output."""
- npt.assert_allclose(
- matrix.gram_schmidt(self.data1, return_opt='orthogonal'),
- self.gram_schmidt_out[0],
- err_msg='Incorrect Gram-Schmidt: orthogonal',
+ @pytest.fixture
+ def pm_instance(self, request):
+ """Power Method instance."""
+ np.random.seed(1)
+ pm = matrix.PowerMethod(
+ lambda x_val: x_val.dot(x_val.T),
+ self.array33.shape,
+ auto_run=request.param,
+ verbose=True,
)
-
- def test_gram_schmidt_both(self):
- """Test gram_schmidt with both outputs."""
+ if not request.param:
+ pm.get_spec_rad(max_iter=1)
+ return pm
+
+ @pytest.mark.parametrize(
+ ("return_opt", "output"),
+ [
+ ("orthonormal", gram_schmidt_out[1]),
+ ("orthogonal", gram_schmidt_out[0]),
+ ("both", gram_schmidt_out),
+ failparam("fail!", gram_schmidt_out, raises=ValueError),
+ ],
+ )
+ def test_gram_schmidt(self, return_opt, output):
+ """Test gram schmidt."""
npt.assert_allclose(
- matrix.gram_schmidt(self.data1, return_opt='both'),
- self.gram_schmidt_out,
- err_msg='Incorrect Gram-Schmidt: both',
+ matrix.gram_schmidt(self.array33, return_opt=return_opt), output
)
def test_nuclear_norm(self):
- """Test nuclear_norm."""
+ """Test nuclear norm."""
npt.assert_almost_equal(
- matrix.nuclear_norm(self.data1),
+ matrix.nuclear_norm(self.array33),
15.49193338482967,
- err_msg='Incorrect nuclear norm',
)
def test_project(self):
"""Test project."""
npt.assert_array_equal(
- matrix.project(self.data2, self.data2 + 3),
+ matrix.project(self.array3, self.array3 + 3),
np.array([0, 2.8, 5.6]),
- err_msg='Incorrect projection',
)
def test_rot_matrix(self):
@@ -217,280 +183,149 @@ def test_rot_matrix(self):
npt.assert_allclose(
matrix.rot_matrix(np.pi / 6),
np.array([[0.8660254, -0.5], [0.5, 0.8660254]]),
- err_msg='Incorrect rotation matrix',
)
def test_rotate(self):
"""Test rotate."""
npt.assert_array_equal(
- matrix.rotate(self.data1, np.pi / 2),
+ matrix.rotate(self.array33, np.pi / 2),
np.array([[2, 5, 8], [1, 4, 7], [0, 3, 6]]),
- err_msg='Incorrect rotation',
- )
-
- npt.assert_raises(ValueError, matrix.rotate, self.data3, np.pi / 2)
-
- def test_powermethod_converged(self):
- """Test PowerMethod converged."""
- npt.assert_almost_equal(
- self.pmInstance1.spec_rad,
- 0.90429242629600837,
- err_msg='Incorrect spectral radius: converged',
)
- npt.assert_almost_equal(
- self.pmInstance1.inv_spec_rad,
- 1.1058369736612865,
- err_msg='Incorrect inverse spectral radius: converged',
- )
-
- def test_powermethod_unconverged(self):
- """Test PowerMethod unconverged."""
- npt.assert_almost_equal(
- self.pmInstance2.spec_rad,
- 0.92048833577059219,
- err_msg='Incorrect spectral radius: unconverged',
- )
-
- npt.assert_almost_equal(
- self.pmInstance2.inv_spec_rad,
- 1.0863798715741946,
- err_msg='Incorrect inverse spectral radius: unconverged',
- )
-
-
-class MetricsTestCase(TestCase):
- """Test case for metrics module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(49).reshape(7, 7)
- self.mask = np.ones(self.data1.shape)
- self.ssim_res = 0.8963363560519094
- self.ssim_mask_res = 0.805154442543846
- self.snr_res = 10.134554256920536
- self.psnr_res = 14.860761791850397
- self.mse_res = 0.03265305507330247
- self.nrmse_res = 0.31136678840022625
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.mask = None
- self.ssim_res = None
- self.ssim_mask_res = None
- self.psnr_res = None
- self.mse_res = None
- self.nrmse_res = None
-
- @skipIf(import_skimage, 'skimage is installed.') # pragma: no cover
- def test_ssim_skimage_error(self):
- """Test ssim skimage error."""
- npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1)
-
- @skipUnless(import_skimage, 'skimage not installed.') # pragma: no cover
- def test_ssim(self):
+ npt.assert_raises(ValueError, matrix.rotate, self.array23, np.pi / 2)
+
+ @pytest.mark.parametrize(
+ ("pm_instance", "value"),
+ [(True, 1.0), (False, 0.8675467477372257)],
+ indirect=["pm_instance"],
+ )
+ def test_power_method(self, pm_instance, value):
+ """Test power method."""
+ npt.assert_almost_equal(pm_instance.spec_rad, value)
+ npt.assert_almost_equal(pm_instance.inv_spec_rad, 1 / value)
+
+
+class TestMetrics:
+ """Test metrics module."""
+
+ data1 = np.arange(49).reshape(7, 7)
+ mask = np.ones(data1.shape)
+ ssim_res = 0.8963363560519094
+ ssim_mask_res = 0.805154442543846
+ snr_res = 10.134554256920536
+ psnr_res = 14.860761791850397
+ mse_res = 0.03265305507330247
+ nrmse_res = 0.31136678840022625
+
+ @pytest.mark.skipif(not SKIMAGE_AVAILABLE, reason="skimage not installed")
+ @pytest.mark.parametrize(
+ ("data1", "data2", "result", "mask"),
+ [
+ (data1, data1**2, ssim_res, None),
+ (data1, data1**2, ssim_mask_res, mask),
+ failparam(data1, data1, None, 1, raises=ValueError),
+ ],
+ )
+ def test_ssim(self, data1, data2, result, mask):
"""Test ssim."""
- npt.assert_almost_equal(
- metrics.ssim(self.data1, self.data1 ** 2),
- self.ssim_res,
- err_msg='Incorrect SSIM result',
- )
+ npt.assert_almost_equal(metrics.ssim(data1, data2, mask=mask), result)
- npt.assert_almost_equal(
- metrics.ssim(self.data1, self.data1 ** 2, mask=self.mask),
- self.ssim_mask_res,
- err_msg='Incorrect SSIM result',
- )
-
- npt.assert_raises(
- ValueError,
- metrics.ssim,
- self.data1,
- self.data1,
- mask=1,
- )
+ @pytest.mark.skipif(SKIMAGE_AVAILABLE, reason="skimage installed")
+ def test_ssim_fail(self):
+ """Test ssim."""
+ npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1)
- def test_snr(self):
+ @pytest.mark.parametrize(
+ ("metric", "data", "result", "mask"),
+ [
+ (metrics.snr, data1, snr_res, None),
+ (metrics.snr, data1, snr_res, mask),
+ (metrics.psnr, data1, psnr_res, None),
+ (metrics.psnr, data1, psnr_res, mask),
+ (metrics.mse, data1, mse_res, None),
+ (metrics.mse, data1, mse_res, mask),
+ (metrics.nrmse, data1, nrmse_res, None),
+ (metrics.nrmse, data1, nrmse_res, mask),
+ failparam(metrics.snr, data1, snr_res, "maskfail", raises=ValueError),
+ ],
+ )
+ def test_metric(self, metric, data, result, mask):
"""Test snr."""
- npt.assert_almost_equal(
- metrics.snr(self.data1, self.data1 ** 2),
- self.snr_res,
- err_msg='Incorrect SNR result',
- )
-
- npt.assert_almost_equal(
- metrics.snr(self.data1, self.data1 ** 2, mask=self.mask),
- self.snr_res,
- err_msg='Incorrect SNR result',
- )
-
- def test_psnr(self):
- """Test psnr."""
- npt.assert_almost_equal(
- metrics.psnr(self.data1, self.data1 ** 2),
- self.psnr_res,
- err_msg='Incorrect PSNR result',
- )
-
- npt.assert_almost_equal(
- metrics.psnr(self.data1, self.data1 ** 2, mask=self.mask),
- self.psnr_res,
- err_msg='Incorrect PSNR result',
- )
-
- def test_mse(self):
- """Test mse."""
- npt.assert_almost_equal(
- metrics.mse(self.data1, self.data1 ** 2),
- self.mse_res,
- err_msg='Incorrect MSE result',
- )
-
- npt.assert_almost_equal(
- metrics.mse(self.data1, self.data1 ** 2, mask=self.mask),
- self.mse_res,
- err_msg='Incorrect MSE result',
- )
-
- def test_nrmse(self):
- """Test nrmse."""
- npt.assert_almost_equal(
- metrics.nrmse(self.data1, self.data1 ** 2),
- self.nrmse_res,
- err_msg='Incorrect NRMSE result',
- )
-
- npt.assert_almost_equal(
- metrics.nrmse(self.data1, self.data1 ** 2, mask=self.mask),
- self.nrmse_res,
- err_msg='Incorrect NRMSE result',
- )
-
-
-class StatsTestCase(TestCase):
- """Test case for stats module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3)
- self.data2 = np.arange(18).reshape(2, 3, 3)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
-
- @skipIf(import_astropy, 'Astropy is installed.') # pragma: no cover
- def test_gaussian_kernel_astropy_error(self):
- """Test gaussian_kernel astropy error."""
- npt.assert_raises(
- ImportError,
- stats.gaussian_kernel,
- self.data1.shape,
- 1,
- )
-
- @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover
- def test_gaussian_kernel_max(self):
- """Test gaussian_kernel with max norm."""
+ npt.assert_almost_equal(metric(data, data**2, mask=mask), result)
+
+
+class TestStats:
+ """Test stats module."""
+
+ array33 = np.arange(9).reshape(3, 3)
+ array233 = np.arange(18).reshape(2, 3, 3)
+
+ @pytest.mark.skipif(not ASTROPY_AVAILABLE, reason="astropy not installed")
+ @pytest.mark.parametrize(
+ ("norm", "result"),
+ [
+ (
+ "max",
+ np.array(
+ [
+ [0.36787944, 0.60653066, 0.36787944],
+ [0.60653066, 1.0, 0.60653066],
+ [0.36787944, 0.60653066, 0.36787944],
+ ]
+ ),
+ ),
+ (
+ "sum",
+ np.array(
+ [
+ [0.07511361, 0.1238414, 0.07511361],
+ [0.1238414, 0.20417996, 0.1238414],
+ [0.07511361, 0.1238414, 0.07511361],
+ ]
+ ),
+ ),
+ failparam("fail", None, raises=ValueError),
+ ],
+ )
+ def test_gaussian_kernel(self, norm, result):
+ """Test Gaussian kernel."""
npt.assert_allclose(
- stats.gaussian_kernel(self.data1.shape, 1),
- np.array([
- [0.36787944, 0.60653066, 0.36787944],
- [0.60653066, 1.0, 0.60653066],
- [0.36787944, 0.60653066, 0.36787944],
- ]),
- err_msg='Incorrect gaussian kernel: max norm',
+ stats.gaussian_kernel(self.array33.shape, 1, norm=norm), result
)
- npt.assert_raises(
- ValueError,
- stats.gaussian_kernel,
- self.data1.shape,
- 1,
- norm='bla',
- )
-
- @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover
- def test_gaussian_kernel_sum(self):
- """Test gaussian_kernel with sum norm."""
- npt.assert_allclose(
- stats.gaussian_kernel(self.data1.shape, 1, norm='sum'),
- np.array([
- [0.07511361, 0.1238414, 0.07511361],
- [0.1238414, 0.20417996, 0.1238414],
- [0.07511361, 0.1238414, 0.07511361],
- ]),
- err_msg='Incorrect gaussian kernel: sum norm',
- )
-
- @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover
- def test_gaussian_kernel_none(self):
- """Test gaussian_kernel with no norm."""
- npt.assert_allclose(
- stats.gaussian_kernel(self.data1.shape, 1, norm='none'),
- np.array([
- [0.05854983, 0.09653235, 0.05854983],
- [0.09653235, 0.15915494, 0.09653235],
- [0.05854983, 0.09653235, 0.05854983],
- ]),
- err_msg='Incorrect gaussian kernel: sum norm',
- )
+ @pytest.mark.skipif(ASTROPY_AVAILABLE, reason="astropy installed")
+ def test_import_astropy(self):
+ """Test missing astropy."""
+ npt.assert_raises(ImportError, stats.gaussian_kernel, self.array33.shape, 1)
def test_mad(self):
"""Test mad."""
- npt.assert_equal(
- stats.mad(self.data1),
- 2.0,
- err_msg='Incorrect median absolute deviation',
- )
-
- def test_mse(self):
- """Test mse."""
- npt.assert_equal(
- stats.mse(self.data1, self.data1 + 2),
- 4.0,
- err_msg='Incorrect mean squared error',
- )
+ npt.assert_equal(stats.mad(self.array33), 2.0)
- def test_psnr_starck(self):
- """Test psnr."""
+ def test_sigma_mad(self):
+ """Test sigma_mad."""
npt.assert_almost_equal(
- stats.psnr(self.data1, self.data1 + 2),
- 12.041199826559248,
- err_msg='Incorrect PSNR: starck',
- )
-
- npt.assert_raises(
- ValueError,
- stats.psnr,
- self.data1,
- self.data1,
- method='bla',
+ stats.sigma_mad(self.array33),
+ 2.9651999999999998,
)
- def test_psnr_wiki(self):
- """Test psnr wiki method."""
- npt.assert_almost_equal(
- stats.psnr(self.data1, self.data1 + 2, method='wiki'),
- 42.110203695399477,
- err_msg='Incorrect PSNR: wiki',
- )
+ @pytest.mark.parametrize(
+ ("data1", "data2", "method", "result"),
+ [
+ (array33, array33 + 2, "starck", 12.041199826559248),
+ failparam(array33, array33, "fail", 0, raises=ValueError),
+ (array33, array33 + 2, "wiki", 42.110203695399477),
+ ],
+ )
+ def test_psnr(self, data1, data2, method, result):
+ """Test PSNR."""
+ npt.assert_almost_equal(stats.psnr(data1, data2, method=method), result)
def test_psnr_stack(self):
"""Test psnr stack."""
npt.assert_almost_equal(
- stats.psnr_stack(self.data2, self.data2 + 2),
+ stats.psnr_stack(self.array233, self.array233 + 2),
12.041199826559248,
- err_msg='Incorrect PSNR stack',
)
- npt.assert_raises(ValueError, stats.psnr_stack, self.data1, self.data1)
-
- def test_sigma_mad(self):
- """Test sigma_mad."""
- npt.assert_almost_equal(
- stats.sigma_mad(self.data1),
- 2.9651999999999998,
- err_msg='Incorrect sigma from MAD',
- )
+ npt.assert_raises(ValueError, stats.psnr_stack, self.array33, self.array33)
diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py
index d5547783..4a82e33c 100644
--- a/modopt/tests/test_opt.py
+++ b/modopt/tests/test_opt.py
@@ -1,718 +1,293 @@
-# -*- coding: utf-8 -*-
-
"""UNIT TESTS FOR OPT.
-This module contains unit tests for the modopt.opt module.
-
-:Author: Samuel Farrens
+This module contains tests for the modopt.opt module.
+:Authors:
+ Samuel Farrens
+ Pierre-Antoine Comby
"""
-from builtins import zip
-from unittest import TestCase, skipIf, skipUnless
-
import numpy as np
import numpy.testing as npt
+import pytest
+from pytest_cases import parametrize, parametrize_with_cases, case, fixture, fixture_ref
+
+from modopt.opt import cost, gradient, linear, proximity, reweight
-from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight
+from test_helpers import Dummy
+SKLEARN_AVAILABLE = True
try:
import sklearn
-except ImportError: # pragma: no cover
- import_sklearn = False
-else:
- import_sklearn = True
+except ImportError:
+ SKLEARN_AVAILABLE = False
+PYWT_AVAILABLE = True
+try:
+ import pywt
+ import joblib
+except ImportError:
+ PYWT_AVAILABLE = False
# Basic functions to be used as operators or as dummy functions
func_identity = lambda x_val: x_val
func_double = lambda x_val: x_val * 2
-func_sq = lambda x_val: x_val ** 2
-func_cube = lambda x_val: x_val ** 3
-
-
-class Dummy(object):
- """Dummy class for tests."""
-
- pass
-
-
-class AlgorithmTestCase(TestCase):
- """Test case for algorithms module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6
- self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1
-
- grad_inst = gradient.GradBasic(
- self.data1,
- func_identity,
- func_identity,
- )
-
- prox_inst = proximity.Positivity()
- prox_dual_inst = proximity.IdentityProx()
- linear_inst = linear.Identity()
- reweight_inst = reweight.cwbReweight(self.data3)
- cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
- self.setup = algorithms.SetUp()
- self.max_iter = 20
-
- self.fb_all_iter = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=None,
- auto_iterate=False,
- beta_update=func_identity,
- )
- self.fb_all_iter.iterate(self.max_iter)
-
- self.fb1 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- )
-
- self.fb2 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- cost=cost_inst,
- lambda_update=None,
- )
-
- self.fb3 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- a_cd=3,
- )
-
- self.fb4 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- beta_update=func_identity,
- r_lazy=3,
- p_lazy=0.7,
- q_lazy=0.7,
- )
-
- self.fb5 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- restart_strategy='adaptive',
- xi_restart=0.9,
- )
-
- self.fb6 = algorithms.ForwardBackward(
- self.data1,
- grad=grad_inst,
- prox=prox_inst,
- restart_strategy='greedy',
- xi_restart=0.9,
- min_beta=1.0,
- s_greedy=1.1,
- )
-
- self.gfb_all_iter = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=None,
- auto_iterate=False,
- gamma_update=func_identity,
- beta_update=func_identity,
- )
- self.gfb_all_iter.iterate(self.max_iter)
-
- self.gfb1 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- gamma_update=func_identity,
- lambda_update=func_identity,
- )
-
- self.gfb2 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=cost_inst,
- )
-
- self.gfb3 = algorithms.GenForwardBackward(
- self.data1,
- grad=grad_inst,
- prox_list=[prox_inst, prox_dual_inst],
- cost=cost_inst,
- step_size=2,
- )
-
- self.condat_all_iter = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- cost=None,
- prox_dual=prox_dual_inst,
- sigma_update=func_identity,
- tau_update=func_identity,
- rho_update=func_identity,
- auto_iterate=False,
- )
- self.condat_all_iter.iterate(self.max_iter)
-
- self.condat1 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- sigma_update=func_identity,
- tau_update=func_identity,
- rho_update=func_identity,
- )
-
- self.condat2 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- linear=linear_inst,
- cost=cost_inst,
- reweight=reweight_inst,
- )
-
- self.condat3 = algorithms.Condat(
- self.data1,
- self.data2,
- grad=grad_inst,
- prox=prox_inst,
- prox_dual=prox_dual_inst,
- linear=Dummy(),
- cost=cost_inst,
- auto_iterate=False,
- )
-
- self.pogm_all_iter = algorithms.POGM(
- u=self.data1,
- x=self.data1,
- y=self.data1,
- z=self.data1,
- grad=grad_inst,
- prox=prox_inst,
- auto_iterate=False,
- cost=None,
- )
- self.pogm_all_iter.iterate(self.max_iter)
-
- self.pogm1 = algorithms.POGM(
- u=self.data1,
- x=self.data1,
- y=self.data1,
- z=self.data1,
- grad=grad_inst,
- prox=prox_inst,
- )
-
- self.dummy = Dummy()
- self.dummy.cost = func_identity
- self.setup._check_operator(self.dummy.cost)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.setup = None
- self.fb_all_iter = None
- self.fb1 = None
- self.fb2 = None
- self.gfb_all_iter = None
- self.gfb1 = None
- self.gfb2 = None
- self.condat_all_iter = None
- self.condat1 = None
- self.condat2 = None
- self.condat3 = None
- self.pogm1 = None
- self.pogm_all_iter = None
- self.dummy = None
-
- def test_set_up(self):
- """Test set_up."""
- npt.assert_raises(TypeError, self.setup._check_input_data, 1)
-
- npt.assert_raises(TypeError, self.setup._check_param, 1)
-
- npt.assert_raises(TypeError, self.setup._check_param_update, 1)
-
- def test_all_iter(self):
- """Test if all opt run for all iterations."""
- opts = [
- self.fb_all_iter,
- self.gfb_all_iter,
- self.condat_all_iter,
- self.pogm_all_iter,
- ]
- for opt in opts:
- npt.assert_equal(opt.idx, self.max_iter - 1)
-
- def test_forward_backward(self):
- """Test forward_backward."""
- npt.assert_array_equal(
- self.fb1.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb2.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb3.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb4.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb5.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.fb6.x_final,
- self.data1,
- err_msg='Incorrect ForwardBackward result.',
- )
-
- def test_gen_forward_backward(self):
- """Test gen_forward_backward."""
- npt.assert_array_equal(
- self.gfb1.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.gfb2.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_array_equal(
- self.gfb3.x_final,
- self.data1,
- err_msg='Incorrect GenForwardBackward result.',
- )
-
- npt.assert_equal(
- self.gfb3.step_size,
- 2,
- err_msg='Incorrect step size.',
- )
-
- npt.assert_raises(
- TypeError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=1,
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[1],
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[0.5, 0.5],
- )
-
- npt.assert_raises(
- ValueError,
- algorithms.GenForwardBackward,
- self.data1,
- self.dummy,
- [self.dummy],
- weights=[0.5],
- )
-
- def test_condat(self):
- """Test gen_condat."""
- npt.assert_almost_equal(
- self.condat1.x_final,
- self.data1,
- err_msg='Incorrect Condat result.',
- )
+func_sq = lambda x_val: x_val**2
+func_cube = lambda x_val: x_val**3
+
+
+@case(tags="cost")
+@parametrize(
+ ("cost_interval", "n_calls", "converged"),
+ [(1, 1, False), (1, 2, True), (2, 5, False), (None, 6, False)],
+)
+def case_cost_op(cost_interval, n_calls, converged):
+ """Case function for costs."""
+ dummy_inst1 = Dummy()
+ dummy_inst1.cost = func_sq
+ dummy_inst2 = Dummy()
+ dummy_inst2.cost = func_cube
+
+ cost_obj = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=cost_interval)
+
+ for _ in range(n_calls + 1):
+ cost_obj.get_cost(2)
+ return cost_obj, converged
+
+
+@parametrize_with_cases("cost_obj, converged", cases=".", has_tag="cost")
+def test_costs(cost_obj, converged):
+ """Test cost."""
+ npt.assert_equal(cost_obj.get_cost(2), converged)
+ if cost_obj._cost_interval:
+ npt.assert_equal(cost_obj.cost, 12)
+
+
+def test_raise_cost():
+ """Test error raising for cost."""
+ npt.assert_raises(TypeError, cost.costObj, 1)
+ npt.assert_raises(ValueError, cost.costObj, [Dummy(), Dummy()])
+
+
+@case(tags="grad")
+@parametrize(call=("op", "trans_op", "trans_op_op"))
+def case_grad_parent(call):
+ """Case for gradient parent."""
+ input_data = np.arange(9).reshape(3, 3)
+ callables = {
+ "op": func_sq,
+ "trans_op": func_cube,
+ "get_grad": func_identity,
+ "cost": lambda input_val: 1.0,
+ }
+
+ grad_op = gradient.GradParent(
+ input_data,
+ **callables,
+ data_type=np.floating,
+ )
+ if call != "trans_op_op":
+ result = callables[call](input_data)
+ else:
+ result = callables["trans_op"](callables["op"](input_data))
+
+ grad_call = getattr(grad_op, call)(input_data)
+ return grad_call, result
+
+
+@parametrize_with_cases("grad_values, result", cases=".", has_tag="grad")
+def test_grad_op(grad_values, result):
+ """Test Gradient operator."""
+ npt.assert_equal(grad_values, result)
+
+
+@pytest.fixture
+def grad_basic():
+ """Case for GradBasic."""
+ input_data = np.arange(9).reshape(3, 3)
+ grad_op = gradient.GradBasic(
+ input_data,
+ func_sq,
+ func_cube,
+ verbose=True,
+ )
+ grad_op.get_grad(input_data)
+ return grad_op
+
+
+def test_grad_basic(grad_basic):
+ """Test grad basic."""
+ npt.assert_array_equal(
+ grad_basic.grad,
+ np.array(
+ [
+ [0, 0, 8.0],
+ [2.16000000e2, 1.72800000e3, 8.0e3],
+ [2.70000000e4, 7.40880000e4, 1.75616000e5],
+ ]
+ ),
+ err_msg="Incorrect gradient.",
+ )
- npt.assert_almost_equal(
- self.condat2.x_final,
- self.data1,
- err_msg='Incorrect Condat result.',
- )
- def test_pogm(self):
- """Test pogm."""
- npt.assert_almost_equal(
- self.pogm1.x_final,
- self.data1,
- err_msg='Incorrect POGM result.',
- )
+def test_grad_basic_cost(grad_basic):
+ """Test grad_basic cost."""
+ npt.assert_almost_equal(grad_basic.cost(np.arange(9).reshape(3, 3)), 3192.0)
-class CostTestCase(TestCase):
- """Test case for cost module."""
+def test_grad_op_raises():
+ """Test raise error."""
+ npt.assert_raises(
+ TypeError,
+ gradient.GradParent,
+ 1,
+ func_sq,
+ func_cube,
+ )
- def setUp(self):
- """Set test parameter values."""
- dummy_inst1 = Dummy()
- dummy_inst1.cost = func_sq
- dummy_inst2 = Dummy()
- dummy_inst2.cost = func_cube
- self.inst1 = cost.costObj([dummy_inst1, dummy_inst2])
- self.inst2 = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=2)
- # Test that by default cost of False if interval is None
- self.inst_none = cost.costObj(
- [dummy_inst1, dummy_inst2],
- cost_interval=None,
- )
- for _ in range(2):
- self.inst1.get_cost(2)
- for _ in range(6):
- self.inst2.get_cost(2)
- self.inst_none.get_cost(2)
- self.dummy = Dummy()
-
- def tearDown(self):
- """Unset test parameter values."""
- self.inst = None
-
- def test_cost_object(self):
- """Test cost_object."""
- npt.assert_equal(
- self.inst1.get_cost(2),
- False,
- err_msg='Incorrect cost test result.',
- )
- npt.assert_equal(
- self.inst1.get_cost(2),
- True,
- err_msg='Incorrect cost test result.',
- )
- npt.assert_equal(
- self.inst_none.get_cost(2),
- False,
- err_msg='Incorrect cost test result.',
- )
-
- npt.assert_equal(self.inst1.cost, 12, err_msg='Incorrect cost value.')
+#############
+# LINEAR OP #
+#############
- npt.assert_equal(self.inst2.cost, 12, err_msg='Incorrect cost value.')
- npt.assert_raises(TypeError, cost.costObj, 1)
+class LinearCases:
+ """Linear operator cases."""
- npt.assert_raises(ValueError, cost.costObj, [self.dummy, self.dummy])
+ def case_linear_identity(self):
+ """Case linear operator identity."""
+ linop = linear.Identity()
+ data_op, data_adj_op, res_op, res_adj_op = 1, 1, 1, 1
-class GradientTestCase(TestCase):
- """Test case for gradient module."""
+ return linop, data_op, data_adj_op, res_op, res_adj_op
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.gp = gradient.GradParent(
- self.data1,
- func_sq,
- func_cube,
- func_identity,
- lambda input_val: 1.0,
- data_type=np.floating,
- )
- self.gp.grad = self.gp.get_grad(self.data1)
- self.gb = gradient.GradBasic(
- self.data1,
- func_sq,
- func_cube,
- )
- self.gb.get_grad(self.data1)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.gp = None
- self.gb = None
-
- def test_grad_parent_operators(self):
- """Test GradParent."""
- npt.assert_array_equal(
- self.gp.op(self.data1),
- np.array([[0, 1.0, 4.0], [9.0, 16.0, 25.0], [36.0, 49.0, 64.0]]),
- err_msg='Incorrect gradient operation.',
- )
-
- npt.assert_array_equal(
- self.gp.trans_op(self.data1),
- np.array(
- [[0, 1.0, 8.0], [27.0, 64.0, 125.0], [216.0, 343.0, 512.0]],
- ),
- err_msg='Incorrect gradient transpose operation.',
+ def case_linear_wavelet_convolve(self):
+ """Case linear operator wavelet."""
+ linop = linear.WaveletConvolve(
+ filters=np.arange(8).reshape(2, 2, 2).astype(float)
)
+ data_op = np.arange(4).reshape(1, 2, 2).astype(float)
+ data_adj_op = np.arange(8).reshape(1, 2, 2, 2).astype(float)
+ res_op = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]])
+ res_adj_op = np.array([[[28.0, 62.0], [68.0, 140.0]]])
- npt.assert_array_equal(
- self.gp.trans_op_op(self.data1),
- np.array([
- [0, 1.0, 6.40000000e1],
- [7.29000000e2, 4.09600000e3, 1.56250000e4],
- [4.66560000e4, 1.17649000e5, 2.62144000e5],
- ]),
- err_msg='Incorrect gradient transpose operation operation.',
- )
+ return linop, data_op, data_adj_op, res_op, res_adj_op
- npt.assert_equal(
- self.gp.cost(self.data1),
- 1.0,
- err_msg='Incorrect cost.',
+ @pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")
+ def case_linear_wavelet_transform(self):
+ linop = linear.WaveletTransform(
+ wavelet_name="haar",
+ shape=(8, 8),
+ level=2,
)
+ data_op = np.arange(64).reshape(8, 8).astype(float)
+ res_op, slices, shapes = pywt.ravel_coeffs(pywt.wavedecn(data_op, "haar", level=2))
+ data_adj_op = linop.op(data_op)
+ res_adj_op = pywt.waverecn(pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar")
+ return linop, data_op, data_adj_op, res_op, res_adj_op
- npt.assert_raises(
- TypeError,
- gradient.GradParent,
- 1,
+ @parametrize(weights=[[1.0, 1.0], None])
+ def case_linear_combo(self, weights):
+ """Case linear operator combo with weights."""
+ parent = linear.LinearParent(
func_sq,
func_cube,
)
+ linop = linear.LinearCombo([parent, parent], weights)
- def test_grad_basic_gradient(self):
- """Test GradBasic."""
- npt.assert_array_equal(
- self.gb.grad,
- np.array([
- [0, 0, 8.0],
- [2.16000000e2, 1.72800000e3, 8.0e3],
- [2.70000000e4, 7.40880000e4, 1.75616000e5],
- ]),
- err_msg='Incorrect gradient.',
+ data_op, data_adj_op, res_op, res_adj_op = (
+ 2,
+ np.array([2, 2]),
+ np.array([4, 4]),
+ 8.0 * (2 if weights else 1),
)
+ return linop, data_op, data_adj_op, res_op, res_adj_op
-class LinearTestCase(TestCase):
- """Test case for linear module."""
+ @parametrize(factor=[1, 1 + 1j])
+ def case_linear_matrix(self, factor):
+ """Case linear operator from matrix."""
+ linop = linear.MatrixOperator(np.eye(5) * factor)
+ data_op = np.arange(5)
+ data_adj_op = np.arange(5)
+ res_op = np.arange(5) * factor
+ res_adj_op = np.arange(5) * np.conjugate(factor)
- def setUp(self):
- """Set test parameter values."""
- self.parent = linear.LinearParent(
- func_sq,
- func_cube,
- )
- self.ident = linear.Identity()
- filters = np.arange(8).reshape(2, 2, 2).astype(float)
- self.wave = linear.WaveletConvolve(filters)
- self.combo = linear.LinearCombo([self.parent, self.parent])
- self.combo_weight = linear.LinearCombo(
- [self.parent, self.parent],
- [1.0, 1.0],
- )
- self.data1 = np.arange(18).reshape(2, 3, 3).astype(float)
- self.data2 = np.arange(4).reshape(1, 2, 2).astype(float)
- self.data3 = np.arange(8).reshape(1, 2, 2, 2).astype(float)
- self.data4 = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]])
- self.data5 = np.array([[[28.0, 62.0], [68.0, 140.0]]])
- self.dummy = Dummy()
-
- def tearDown(self):
- """Unset test parameter values."""
- self.parent = None
- self.ident = None
- self.combo = None
- self.combo_weight = None
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.data4 = None
- self.data5 = None
- self.dummy = None
-
- def test_linear_parent(self):
- """Test LinearParent."""
- npt.assert_equal(
- self.parent.op(2),
- 4,
- err_msg='Incorrect linear parent operation.',
- )
+ return linop, data_op, data_adj_op, res_op, res_adj_op
- npt.assert_equal(
- self.parent.adj_op(2),
- 8,
- err_msg='Incorrect linear parent adjoint operation.',
- )
- npt.assert_raises(TypeError, linear.LinearParent, 0, 0)
+@fixture
+@parametrize_with_cases(
+ "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases
+)
+def lin_adj_op(linop, data_op, data_adj_op, res_op, res_adj_op):
+ """Get adj_op relative data."""
+ return linop.adj_op, data_adj_op, res_adj_op
- def test_identity(self):
- """Test Identity."""
- npt.assert_equal(
- self.ident.op(1.0),
- 1.0,
- err_msg='Incorrect identity operation.',
- )
- npt.assert_equal(
- self.ident.adj_op(1.0),
- 1.0,
- err_msg='Incorrect identity adjoint operation.',
- )
+@fixture
+@parametrize_with_cases(
+ "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases
+)
+def lin_op(linop, data_op, data_adj_op, res_op, res_adj_op):
+ """Get op relative data."""
+ return linop.op, data_op, res_op
- def test_wavelet_convolve(self):
- """Test WaveletConvolve."""
- npt.assert_almost_equal(
- self.wave.op(self.data2),
- self.data4,
- err_msg='Incorrect wavelet convolution operation.',
- )
- npt.assert_almost_equal(
- self.wave.adj_op(self.data3),
- self.data5,
- err_msg='Incorrect wavelet convolution adjoint operation.',
- )
+@parametrize(
+ ("action", "data", "result"), [fixture_ref(lin_op), fixture_ref(lin_adj_op)]
+)
+def test_linear_operator(action, data, result):
+ """Test linear operator."""
+ npt.assert_almost_equal(action(data), result)
- def test_linear_combo(self):
- """Test LinearCombo."""
- npt.assert_equal(
- self.combo.op(2),
- np.array([4, 4]).astype(object),
- err_msg='Incorrect combined linear operation',
- )
- npt.assert_equal(
- self.combo.adj_op([2, 2]),
- 8.0,
- err_msg='Incorrect combined linear adjoint operation',
- )
+dummy_with_op = Dummy()
+dummy_with_op.op = lambda x: x
- npt.assert_raises(TypeError, linear.LinearCombo, self.parent)
- npt.assert_raises(ValueError, linear.LinearCombo, [])
+@pytest.mark.parametrize(
+ ("args", "error"),
+ [
+ ([linear.LinearParent(func_sq, func_cube)], TypeError),
+ ([[]], ValueError),
+ ([[Dummy()]], ValueError),
+ ([[dummy_with_op]], ValueError),
+ ([[]], ValueError),
+ ([[linear.LinearParent(func_sq, func_cube)] * 2, [1.0]], ValueError),
+ ([[linear.LinearParent(func_sq, func_cube)] * 2, ["1", "1"]], TypeError),
+ ],
+)
+def test_linear_combo_errors(args, error):
+ """Test linear combo_errors."""
+ npt.assert_raises(error, linear.LinearCombo, *args)
- npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy])
- self.dummy.op = func_identity
+#############
+# Proximity #
+#############
- npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy])
- def test_linear_combo_weight(self):
- """Test LinearCombo with weight ."""
- npt.assert_equal(
- self.combo_weight.op(2),
- np.array([4, 4]).astype(object),
- err_msg='Incorrect combined linear operation',
- )
-
- npt.assert_equal(
- self.combo_weight.adj_op([2, 2]),
- 16.0,
- err_msg='Incorrect combined linear adjoint operation',
- )
+class ProxCases:
+ """Class containing all proximal operator cases.
- npt.assert_raises(
- ValueError,
- linear.LinearCombo,
- [self.parent, self.parent],
- [1.0],
- )
-
- npt.assert_raises(
- TypeError,
- linear.LinearCombo,
- [self.parent, self.parent],
- ['1', '1'],
- )
+ Each case should return 4 parameters:
+ 1. The proximal operator
+ 2. test input data
+ 3. Expected result data
+ 4. Expected cost value.
+ """
+ weights = np.ones(9).reshape(3, 3).astype(float) * 3
+ array33 = np.arange(9).reshape(3, 3).astype(float)
+ array33_st = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]])
+ array33_st2 = array33_st * -1
-class ProximityTestCase(TestCase):
- """Test case for proximity module."""
+ array33_support = np.asarray([[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]])
- def setUp(self):
- """Set test parameter values."""
- self.parent = proximity.ProximityParent(
- func_sq,
- func_double,
- )
- self.identity = proximity.IdentityProx()
- self.positivity = proximity.Positivity()
- weights = np.ones(9).reshape(3, 3).astype(float) * 3
- self.sparsethresh = proximity.SparseThreshold(
- linear.Identity(),
- weights,
- )
- self.lowrank = proximity.LowRankMatrix(10.0, thresh_type='hard')
- self.lowrank_rank = proximity.LowRankMatrix(
- 10.0,
- initial_rank=1,
- thresh_type='hard',
- )
- self.lowrank_ngole = proximity.LowRankMatrix(
- 10.0,
- lowr_type='ngole',
- operator=func_double,
- )
- self.linear_comp = proximity.LinearCompositionProx(
- linear_op=linear.Identity(),
- prox_op=self.sparsethresh,
- )
- self.combo = proximity.ProximityCombo([self.identity, self.positivity])
- if import_sklearn:
- self.owl = proximity.OrderedWeightedL1Norm(weights.flatten())
- self.ridge = proximity.Ridge(linear.Identity(), weights)
- self.elasticnet_alpha0 = proximity.ElasticNet(
- linear.Identity(),
- alpha=0,
- beta=weights,
- )
- self.elasticnet_beta0 = proximity.ElasticNet(
- linear.Identity(),
- alpha=weights,
- beta=0,
- )
- self.one_support = proximity.KSupportNorm(beta=0.2, k_value=1)
- self.five_support_norm = proximity.KSupportNorm(beta=3, k_value=5)
- self.d_support = proximity.KSupportNorm(beta=3.0 * 2, k_value=19)
- self.group_lasso = proximity.GroupLASSO(
- weights=np.tile(weights, (4, 1, 1)),
- )
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.data2 = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]])
- self.data3 = np.arange(18).reshape(2, 3, 3).astype(float)
- self.data4 = np.array([
+ array233 = np.arange(18).reshape(2, 3, 3).astype(float)
+ array233_2 = np.array(
[
[2.73843189, 3.14594066, 3.55344943],
[3.9609582, 4.36846698, 4.77597575],
@@ -723,349 +298,230 @@ def setUp(self):
[11.67394789, 12.87497954, 14.07601119],
[15.27704284, 16.47807449, 17.67910614],
],
- ])
- self.data5 = np.array([
+ ]
+ )
+ array233_3 = np.array(
+ [
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[
[4.00795282, 4.60438026, 5.2008077],
[5.79723515, 6.39366259, 6.99009003],
[7.58651747, 8.18294492, 8.77937236],
],
- ])
- self.data6 = self.data3 * -1
- self.data7 = self.combo.op(self.data6)
- self.data8 = np.empty(2, dtype=np.ndarray)
- self.data8[0] = np.array(
- [[-0, -1.0, -2.0], [-3.0, -4.0, -5.0], [-6.0, -7.0, -8.0]],
- )
- self.data8[1] = np.array(
- [[-0, -0, -0], [-0, -0, -0], [-0, -0, -0]],
- )
- self.data9 = self.data1 * (1 + 1j)
- self.data10 = self.data9 / (2 * 3 + 1)
- self.data11 = np.asarray(
- [[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]],
- )
- self.random_data = 3 * np.random.random(
- self.group_lasso.weights[0].shape,
- )
- self.random_data_tile = np.tile(
- self.random_data,
- (self.group_lasso.weights.shape[0], 1, 1),
- )
- self.gl_result_data = 2 * self.random_data_tile - 3
- self.gl_result_data = np.array(
- (self.gl_result_data * (self.gl_result_data > 0).astype('int'))
- / 2,
- )
-
- self.dummy = Dummy()
-
- def tearDown(self):
- """Unset test parameter values."""
- self.parent = None
- self.identity = None
- self.positivity = None
- self.sparsethresh = None
- self.lowrank = None
- self.lowrank_rank = None
- self.lowrank_ngole = None
- self.combo = None
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.data4 = None
- self.data5 = None
- self.data6 = None
- self.data7 = None
- self.data8 = None
- self.dummy = None
- self.random_data = None
- self.random_data_tile = None
- self.gl_result_data = None
-
- def test_proximity_parent(self):
- """Test ProximityParent."""
- npt.assert_equal(
- self.parent.op(3),
+ ]
+ )
+
+ def case_prox_parent(self):
+ """Case prox parent."""
+ return (
+ proximity.ProximityParent(
+ func_sq,
+ func_double,
+ ),
+ 3,
9,
- err_msg='Inccoret proximity parent operation.',
- )
-
- npt.assert_equal(
- self.parent.cost(3),
6,
- err_msg='Incorrect proximity parent cost.',
- )
-
- def test_identity(self):
- """Test IdentityProx."""
- npt.assert_equal(
- self.identity.op(3),
- 3,
- err_msg='Incorrect proximity identity operation.',
- )
-
- npt.assert_equal(
- self.identity.cost(3),
- 0,
- err_msg='Incorrect proximity identity cost.',
- )
-
- def test_positivity(self):
- """Test Positivity."""
- npt.assert_equal(
- self.positivity.op(-3),
- 0,
- err_msg='Incorrect proximity positivity operation.',
- )
-
- npt.assert_equal(
- self.positivity.cost(-3, verbose=True),
- 0,
- err_msg='Incorrect proximity positivity cost.',
)
- def test_sparse_threshold(self):
- """Test SparseThreshold."""
- npt.assert_array_equal(
- self.sparsethresh.op(self.data1),
- self.data2,
- err_msg='Incorrect sparse threshold operation.',
- )
-
- npt.assert_equal(
- self.sparsethresh.cost(self.data1, verbose=True),
- 108.0,
- err_msg='Incorrect sparse threshold cost.',
- )
-
- def test_low_rank_matrix(self):
- """Test LowRankMatrix."""
- npt.assert_almost_equal(
- self.lowrank.op(self.data3),
- self.data4,
- err_msg='Incorrect low rank operation: standard',
- )
-
- npt.assert_almost_equal(
- self.lowrank_rank.op(self.data3),
- self.data4,
- err_msg='Incorrect low rank operation: standard with rank',
- )
- npt.assert_almost_equal(
- self.lowrank_ngole.op(self.data3),
- self.data5,
- err_msg='Incorrect low rank operation: ngole',
- )
-
- npt.assert_almost_equal(
- self.lowrank.cost(self.data3, verbose=True),
- 469.39132942464983,
- err_msg='Incorrect low rank cost.',
- )
-
- def test_linear_comp_prox(self):
- """Test LinearCompositionProx."""
- npt.assert_array_equal(
- self.linear_comp.op(self.data1),
- self.data2,
- err_msg='Incorrect sparse threshold operation.',
- )
-
- npt.assert_equal(
- self.linear_comp.cost(self.data1, verbose=True),
- 108.0,
- err_msg='Incorrect sparse threshold cost.',
+ def case_prox_identity(self):
+ """Case prox identity."""
+ return proximity.IdentityProx(), 3, 3, 0
+
+ def case_prox_positivity(self):
+ """Case prox positivity."""
+ return proximity.Positivity(), -3, 0, 0
+
+ def case_prox_sparsethresh(self):
+ """Case prox sparsethreshosld."""
+ return (
+ proximity.SparseThreshold(linear.Identity(), weights=self.weights),
+ self.array33,
+ self.array33_st,
+ 108,
+ )
+
+ @parametrize(
+ "lowr_type, initial_rank, operator, result, cost",
+ [
+ ("standard", None, None, array233_2, 469.3913294246498),
+ ("standard", 1, None, array233_2, 469.3913294246498),
+ ("ngole", None, func_double, array233_3, 469.3913294246498),
+ ],
+ )
+ def case_prox_lowrank(self, lowr_type, initial_rank, operator, result, cost):
+ """Case prox lowrank."""
+ return (
+ proximity.LowRankMatrix(
+ 10,
+ lowr_type=lowr_type,
+ initial_rank=initial_rank,
+ operator=operator,
+ thresh_type="hard" if lowr_type == "standard" else "soft",
+ ),
+ self.array233,
+ result,
+ cost,
)
- def test_proximity_combo(self):
- """Test ProximityCombo."""
- for data7, data8 in zip(self.data7, self.data8):
- npt.assert_array_equal(
- data7,
- data8,
- err_msg='Incorrect combined operation',
+ def case_prox_linear_comp(self):
+ """Case prox linear comp."""
+ return (
+ proximity.LinearCompositionProx(
+ linear_op=linear.Identity(), prox_op=self.case_prox_sparsethresh()[0]
+ ),
+ self.array33,
+ self.array33_st,
+ 108,
+ )
+
+ def case_prox_ridge(self):
+ """Case prox ridge."""
+ return (
+ proximity.Ridge(linear.Identity(), self.weights),
+ self.array33 * (1 + 1j),
+ self.array33 * (1 + 1j) / 7,
+ 1224,
+ )
+
+ @parametrize("alpha, beta", [(0, weights), (weights, 0)])
+ def case_prox_elasticnet(self, alpha, beta):
+ """Case prox elastic net."""
+ if np.all(alpha == 0):
+ data = self.case_prox_sparsethresh()[1:]
+ else:
+ data = self.case_prox_ridge()[1:]
+ return (proximity.ElasticNet(linear.Identity(), alpha, beta), *data)
+
+ @parametrize(
+ "beta, k_value, data, result, cost",
+ [
+ (0.2, 1, array33.flatten(), array33_st.flatten(), 259.2),
+ (3, 5, array33.flatten(), array33_support.flatten(), 684.0),
+ (
+ 6.0,
+ 9,
+ array33.flatten() * (1 + 1j),
+ array33.flatten() * (1 + 1j) / 7,
+ 1224,
+ ),
+ ],
+ )
+ def case_prox_Ksupport(self, beta, k_value, data, result, cost):
+ """Case prox K-support norm."""
+ return (proximity.KSupportNorm(beta=beta, k_value=k_value), data, result, cost)
+
+ @parametrize(use_weights=[True, False])
+ def case_prox_grouplasso(self, use_weights):
+ """Case GroupLasso proximity."""
+ if use_weights:
+ weights = np.tile(self.weights, (4, 1, 1))
+ else:
+ weights = np.tile(np.zeros((3, 3)), (4, 1, 1))
+
+ random_data = 3 * np.random.random(weights[0].shape)
+ random_data_tile = np.tile(random_data, (weights.shape[0], 1, 1))
+ if use_weights:
+ gl_result_data = 2 * random_data_tile - 3
+ gl_result_data = (
+ np.array(gl_result_data * (gl_result_data > 0).astype("int")) / 2
)
-
- npt.assert_equal(
- self.combo.cost(self.data6),
- 0,
- err_msg='Incorrect combined cost.',
- )
-
- npt.assert_raises(TypeError, proximity.ProximityCombo, 1)
-
- npt.assert_raises(ValueError, proximity.ProximityCombo, [])
-
- npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy])
-
- self.dummy.op = func_identity
-
- npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy])
-
- @skipIf(import_sklearn, 'sklearn is installed.') # pragma: no cover
- def test_owl_sklearn_error(self):
- """Test OrderedWeightedL1Norm with Scikit-Learn."""
- npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1)
-
- @skipUnless(import_sklearn, 'sklearn not installed.') # pragma: no cover
- def test_sparse_owl(self):
- """Test OrderedWeightedL1Norm."""
- npt.assert_array_equal(
- self.owl.op(self.data1.flatten()),
- self.data2.flatten(),
- err_msg='Incorrect sparse threshold operation.',
- )
-
- npt.assert_equal(
- self.owl.cost(self.data1.flatten(), verbose=True),
+ cost = np.sum(random_data_tile) * 6
+ else:
+ gl_result_data = random_data_tile
+ cost = 0
+ return (
+ proximity.GroupLASSO(
+ weights=weights,
+ ),
+ random_data_tile,
+ gl_result_data,
+ cost,
+ )
+
+ @pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn not available.")
+ def case_prox_owl(self):
+ """Case prox for Ordered Weighted L1 Norm."""
+ return (
+ proximity.OrderedWeightedL1Norm(self.weights.flatten()),
+ self.array33.flatten(),
+ self.array33_st.flatten(),
108.0,
- err_msg='Incorrect sparse threshold cost.',
)
- npt.assert_raises(
- ValueError,
- proximity.OrderedWeightedL1Norm,
- np.arange(10),
- )
- def test_ridge(self):
- """Test Ridge."""
- npt.assert_array_equal(
- self.ridge.op(self.data9),
- self.data10,
- err_msg='Incorect shrinkage operation.',
- )
+@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases)
+def test_prox_op(operator, input_data, op_result, cost_result):
+ """Test proximity operator op."""
+ npt.assert_almost_equal(operator.op(input_data), op_result)
- npt.assert_equal(
- self.ridge.cost(self.data9, verbose=True),
- 408.0 * 3.0,
- err_msg='Incorect shrinkage cost.',
- )
- def test_elastic_net_alpha0(self):
- """Test ElasticNet."""
- npt.assert_array_equal(
- self.elasticnet_alpha0.op(self.data1),
- self.data2,
- err_msg='Incorect sparse threshold operation ElasticNet class.',
- )
+@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases)
+def test_prox_cost(operator, input_data, op_result, cost_result):
+ """Test proximity operator cost."""
+ npt.assert_almost_equal(operator.cost(input_data, verbose=True), cost_result)
- npt.assert_equal(
- self.elasticnet_alpha0.cost(self.data1),
- 108.0,
- err_msg='Incorect shrinkage cost in ElasticNet class.',
- )
- def test_elastic_net_beta0(self):
- """Test ElasticNet with beta=0."""
- npt.assert_array_equal(
- self.elasticnet_beta0.op(self.data9),
- self.data10,
- err_msg='Incorect ridge operation ElasticNet class.',
- )
+@parametrize(
+ "arg, error",
+ [
+ (1, TypeError),
+ ([], ValueError),
+ ([Dummy()], ValueError),
+ ([dummy_with_op], ValueError),
+ ],
+)
+def test_error_prox_combo(arg, error):
+ """Test errors for proximity combo."""
+ npt.assert_raises(error, proximity.ProximityCombo, arg)
- npt.assert_equal(
- self.elasticnet_beta0.cost(self.data9, verbose=True),
- 408.0 * 3.0,
- err_msg='Incorect shrinkage cost in ElasticNet class.',
- )
- def test_one_support_norm(self):
- """Test KSupportNorm with k=1."""
- npt.assert_allclose(
- self.one_support.op(self.data1.flatten()),
- self.data2.flatten(),
- err_msg='Incorect sparse threshold operation for 1-support norm',
- rtol=1e-6,
- )
-
- npt.assert_equal(
- self.one_support.cost(self.data1.flatten(), verbose=True),
- 259.2,
- err_msg='Incorect sparse threshold cost.',
- )
+@pytest.mark.skipif(SKLEARN_AVAILABLE, reason="sklearn is installed")
+def test_fail_sklearn():
+ """Test fail OWL with sklearn."""
+ npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1)
- npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0)
- def test_five_support_norm(self):
- """Test KSupportNorm with k=5."""
- npt.assert_allclose(
- self.five_support_norm.op(self.data1.flatten()),
- self.data11.flatten(),
- err_msg='Incorect sparse Ksupport norm operation',
- rtol=1e-6,
- )
+@pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn is not installed.")
+def test_fail_owl():
+ """Test errors for Ordered Weighted L1 Norm."""
+ npt.assert_raises(
+ ValueError,
+ proximity.OrderedWeightedL1Norm,
+ np.arange(10),
+ )
- npt.assert_equal(
- self.five_support_norm.cost(self.data1.flatten(), verbose=True),
- 684.0,
- err_msg='Incorrect 5-support norm cost.',
- )
+ npt.assert_raises(
+ ValueError,
+ proximity.OrderedWeightedL1Norm,
+ -np.arange(10),
+ )
- npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0)
- def test_d_support_norm(self):
- """Test KSupportNorm with k=19."""
- npt.assert_allclose(
- self.d_support.op(self.data9.flatten()),
- self.data10.flatten(),
- err_msg='Incorect shrinkage operation for d-support norm',
- rtol=1e-6,
- )
+def test_fail_lowrank():
+ """Test fail for lowrank."""
+ prox_op = proximity.LowRankMatrix(10, lowr_type="fail")
+ npt.assert_raises(ValueError, prox_op.op, 0)
- npt.assert_almost_equal(
- self.d_support.cost(self.data9.flatten(), verbose=True),
- 408.0 * 3.0,
- err_msg='Incorrect shrinkage cost for d-support norm.',
- )
- npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0)
+def test_fail_Ksupport_norm():
+ """Test fail for K-support norm."""
+ npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0)
- def test_group_lasso(self):
- """Test GroupLASSO."""
- npt.assert_allclose(
- self.group_lasso.op(self.random_data_tile),
- self.gl_result_data,
- )
- npt.assert_equal(
- self.group_lasso.cost(self.random_data_tile),
- np.sum(6 * self.random_data_tile),
- )
- # Check that for 0 weights operator doesnt change result
- self.group_lasso.weights = np.zeros_like(self.group_lasso.weights)
- npt.assert_equal(
- self.group_lasso.op(self.random_data_tile),
- self.random_data_tile,
- )
- npt.assert_equal(self.group_lasso.cost(self.random_data_tile), 0)
+def test_reweight():
+ """Test for reweight module."""
+ data1 = np.arange(9).reshape(3, 3).astype(float) + 1
+ data2 = np.array(
+ [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]],
+ )
-class ReweightTestCase(TestCase):
- """Test case for reweight module."""
+ rw = reweight.cwbReweight(data1)
+ rw.reweight(data1)
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float) + 1
- self.data2 = np.array(
- [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]],
- )
- self.rw = reweight.cwbReweight(self.data1)
- self.rw.reweight(self.data1)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.rw = None
-
- def test_cwbreweight(self):
- """Test cwbReweight."""
- npt.assert_array_equal(
- self.rw.weights,
- self.data2,
- err_msg='Incorrect CWB re-weighting.',
- )
+ npt.assert_array_equal(
+ rw.weights,
+ data2,
+ err_msg="Incorrect CWB re-weighting.",
+ )
- npt.assert_raises(ValueError, self.rw.reweight, self.data1[0])
+ npt.assert_raises(ValueError, rw.reweight, data1[0])
diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py
index 7490b98c..202e541b 100644
--- a/modopt/tests/test_signal.py
+++ b/modopt/tests/test_signal.py
@@ -1,322 +1,240 @@
-# -*- coding: utf-8 -*-
-
"""UNIT TESTS FOR SIGNAL.
This module contains unit tests for the modopt.signal module.
-:Author: Samuel Farrens
-
+:Authors:
+ Samuel Farrens
+ Pierre-Antoine Comby
"""
-from unittest import TestCase
-
import numpy as np
import numpy.testing as npt
+import pytest
+from test_helpers import failparam
from modopt.signal import filter, noise, positivity, svd, validation, wavelet
-class FilterTestCase(TestCase):
- """Test case for filter module."""
-
- def test_guassian_filter(self):
- """Test guassian_filter."""
- npt.assert_almost_equal(
- filter.gaussian_filter(1, 1),
- 0.24197072451914337,
- err_msg='Incorrect Gaussian filter',
- )
+class TestFilter:
+ """Test filter module"""
+ @pytest.mark.parametrize(
+ ("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)]
+ )
+ def test_gaussian_filter(self, norm, result):
+ """Test gaussian filter."""
+ npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result)
- npt.assert_almost_equal(
- filter.gaussian_filter(1, 1, norm=False),
- 0.60653065971263342,
- err_msg='Incorrect Gaussian filter',
- )
def test_mex_hat(self):
- """Test mex_hat."""
+ """Test mexican hat filter."""
npt.assert_almost_equal(
filter.mex_hat(2, 1),
-0.35213905225713371,
- err_msg='Incorrect Mexican hat filter',
)
+
def test_mex_hat_dir(self):
- """Test mex_hat_dir."""
+ """Test directional mexican hat filter."""
npt.assert_almost_equal(
filter.mex_hat_dir(1, 2, 1),
0.17606952612856686,
- err_msg='Incorrect directional Mexican hat filter',
)
-class NoiseTestCase(TestCase):
- """Test case for noise module."""
+class TestNoise:
+ """Test noise module."""
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.data2 = np.array(
- [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]],
- )
- self.data3 = np.array([
+ data1 = np.arange(9).reshape(3, 3).astype(float)
+ data2 = np.array(
+ [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]],
+ )
+ data3 = np.array(
+ [
[1.62434536, 0.38824359, 1.47182825],
[1.92703138, 4.86540763, 2.6984613],
[7.74481176, 6.2387931, 8.3190391],
- ])
- self.data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
- self.data5 = np.array(
- [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]],
- )
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.data4 = None
- self.data5 = None
-
- def test_add_noise_poisson(self):
- """Test add_noise with Poisson noise."""
- np.random.seed(1)
- npt.assert_array_equal(
- noise.add_noise(self.data1, noise_type='poisson'),
- self.data2,
- err_msg='Incorrect noise: Poisson',
- )
-
- npt.assert_raises(
- ValueError,
- noise.add_noise,
- self.data1,
- noise_type='bla',
- )
-
- npt.assert_raises(ValueError, noise.add_noise, self.data1, (1, 1))
-
- def test_add_noise_gaussian(self):
- """Test add_noise with Gaussian noise."""
- np.random.seed(1)
- npt.assert_almost_equal(
- noise.add_noise(self.data1),
- self.data3,
- err_msg='Incorrect noise: Gaussian',
- )
-
+ ]
+ )
+ data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
+ data5 = np.array(
+ [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]],
+ )
+
+ @pytest.mark.parametrize(
+ ("data", "noise_type", "sigma", "data_noise"),
+ [
+ (data1, "poisson", 1, data2),
+ (data1, "gauss", 1, data3),
+ (data1, "gauss", (1, 1, 1), data3),
+ failparam(data1, "fail", 1, data1, raises=ValueError),
+ ],
+ )
+ def test_add_noise(self, data, noise_type, sigma, data_noise):
+ """Test add_noise."""
np.random.seed(1)
npt.assert_almost_equal(
- noise.add_noise(self.data1, sigma=(1, 1, 1)),
- self.data3,
- err_msg='Incorrect noise: Gaussian',
- )
-
- def test_thresh_hard(self):
- """Test thresh with hard threshold."""
- npt.assert_array_equal(
- noise.thresh(self.data1, 5),
- self.data4,
- err_msg='Incorrect threshold: hard',
- )
-
- npt.assert_raises(
- ValueError,
- noise.thresh,
- self.data1,
- 5,
- threshold_type='bla',
+ noise.add_noise(data, sigma=sigma, noise_type=noise_type), data_noise
)
- def test_thresh_soft(self):
- """Test thresh with soft threshold."""
+ @pytest.mark.parametrize(
+ ("threshold_type", "result"),
+ [("hard", data4), ("soft", data5), failparam("fail", None, raises=ValueError)],
+ )
+ def test_thresh(self, threshold_type, result):
+ """Test threshold."""
npt.assert_array_equal(
- noise.thresh(self.data1, 5, threshold_type='soft'),
- self.data5,
- err_msg='Incorrect threshold: soft',
+ noise.thresh(self.data1, 5, threshold_type=threshold_type), result
)
-
-class PositivityTestCase(TestCase):
- """Test case for positivity module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3) - 5
- self.data2 = np.array([[0, 0, 0], [0, 0, 0], [1, 2, 3]])
- self.data3 = np.array(
- [np.arange(5) - 3, np.arange(4) - 2],
- dtype=object,
- )
- self.data4 = np.array(
- [np.array([0, 0, 0, 0, 1]), np.array([0, 0, 0, 1])],
+class TestPositivity:
+ """Test positivity module."""
+ data1 = np.arange(9).reshape(3, 3).astype(float)
+ data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
+ data5 = np.array(
+ [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]],
+ )
+ @pytest.mark.parametrize(
+ ("value", "expected"),
+ [
+ (-1.0, -float(0)),
+ (-1, 0),
+ (data1 - 5, data5),
+ (
+ np.array([np.arange(3) - 1, np.arange(2) - 1], dtype=object),
+ np.array([np.array([0, 0, 1]), np.array([0, 0])], dtype=object),
+ ),
+ failparam("-1", None, raises=TypeError),
+ ],
+ )
+ def test_positive(self, value, expected):
+ """Test positive."""
+ if isinstance(value, np.ndarray) and value.dtype == "O":
+ for v, e in zip(positivity.positive(value), expected):
+ npt.assert_array_equal(v, e)
+ else:
+ npt.assert_array_equal(positivity.positive(value), expected)
+
+
+class TestSVD:
+ """Test for svd module."""
+
+ @pytest.fixture
+ def data(self):
+ """Initialize test data."""
+ data1 = np.arange(18).reshape(9, 2).astype(float)
+ data2 = np.arange(32).reshape(16, 2).astype(float)
+ data3 = np.array(
+ [
+ np.array(
+ [
+ [-0.01744594, -0.61438865],
+ [-0.08435304, -0.50397984],
+ [-0.15126014, -0.39357102],
+ [-0.21816724, -0.28316221],
+ [-0.28507434, -0.17275339],
+ [-0.35198144, -0.06234457],
+ [-0.41888854, 0.04806424],
+ [-0.48579564, 0.15847306],
+ [-0.55270274, 0.26888188],
+ ]
+ ),
+ np.array([42.23492742, 1.10041151]),
+ np.array(
+ [
+ [-0.67608034, -0.73682791],
+ [0.73682791, -0.67608034],
+ ]
+ ),
+ ],
dtype=object,
)
- self.pos_dtype_obj = positivity.positive(self.data3)
- self.err = 'Incorrect positivity'
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
-
- def test_positivity(self):
- """Test positivity."""
- npt.assert_equal(positivity.positive(-1), 0, err_msg=self.err)
-
- npt.assert_equal(
- positivity.positive(-1.0),
- -float(0),
- err_msg=self.err,
+ data4 = np.array(
+ [
+ [-1.05426832e-16, 1.0],
+ [2.0, 3.0],
+ [4.0, 5.0],
+ [6.0, 7.0],
+ [8.0, 9.0],
+ [1.0e1, 1.1e1],
+ [1.2e1, 1.3e1],
+ [1.4e1, 1.5e1],
+ [1.6e1, 1.7e1],
+ ]
)
- npt.assert_equal(
- positivity.positive(self.data1),
- self.data2,
- err_msg=self.err,
+ data5 = np.array(
+ [
+ [0.49815487, 0.54291537],
+ [2.40863386, 2.62505584],
+ [4.31911286, 4.70719631],
+ [6.22959185, 6.78933678],
+ [8.14007085, 8.87147725],
+ [10.05054985, 10.95361772],
+ [11.96102884, 13.03575819],
+ [13.87150784, 15.11789866],
+ [15.78198684, 17.20003913],
+ ]
)
+ return (data1, data2, data3, data4, data5)
- for expected, output in zip(self.data4, self.pos_dtype_obj):
- print(expected, output)
- npt.assert_array_equal(expected, output, err_msg=self.err)
+ @pytest.fixture
+ def svd0(self, data):
+ """Compute SVD of first data sample."""
+ return svd.calculate_svd(data[0])
- npt.assert_raises(TypeError, positivity.positive, '-1')
-
-
-class SVDTestCase(TestCase):
- """Test case for svd module."""
-
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(18).reshape(9, 2).astype(float)
- self.data2 = np.arange(32).reshape(16, 2).astype(float)
- self.data3 = np.array(
- [
- np.array([
- [-0.01744594, -0.61438865],
- [-0.08435304, -0.50397984],
- [-0.15126014, -0.39357102],
- [-0.21816724, -0.28316221],
- [-0.28507434, -0.17275339],
- [-0.35198144, -0.06234457],
- [-0.41888854, 0.04806424],
- [-0.48579564, 0.15847306],
- [-0.55270274, 0.26888188],
- ]),
- np.array([42.23492742, 1.10041151]),
- np.array([
- [-0.67608034, -0.73682791],
- [0.73682791, -0.67608034],
- ]),
- ],
- dtype=object,
- )
- self.data4 = np.array([
- [-1.05426832e-16, 1.0],
- [2.0, 3.0],
- [4.0, 5.0],
- [6.0, 7.0],
- [8.0, 9.0],
- [1.0e1, 1.1e1],
- [1.2e1, 1.3e1],
- [1.4e1, 1.5e1],
- [1.6e1, 1.7e1],
- ])
- self.data5 = np.array([
- [0.49815487, 0.54291537],
- [2.40863386, 2.62505584],
- [4.31911286, 4.70719631],
- [6.22959185, 6.78933678],
- [8.14007085, 8.87147725],
- [10.05054985, 10.95361772],
- [11.96102884, 13.03575819],
- [13.87150784, 15.11789866],
- [15.78198684, 17.20003913],
- ])
- self.svd = svd.calculate_svd(self.data1)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.data4 = None
- self.svd = None
-
- def test_find_n_pc(self):
- """Test find_n_pc."""
+ def test_find_n_pc(self, data):
+ """Test find number of principal component."""
npt.assert_equal(
- svd.find_n_pc(svd.svd(self.data2)[0]),
+ svd.find_n_pc(svd.svd(data[1])[0]),
2,
- err_msg='Incorrect number of principal components.',
+ err_msg="Incorrect number of principal components.",
)
+ def test_n_pc_fail_non_square(self):
+ """Test find_n_pc."""
npt.assert_raises(ValueError, svd.find_n_pc, np.arange(3))
- def test_calculate_svd(self):
+ def test_calculate_svd(self, data, svd0):
"""Test calculate_svd."""
+ errors = []
+ for i, name in enumerate("USV"):
+ try:
+ npt.assert_almost_equal(svd0[i], data[2][i])
+ except AssertionError:
+ errors.append(name)
+ if errors:
+ raise AssertionError("Incorrect SVD calculation for: " + ", ".join(errors))
+
+ @pytest.mark.parametrize(
+ ("n_pc", "idx_res"),
+ [(None, 3), (1, 4), ("all", 0), failparam("fail", 1, raises=ValueError)],
+ )
+ def test_svd_thresh(self, data, n_pc, idx_res):
+ """Test svd_tresh."""
npt.assert_almost_equal(
- self.svd[0],
- np.array(self.data3)[0],
- err_msg='Incorrect SVD calculation: U',
- )
-
- npt.assert_almost_equal(
- self.svd[1],
- np.array(self.data3)[1],
- err_msg='Incorrect SVD calculation: S',
- )
-
- npt.assert_almost_equal(
- self.svd[2],
- np.array(self.data3)[2],
- err_msg='Incorrect SVD calculation: V',
- )
-
- def test_svd_thresh(self):
- """Test svd_thresh."""
- npt.assert_almost_equal(
- svd.svd_thresh(self.data1),
- self.data4,
- err_msg='Incorrect SVD tresholding',
- )
-
- npt.assert_almost_equal(
- svd.svd_thresh(self.data1, n_pc=1),
- self.data5,
- err_msg='Incorrect SVD tresholding',
- )
-
- npt.assert_almost_equal(
- svd.svd_thresh(self.data1, n_pc='all'),
- self.data1,
- err_msg='Incorrect SVD tresholding',
+ svd.svd_thresh(data[0], n_pc=n_pc),
+ data[idx_res],
)
+ def test_svd_tresh_invalid_type(self):
+ """Test svd_tresh failure."""
npt.assert_raises(TypeError, svd.svd_thresh, 1)
- npt.assert_raises(ValueError, svd.svd_thresh, self.data1, n_pc='bla')
-
- def test_svd_thresh_coef(self):
- """Test svd_thresh_coef."""
+ @pytest.mark.parametrize("operator", [lambda x: x, failparam(0, raises=TypeError)])
+ def test_svd_thresh_coef(self, data, operator):
+ """Test svd_tresh_coef."""
npt.assert_almost_equal(
- svd.svd_thresh_coef(self.data1, lambda x_val: x_val, 0),
- self.data1,
- err_msg='Incorrect SVD coefficient tresholding',
+ svd.svd_thresh_coef(data[0], operator, 0),
+ data[0],
+ err_msg="Incorrect SVD coefficient tresholding",
)
- npt.assert_raises(TypeError, svd.svd_thresh_coef, self.data1, 0, 0)
-
+ # TODO test_svd_thresh_coef_fast
-class ValidationTestCase(TestCase):
- """Test case for validation module."""
+class TestValidation:
+ """Test validation Module."""
- def setUp(self):
- """Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
+ array33 = np.arange(9).reshape(3, 3)
def test_transpose_test(self):
"""Test transpose_test."""
@@ -325,90 +243,81 @@ def test_transpose_test(self):
validation.transpose_test(
lambda x_val, y_val: x_val.dot(y_val),
lambda x_val, y_val: x_val.dot(y_val.T),
- self.data1.shape,
- x_args=self.data1,
+ self.array33.shape,
+ x_args=self.array33,
),
None,
)
- npt.assert_raises(
- TypeError,
- validation.transpose_test,
- 0,
- 0,
- self.data1.shape,
- x_args=self.data1,
- )
-
-class WaveletTestCase(TestCase):
- """Test case for wavelet module."""
+class TestWavelet:
+ """Test Wavelet Module."""
- def setUp(self):
+ @pytest.fixture
+ def data(self):
"""Set test parameter values."""
- self.data1 = np.arange(9).reshape(3, 3).astype(float)
- self.data2 = np.arange(36).reshape(4, 3, 3).astype(float)
- self.data3 = np.array([
- [
- [6.0, 20, 26.0],
- [36.0, 84.0, 84.0],
- [90, 164.0, 134.0],
- ],
+ data1 = np.arange(9).reshape(3, 3).astype(float)
+ data2 = np.arange(36).reshape(4, 3, 3).astype(float)
+ data3 = np.array(
[
- [78.0, 155.0, 134.0],
- [225.0, 408.0, 327.0],
- [270, 461.0, 350],
- ],
+ [
+ [6.0, 20, 26.0],
+ [36.0, 84.0, 84.0],
+ [90, 164.0, 134.0],
+ ],
+ [
+ [78.0, 155.0, 134.0],
+ [225.0, 408.0, 327.0],
+ [270, 461.0, 350],
+ ],
+ [
+ [150, 290, 242.0],
+ [414.0, 732.0, 570],
+ [450, 758.0, 566.0],
+ ],
+ [
+ [222.0, 425.0, 350],
+ [603.0, 1056.0, 813.0],
+ [630, 1055.0, 782.0],
+ ],
+ ]
+ )
+
+ data4 = np.array(
[
- [150, 290, 242.0],
- [414.0, 732.0, 570],
- [450, 758.0, 566.0],
- ],
+ [6496.0, 9796.0, 6544.0],
+ [9924.0, 14910, 9924.0],
+ [6544.0, 9796.0, 6496.0],
+ ]
+ )
+
+ data5 = np.array(
[
- [222.0, 425.0, 350],
- [603.0, 1056.0, 813.0],
- [630, 1055.0, 782.0],
- ],
- ])
-
- self.data4 = np.array([
- [6496.0, 9796.0, 6544.0],
- [9924.0, 14910, 9924.0],
- [6544.0, 9796.0, 6496.0],
- ])
-
- self.data5 = np.array([
- [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]],
- [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]],
- [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]],
- ])
-
- def tearDown(self):
- """Unset test parameter values."""
- self.data1 = None
- self.data2 = None
- self.data3 = None
- self.data4 = None
- self.data5 = None
-
- def test_filter_convolve(self):
- """Test filter_convolve."""
- npt.assert_almost_equal(
- wavelet.filter_convolve(self.data1, self.data2),
- self.data3,
- err_msg='Inccorect filter comvolution.',
+ [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]],
+ [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]],
+ [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]],
+ ]
)
+ return (data1, data2, data3, data4, data5)
+ @pytest.mark.parametrize(
+ ("idx_data", "idx_filter", "idx_res", "filter_rot"),
+ [(0, 1, 2, False), (1, 1, 3, True)],
+ )
+ def test_filter_convolve(self, data, idx_data, idx_filter, idx_res, filter_rot):
+ """Test filter_convolve."""
npt.assert_almost_equal(
- wavelet.filter_convolve(self.data2, self.data2, filter_rot=True),
- self.data4,
- err_msg='Inccorect filter comvolution.',
+ wavelet.filter_convolve(
+ data[idx_data], data[idx_filter], filter_rot=filter_rot
+ ),
+ data[idx_res],
+ err_msg="Inccorect filter comvolution.",
)
- def test_filter_convolve_stack(self):
+ def test_filter_convolve_stack(self, data):
"""Test filter_convolve_stack."""
npt.assert_almost_equal(
- wavelet.filter_convolve_stack(self.data1, self.data1),
- self.data5,
- err_msg='Inccorect filter stack comvolution.',
+ wavelet.filter_convolve_stack(data[0], data[0]),
+ data[4],
+ err_msg="Inccorect filter stack comvolution.",
)
diff --git a/requirements.txt b/requirements.txt
index 63a404ba..1f44de13 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
importlib_metadata>=3.7.0
numpy>=1.19.5
scipy>=1.5.4
-progressbar2>=3.53.1
+tqdm>=4.64.0
diff --git a/setup.cfg b/setup.cfg
index cabd35a0..100adb40 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -42,6 +42,8 @@ per-file-ignores =
modopt/opt/algorithms/__init__.py: F401,F403,WPS318, WPS319, WPS412, WPS410
#Todo: x is a too short name.
modopt/opt/algorithms/forward_backward.py: WPS111
+ #Todo: u,v , A is a too short name.
+ modopt/opt/algorithms/admm.py: WPS111, N803
#Todo: Check need for del statement
modopt/opt/algorithms/primal_dual.py: WPS111, WPS420
#multiline parameters bug with tuples
@@ -79,13 +81,17 @@ max-string-usages = 20
max-raises = 5
[tool:pytest]
+norecursedirs=tests/test_helpers
testpaths =
modopt
addopts =
--verbose
- --emoji
- --flake8
--cov=modopt
- --cov-report=term
+ --cov-report=term-missing
--cov-report=xml
--junitxml=pytest.xml
+ --pydocstyle
+
+[pydocstyle]
+convention=numpy
+add-ignore=D107
diff --git a/setup.py b/setup.py
index c93dd020..e6a8a9e6 100644
--- a/setup.py
+++ b/setup.py
@@ -6,7 +6,7 @@
# Set the package release version
major = 1
-minor = 6
+minor = 7
patch = 1
# Set the package details
@@ -20,7 +20,7 @@
license = 'MIT'
# Set the package classifiers
-python_versions_supported = ['3.6', '3.7', '3.8', '3.9']
+python_versions_supported = ['3.7', '3.8', '3.9', '3.10', '3.11']
os_platforms_supported = ['Unix', 'MacOS']
lc_str = 'License :: OSI Approved :: {0} License'
From 4dc7d10f7ba30bc046162eabbe3e50af892afaf8 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Sun, 4 Feb 2024 21:22:01 +0100
Subject: [PATCH 02/45] feat: add support for cupy in SparseThreshold.
Ideally we want to have such support everywhere.
---
modopt/opt/proximity.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py
index e8492367..fc81a753 100644
--- a/modopt/opt/proximity.py
+++ b/modopt/opt/proximity.py
@@ -22,6 +22,7 @@
else:
import_sklearn = True
+from modopt.base.backend import get_array_module
from modopt.base.transform import cube2matrix, matrix2cube
from modopt.base.types import check_callable
from modopt.interface.errors import warn
@@ -215,7 +216,10 @@ def _cost_method(self, *args, **kwargs):
Sparsity cost component
"""
- cost_val = np.sum(np.abs(self.weights * self._linear.op(args[0])))
+ xp = get_array_module(args[0])
+ cost_val = xp.sum(xp.abs(self.weights * self._linear.op(args[0])))
+ if isinstance(cost_val, xp.ndarray):
+ cost_val = cost_val.item()
if 'verbose' in kwargs and kwargs['verbose']:
print(' - L1 NORM (X):', cost_val)
From 1961acc6346da0b15b1b6ed9544195d235ec5b6c Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 5 Feb 2024 09:54:48 +0100
Subject: [PATCH 03/45] feat: add cupy wavelet transform.
---
modopt/opt/linear/wavelet.py | 189 +++++++++++++++++++++++++++++++++++
1 file changed, 189 insertions(+)
diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py
index 6e22a2b0..6d72db0d 100644
--- a/modopt/opt/linear/wavelet.py
+++ b/modopt/opt/linear/wavelet.py
@@ -16,6 +16,14 @@
except ImportError:
pywt_available = False
+ptwt_available = True
+try:
+ import ptwt
+ import torch
+ import cupy as cp
+except:
+ ptwt_available = False
+
class WaveletConvolve(LinearParent):
"""Wavelet Convolution Class.
@@ -214,3 +222,184 @@ def _adj_op(self, coeffs):
wavelet=self.wavelet,
mode=self.mode,
)
+
+
+class TorchWaveletTransform:
+ """Wavelet transform using pytorch."""
+
+ wavedec3_keys = ["aad", "ada", "add", "daa", "dad", "dda", "ddd"]
+
+ def __init__(
+ self,
+ shape: tuple[int, ...],
+ wavelet: str,
+ level: int,
+ mode: str,
+ ):
+ self.wavelet = wavelet
+ self.level = level
+ self.shape = shape
+ self.mode = mode
+
+ def op(self, data: torch.Tensor) -> list[torch.Tensor]:
+ """Apply the wavelet decomposition on.
+
+ Parameters
+ ----------
+ data: torch.Tensor
+ 2D or 3D, real or complex data with last axes matching shape of
+ the operator.
+
+ Returns
+ -------
+ list[torch.Tensor]
+ list of tensor each containing the data of a subband.
+ """
+ if data.shape == self.shape:
+ data = data[None, ...] # add a batch dimension
+
+ if len(self.shape) == 2:
+ if torch.is_complex(data):
+ # 2D Complex
+ data_ = torch.view_as_real(data)
+ coeffs_ = ptwt.wavedec2(
+ data_, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2)
+ )
+ self.coeffs_shape = [coeffs_[0].shape]
+ self.coeffs_shape += [tuple(cc.shape for cc in c) for c in coeffs_]
+ # flatten list of tuple of tensors to a list of tensors
+ coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [
+ torch.view_as_complex(cc.contiguous())
+ for c in coeffs_[1:]
+ for cc in c
+ ]
+
+ return coeffs
+ # 2D Real
+ coeffs_ = ptwt.wavedec2(
+ data, self.wavelet, level=self.level, mode=self.mode, axes=(-2, -1)
+ )
+ return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c]
+
+ if torch.is_complex(data):
+ # 3D Complex
+ data_ = torch.view_as_real(data)
+ coeffs_ = ptwt.wavedec3(
+ data_,
+ self.wavelet,
+ level=self.level,
+ mode=self.mode,
+ axes=(-4, -3, -2),
+ )
+ # flatten list of tuple of tensors to a list of tensors
+ coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [
+ torch.view_as_complex(cc.contiguous())
+ for c in coeffs_[1:]
+ for cc in c.values()
+ ]
+
+ return coeffs
+ # 3D Real
+ coeffs_ = ptwt.wavedec3(
+ data, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2, -1)
+ )
+ return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c.values()]
+
+ def adj_op(self, coeffs: list[torch.Tensor]) -> torch.Tensor:
+ """Apply the wavelet recomposition.
+
+ Parameters
+ ----------
+ list[torch.Tensor]
+ list of tensor each containing the data of a subband.
+
+ Returns
+ -------
+ data: torch.Tensor
+ 2D or 3D, real or complex data with last axes matching shape of the
+ operator.
+
+ """
+ if len(self.shape) == 2:
+ if torch.is_complex(coeffs[0]):
+ ## 2D Complex ##
+ # list of tensor to list of tuple of tensor
+ coeffs = [torch.view_as_real(coeffs[0])] + [
+ tuple(torch.view_as_real(coeffs[i + k]) for k in range(3))
+ for i in range(1, len(coeffs) - 2, 3)
+ ]
+ data = ptwt.waverec2(coeffs, wavelet=self.wavelet, axes=(-3, -2))
+ return torch.view_as_complex(data.contiguous())
+ ## 2D Real ##
+ coeffs_ = [coeffs[0]] + [
+ tuple(coeffs[i + k] for k in range(3))
+ for i in range(1, len(coeffs) - 2, 3)
+ ]
+ data = ptwt.waverec2(coeffs_, wavelet=self.wavelet, axes=(-2, -1))
+ return data
+
+ if torch.is_complex(coeffs[0]):
+ ## 3D Complex ##
+ # list of tensor to list of tuple of tensor
+ coeffs = [torch.view_as_real(coeffs[0])] + [
+ {
+ v: torch.view_as_real(coeffs[i + k])
+ for k, v in enumerate(self.wavedec3_keys)
+ }
+ for i in range(1, len(coeffs) - 6, 7)
+ ]
+ data = ptwt.waverec3(coeffs, wavelet=self.wavelet, axes=(-4, -3, -2))
+ return torch.view_as_complex(data.contiguous())
+ ## 3D Real ##
+ coeffs_ = [coeffs[0]] + [
+ {v: coeffs[i + k] for k, v in enumerate(self.wavedec3_keys)}
+ for i in range(1, len(coeffs) - 6, 7)
+ ]
+ data = ptwt.waverec3(coeffs_, wavelet=self.wavelet, axes=(-3, -2, -1))
+ return data
+
+
+class CupyWaveletTransform:
+ """Wrapper around torch wavelet transform to be compatible with the Modopt API."""
+
+ def __init__(
+ self,
+ shape: tuple[int, ...],
+ wavelet: str,
+ level: int,
+ mode: str,
+ ):
+ self.wavelet = wavelet
+ self.level = level
+ self.shape = shape
+ self.mode = mode
+
+ self.operator = TorchWaveletTransform(shape, wavelet, level, mode)
+
+ def op(self, data: cp.array) -> cp.ndarray:
+ """Apply Forward Wavelet transform on cupy array."""
+ data_ = torch.as_tensor(data)
+ tensor_list = self.operator.op(data_)
+ # flatten the list of tensor to a cupy array
+ # this requires an on device copy...
+ self.coeffs_shape = [c.shape for c in tensor_list]
+ n_tot_coeffs = np.sum([np.prod(s) for s in self.coeffs_shape])
+ ret = cp.zeros(n_tot_coeffs, dtype=np.complex64) # FIXME get dtype from torch
+ start = 0
+ for t in tensor_list:
+ stop = start + np.prod(t.shape)
+ ret[start:stop] = cp.asarray(t.flatten())
+ start = stop
+
+ return ret
+
+ def adj_op(self, data: cp.ndarray) -> cp.ndarray:
+ """Apply Adjoint Wavelet transform on cupy array."""
+ start = 0
+ tensor_list = [None] * len(self.coeffs_shape)
+ for i, s in enumerate(self.coeffs_shape):
+ stop = start + np.prod(s)
+ tensor_list[i] = torch.as_tensor(data[start:stop].reshape(s), device="cuda")
+ start = stop
+ ret_tensor = self.operator.adj_op(tensor_list)
+ return cp.from_dlpack(ret_tensor)
From fc36139eb02fa3d777971b1101eaf8c1a6b937de Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 5 Feb 2024 10:30:33 +0100
Subject: [PATCH 04/45] feat: use compute_backend to dispatch.
---
modopt/opt/linear/wavelet.py | 77 ++++++++++++++++++++++++++++++++++--
modopt/tests/test_opt.py | 16 ++++++--
2 files changed, 86 insertions(+), 7 deletions(-)
diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py
index 6d72db0d..c497e59e 100644
--- a/modopt/opt/linear/wavelet.py
+++ b/modopt/opt/linear/wavelet.py
@@ -21,7 +21,7 @@
import ptwt
import torch
import cupy as cp
-except:
+except ImportError:
ptwt_available = False
@@ -62,10 +62,53 @@ def __init__(self, filters, method='scipy'):
+
class WaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.
+ This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU using Pytorch).
+
+ Parameters
+ ----------
+ wavelet_name: str
+ the wavelet name to be used during the decomposition.
+ shape: tuple[int,...]
+ Shape of the input data. The shape should be a tuple of length 2 or 3.
+ It should not contains coils or batch dimension.
+ nb_scales: int, default 4
+ the number of scales in the decomposition.
+ mode: str, default "zero"
+ Boundary Condition mode
+ compute_backend: str, "numpy" or "cupy", default "numpy"
+ Backend library to use. "cupy" also requires a working installation of PyTorch and pytorch wavelets.
+
+ **kwargs: extra kwargs for Pywavelet or Pytorch Wavelet
+ """
+ def __init__(self,
+ wavelet_name,
+ shape,
+ level=4,
+ mode="symmetric",
+ compute_backend="numpy",
+ **kwargs):
+
+ if compute_backend == "cupy" and ptwt_available:
+ self.operator = CupyWaveletTransform(wavelet_name, shape, level, mode)
+ elif compute_backend == "numpy" and pywt_available:
+ self.operator = CPUWaveletTransform(wavelet_name = wavelet_name,shape= shape,nb_scales=level, **kwargs)
+ else:
+ raise ValueError(f"Compute Backend {compute_backend} not available")
+
+
+ self.op = self.operator.op
+ self.adj_op = self.operator.adj_op
+
+
+class CPUWaveletTransform(LinearParent):
+ """
+ 2D and 3D wavelet transform class.
+
This is a light wrapper around PyWavelet, with multicoil support.
Parameters
@@ -359,7 +402,7 @@ def adj_op(self, coeffs: list[torch.Tensor]) -> torch.Tensor:
return data
-class CupyWaveletTransform:
+class CupyWaveletTransform(LinearParent):
"""Wrapper around torch wavelet transform to be compatible with the Modopt API."""
def __init__(
@@ -377,7 +420,20 @@ def __init__(
self.operator = TorchWaveletTransform(shape, wavelet, level, mode)
def op(self, data: cp.array) -> cp.ndarray:
- """Apply Forward Wavelet transform on cupy array."""
+ """Define the wavelet operator.
+
+ This method returns the input data convolved with the wavelet filter.
+
+ Parameters
+ ----------
+ data: cp.ndarray
+ input 2D data array.
+
+ Returns
+ -------
+ coeffs: ndarray
+ the wavelet coefficients.
+ """
data_ = torch.as_tensor(data)
tensor_list = self.operator.op(data_)
# flatten the list of tensor to a cupy array
@@ -394,7 +450,20 @@ def op(self, data: cp.array) -> cp.ndarray:
return ret
def adj_op(self, data: cp.ndarray) -> cp.ndarray:
- """Apply Adjoint Wavelet transform on cupy array."""
+ """Define the wavelet adjoint operator.
+
+ This method returns the reconstructed image.
+
+ Parameters
+ ----------
+ coeffs: cp.ndarray
+ the wavelet coefficients.
+
+ Returns
+ -------
+ data: ndarray
+ the reconstructed data.
+ """
start = 0
tensor_list = [None] * len(self.coeffs_shape)
for i, s in enumerate(self.coeffs_shape):
diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py
index 4a82e33c..7c30186e 100644
--- a/modopt/tests/test_opt.py
+++ b/modopt/tests/test_opt.py
@@ -22,6 +22,13 @@
except ImportError:
SKLEARN_AVAILABLE = False
+PTWT_AVAILABLE = True
+try:
+ import ptwt
+ import cupy
+except ImportError:
+ PTWT_AVAILABLE = False
+
PYWT_AVAILABLE = True
try:
import pywt
@@ -174,8 +181,12 @@ def case_linear_wavelet_convolve(self):
return linop, data_op, data_adj_op, res_op, res_adj_op
- @pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")
- def case_linear_wavelet_transform(self):
+ @parametrize(
+ compute_backend=[
+ pytest.param("numpy", marks=pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")),
+ pytest.param("cupy", marks=pytest.mark.skipif(not PTWT_AVAILABLE, reason="Pytorch Wavelet not available."))
+ ])
+ def case_linear_wavelet_transform(self, compute_backend="numpy"):
linop = linear.WaveletTransform(
wavelet_name="haar",
shape=(8, 8),
@@ -298,7 +309,6 @@ class ProxCases:
[11.67394789, 12.87497954, 14.07601119],
[15.27704284, 16.47807449, 17.67910614],
],
- ]
)
array233_3 = np.array(
[
From 62b519a009c2bea4cc35e1e6f036e1e12e77e863 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 5 Feb 2024 11:04:37 +0100
Subject: [PATCH 05/45] fix: pass parameters by name
---
modopt/opt/linear/wavelet.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py
index c497e59e..cb75ecff 100644
--- a/modopt/opt/linear/wavelet.py
+++ b/modopt/opt/linear/wavelet.py
@@ -94,9 +94,9 @@ def __init__(self,
**kwargs):
if compute_backend == "cupy" and ptwt_available:
- self.operator = CupyWaveletTransform(wavelet_name, shape, level, mode)
+ self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode)
elif compute_backend == "numpy" and pywt_available:
- self.operator = CPUWaveletTransform(wavelet_name = wavelet_name,shape= shape,nb_scales=level, **kwargs)
+ self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, nb_scales=level, **kwargs)
else:
raise ValueError(f"Compute Backend {compute_backend} not available")
@@ -417,7 +417,7 @@ def __init__(
self.shape = shape
self.mode = mode
- self.operator = TorchWaveletTransform(shape, wavelet, level, mode)
+ self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode)
def op(self, data: cp.array) -> cp.ndarray:
"""Define the wavelet operator.
From 80159e4855c4fdb72893c30e49af412cd5e383cb Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 9 Feb 2024 18:12:08 +0100
Subject: [PATCH 06/45] fix: provide a coeffs shape property.
---
modopt/opt/linear/wavelet.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py
index cb75ecff..3fa2d60d 100644
--- a/modopt/opt/linear/wavelet.py
+++ b/modopt/opt/linear/wavelet.py
@@ -104,6 +104,9 @@ def __init__(self,
self.op = self.operator.op
self.adj_op = self.operator.adj_op
+ @property
+ def coeffs_shape(self):
+ return self.operator.coeffs_shape
class CPUWaveletTransform(LinearParent):
"""
@@ -283,6 +286,7 @@ def __init__(
self.level = level
self.shape = shape
self.mode = mode
+ self.coeffs_shape = None # will be set after op.
def op(self, data: torch.Tensor) -> list[torch.Tensor]:
"""Apply the wavelet decomposition on.
@@ -308,8 +312,6 @@ def op(self, data: torch.Tensor) -> list[torch.Tensor]:
coeffs_ = ptwt.wavedec2(
data_, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2)
)
- self.coeffs_shape = [coeffs_[0].shape]
- self.coeffs_shape += [tuple(cc.shape for cc in c) for c in coeffs_]
# flatten list of tuple of tensors to a list of tensors
coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [
torch.view_as_complex(cc.contiguous())
@@ -418,6 +420,7 @@ def __init__(
self.mode = mode
self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode)
+ self.coeffs_shape = None # will be set after op
def op(self, data: cp.array) -> cp.ndarray:
"""Define the wavelet operator.
From bf9367470a204a820d9d9913947b9201e71601f6 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Tue, 13 Feb 2024 18:28:56 +0100
Subject: [PATCH 07/45] fix: update name.
---
modopt/opt/linear/wavelet.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py
index 3fa2d60d..5feead66 100644
--- a/modopt/opt/linear/wavelet.py
+++ b/modopt/opt/linear/wavelet.py
@@ -96,7 +96,7 @@ def __init__(self,
if compute_backend == "cupy" and ptwt_available:
self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode)
elif compute_backend == "numpy" and pywt_available:
- self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, nb_scales=level, **kwargs)
+ self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, level=level, **kwargs)
else:
raise ValueError(f"Compute Backend {compute_backend} not available")
From 3385679bbb26c0ed92452d39652c54e6f315188d Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Tue, 13 Feb 2024 18:31:48 +0100
Subject: [PATCH 08/45] refactor: remove add_args_kwargs.
This is 1) expensive 2) dangerous 3) source of bug.
---
modopt/base/types.py | 6 -----
modopt/base/wrappers.py | 49 -----------------------------------------
2 files changed, 55 deletions(-)
delete mode 100644 modopt/base/wrappers.py
diff --git a/modopt/base/types.py b/modopt/base/types.py
index 88051675..e9212e12 100644
--- a/modopt/base/types.py
+++ b/modopt/base/types.py
@@ -9,8 +9,6 @@
"""
import numpy as np
-
-from modopt.base.wrappers import add_args_kwargs
from modopt.interface.errors import warn
@@ -45,10 +43,6 @@ def check_callable(input_obj, add_agrs=True):
"""
if not callable(input_obj):
raise TypeError('The input object must be a callable function.')
-
- if add_agrs:
- input_obj = add_args_kwargs(input_obj)
-
return input_obj
diff --git a/modopt/base/wrappers.py b/modopt/base/wrappers.py
deleted file mode 100644
index baedb891..00000000
--- a/modopt/base/wrappers.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""WRAPPERS.
-
-This module contains wrappers for adding additional features to functions.
-
-:Author: Samuel Farrens
-
-"""
-
-from functools import wraps
-from inspect import getfullargspec as argspec
-
-
-def add_args_kwargs(func):
- """Add args and kwargs.
-
- This wrapper adds support for additional arguments and keyword arguments to
- any callable function.
-
- Parameters
- ----------
- func : callable
- Callable function
-
- Returns
- -------
- callable
- wrapper
-
- """
- @wraps(func)
- def wrapper(*args, **kwargs):
-
- props = argspec(func)
-
- # if 'args' not in props:
- if isinstance(props[1], type(None)):
- args = args[:len(props[0])]
-
- if (
- (not isinstance(props[2], type(None)))
- or (not isinstance(props[3], type(None)))
- ):
- return func(*args, **kwargs)
-
- return func(*args)
-
- return wrapper
From 1bf9617f946f8b85cad904d0e07004084078ad62 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Wed, 14 Feb 2024 14:04:18 +0100
Subject: [PATCH 09/45] remove wrapper module.
---
modopt/base/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modopt/base/__init__.py b/modopt/base/__init__.py
index 1c0c8b2c..88424bae 100644
--- a/modopt/base/__init__.py
+++ b/modopt/base/__init__.py
@@ -9,4 +9,4 @@
"""
-__all__ = ['np_adjust', 'transform', 'types', 'wrappers', 'observable']
+__all__ = ['np_adjust', 'transform', 'types', 'observable']
From 557023e20ea4cb1cc084f1a65c2652e5b4ceed13 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 17:56:46 +0100
Subject: [PATCH 10/45] cleanup
---
modopt/base/types.py | 14 +-------------
1 file changed, 1 insertion(+), 13 deletions(-)
diff --git a/modopt/base/types.py b/modopt/base/types.py
index e9212e12..16e06f15 100644
--- a/modopt/base/types.py
+++ b/modopt/base/types.py
@@ -12,7 +12,7 @@
from modopt.interface.errors import warn
-def check_callable(input_obj, add_agrs=True):
+def check_callable(input_obj):
"""Check input object is callable.
This method checks if the input operator is a callable funciton and
@@ -23,23 +23,11 @@ def check_callable(input_obj, add_agrs=True):
----------
input_obj : callable
Callable function
- add_agrs : bool, optional
- Option to add support for agrs and kwargs (default is ``True``)
-
- Returns
- -------
- function
- Function wrapped by ``add_args_kwargs``
Raises
------
TypeError
For invalid input type
-
- See Also
- --------
- modopt.base.wrappers.add_args_kwargs : wrapper used
-
"""
if not callable(input_obj):
raise TypeError('The input object must be a callable function.')
From 753499946b4383c0b4041ac1e0e51d38cab7f0d9 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Wed, 14 Feb 2024 18:59:32 +0100
Subject: [PATCH 11/45] use a pyproject.toml file.
---
.github/workflows/ci-build.yml | 2 +-
.pylintrc | 2 -
.pyup.yml | 14 -----
MANIFEST.in | 5 --
develop.txt | 12 -----
pyproject.toml | 56 ++++++++++++++++++++
requirements.txt | 4 --
setup.cfg | 97 ----------------------------------
setup.py | 73 -------------------------
9 files changed, 57 insertions(+), 208 deletions(-)
delete mode 100644 .pylintrc
delete mode 100644 .pyup.yml
delete mode 100644 MANIFEST.in
delete mode 100644 develop.txt
create mode 100644 pyproject.toml
delete mode 100644 requirements.txt
delete mode 100644 setup.cfg
delete mode 100644 setup.py
diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml
index c4ba28a0..88129d45 100644
--- a/.github/workflows/ci-build.yml
+++ b/.github/workflows/ci-build.yml
@@ -44,7 +44,7 @@ jobs:
python -m pip install -r develop.txt
python -m pip install -r docs/requirements.txt
python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
- python -m pip install tensorflow>=2.4.1
+ python -m pip install tensorflow>=2.4.1 torch
python -m pip install twine
python -m pip install .
diff --git a/.pylintrc b/.pylintrc
deleted file mode 100644
index 3ac9aef9..00000000
--- a/.pylintrc
+++ /dev/null
@@ -1,2 +0,0 @@
-[MASTER]
-ignore-patterns=**/docs/**/*.py
diff --git a/.pyup.yml b/.pyup.yml
deleted file mode 100644
index 8fdac7ff..00000000
--- a/.pyup.yml
+++ /dev/null
@@ -1,14 +0,0 @@
-# autogenerated pyup.io config file
-# see https://pyup.io/docs/configuration/ for all available options
-
-schedule: ''
-update: all
-label_prs: update
-assignees: sfarrens
-requirements:
- - requirements.txt:
- pin: False
- - develop.txt:
- pin: False
- - docs/requirements.txt:
- pin: True
diff --git a/MANIFEST.in b/MANIFEST.in
deleted file mode 100644
index 9a2f374e..00000000
--- a/MANIFEST.in
+++ /dev/null
@@ -1,5 +0,0 @@
-include requirements.txt
-include develop.txt
-include docs/requirements.txt
-include README.rst
-include LICENSE.txt
diff --git a/develop.txt b/develop.txt
deleted file mode 100644
index 6ff665eb..00000000
--- a/develop.txt
+++ /dev/null
@@ -1,12 +0,0 @@
-coverage>=5.5
-pytest>=6.2.2
-pytest-raises>=0.10
-pytest-cases>= 3.6
-pytest-xdist>= 3.0.1
-pytest-cov>=2.11.1
-pytest-emoji>=0.2.0
-pydocstyle==6.1.1
-pytest-pydocstyle>=2.2.0
-black
-isort
-pytest-black
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..71bdce82
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,56 @@
+[project]
+name="modopt"
+description = 'Modular Optimisation tools for soliving inverse problems.'
+version = "1.7.1"
+requires-python= ">=3.8"
+
+authors = [{name="Samuel Farrens", email="samuel.farrens@cea.fr"},
+{name="Chaithya GR", email="chaithyagr@gmail.com"},
+{name="Pierre-Antoine Comby", email="pierre-antoine.comby@cea.fr"}
+]
+readme="README.md"
+license={file="LICENCE.txt"}
+
+dependencies = ["numpy", "scipy", "tqdm"]
+
+[project.optional-dependencies]
+gpu=["torch", "ptwt"]
+doc=["myst-parser==0.16.1",
+"nbsphinx==0.8.7",
+"nbsphinx-link==1.3.0",
+"sphinx-gallery==0.11.1",
+"sphinxawesome-theme==3.2.1",
+"sphinxcontrib-bibtex"]
+dev=["black", "pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-sugar", "ruff"]
+
+[build-system]
+requires=["setuptools", "setuptools-scm[toml]", "wheel"]
+
+[tool.setuptools]
+packages=["modopt"]
+
+[tool.coverage.run]
+omit = ["*tests*", "*__init__*", "*setup.py*", "*_version.py*", "*example*"]
+
+[tool.coverage.report]
+precision = 2
+exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
+
+[tool.black]
+
+[tool.ruff]
+
+src=["modopt"]
+select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"]
+
+[tool.ruff.pydocstyle]
+convention="numpy"
+
+[tool.isort]
+profile="black"
+
+[tool.pytest.ini_options]
+minversion = "6.0"
+norecursedirs = ["tests/helpers"]
+testpaths=["modopt"]
+addopts = ["--verbose", "--cov=modopt", "--cov-report=term-missing", "--cov-report=xml", "--junitxml=pytest.xml"]
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 1f44de13..00000000
--- a/requirements.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-importlib_metadata>=3.7.0
-numpy>=1.19.5
-scipy>=1.5.4
-tqdm>=4.64.0
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 100adb40..00000000
--- a/setup.cfg
+++ /dev/null
@@ -1,97 +0,0 @@
-[aliases]
-test=pytest
-
-[metadata]
-description_file = README.rst
-
-[darglint]
-docstring_style = numpy
-strictness = short
-
-[flake8]
-ignore =
- D107, #Justification: Don't need docstring for __init__ in numpydoc style
- RST304, #Justification: Need to use :cite: role for citations
- RST210, #Justification: RST210, RST213 Inconsistent with numpydoc
- RST213, # documentation for handling *args and **kwargs
- W503, #Justification: Have to choose one multiline operator format
- WPS202, #Todo: Rethink module size, possibly split large modules
- WPS337, #Todo: Consider simplifying multiline conditions.
- WPS338, #Todo: Consider changing method order
- WPS403, #Todo: Rethink no cover lines
- WPS421, #Todo: Review need for print statements
- WPS432, #Justification: Mathematical codes require "magic numbers"
- WPS433, #Todo: Rethink conditional imports
- WPS463, #Todo: Rename get_ methods
- WPS615, #Todo: Rename get_ methods
-per-file-ignores =
- #Justification: Needed for keeping package version and current API
- *__init__.py*: F401,F403,WPS347,WPS410,WPS412
- #Todo: Rethink conditional imports
- #Todo: How can we bypass mutable constants?
- modopt/base/backend.py: WPS229, WPS420, WPS407
- #Todo: Rethink conditional imports
- modopt/base/observable.py: WPS420,WPS604
- #Todo: Check string for log formatting
- modopt/interface/log.py: WPS323
- #Todo: Rethink conditional imports
- modopt/math/convolve.py: WPS301,WPS420
- #Todo: Rethink conditional imports
- modopt/math/matrix.py: WPS420
- #Todo: import has bad parenthesis
- modopt/opt/algorithms/__init__.py: F401,F403,WPS318, WPS319, WPS412, WPS410
- #Todo: x is a too short name.
- modopt/opt/algorithms/forward_backward.py: WPS111
- #Todo: u,v , A is a too short name.
- modopt/opt/algorithms/admm.py: WPS111, N803
- #Todo: Check need for del statement
- modopt/opt/algorithms/primal_dual.py: WPS111, WPS420
- #multiline parameters bug with tuples
- modopt/opt/algorithms/gradient_descent.py: WPS111, WPS420, WPS317
- #Todo: Consider changing costObj name
- modopt/opt/cost.py: N801,
- #Todo:
- # - Rethink subscript slice assignment
- # - Reduce complexity of KSupportNorm
- # - Check bitwise operations
- modopt/opt/proximity.py: WPS220,WPS231,WPS352,WPS362,WPS465,WPS506,WPS508
- #Todo: Consider changing cwbReweight name
- modopt/opt/reweight.py: N801
- #Justification: Needed to import matplotlib.pyplot
- modopt/plot/cost_plot.py: N802,WPS301
- #Todo: Investigate possible bug in find_n_pc function
- #Todo: Investigate darglint error
- modopt/signal/svd.py: WPS345, DAR000
- #Todo: Check security of using system executable call
- modopt/signal/wavelet.py: S404,S603
- #Todo: Clean up tests
- modopt/tests/*.py: E731,F401,WPS301,WPS420,WPS425,WPS437,WPS604
- #Todo: Import has bad parenthesis
- modopt/tests/test_base.py: WPS318,WPS319,E501,WPS301
-#WPS Settings
-max-arguments = 25
-max-attributes = 40
-max-cognitive-score = 20
-max-function-expressions = 20
-max-line-complexity = 30
-max-local-variables = 10
-max-methods = 20
-max-module-expressions = 20
-max-string-usages = 20
-max-raises = 5
-
-[tool:pytest]
-norecursedirs=tests/test_helpers
-testpaths =
- modopt
-addopts =
- --verbose
- --cov=modopt
- --cov-report=term-missing
- --cov-report=xml
- --junitxml=pytest.xml
- --pydocstyle
-
-[pydocstyle]
-convention=numpy
-add-ignore=D107
diff --git a/setup.py b/setup.py
deleted file mode 100644
index e6a8a9e6..00000000
--- a/setup.py
+++ /dev/null
@@ -1,73 +0,0 @@
-#! /usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from setuptools import setup, find_packages
-import os
-
-# Set the package release version
-major = 1
-minor = 7
-patch = 1
-
-# Set the package details
-name = 'modopt'
-version = '.'.join(str(value) for value in (major, minor, patch))
-author = 'Samuel Farrens'
-email = 'samuel.farrens@cea.fr'
-gh_user = 'cea-cosmic'
-url = 'https://github.com/{0}/{1}'.format(gh_user, name)
-description = 'Modular Optimisation tools for soliving inverse problems.'
-license = 'MIT'
-
-# Set the package classifiers
-python_versions_supported = ['3.7', '3.8', '3.9', '3.10', '3.11']
-os_platforms_supported = ['Unix', 'MacOS']
-
-lc_str = 'License :: OSI Approved :: {0} License'
-ln_str = 'Programming Language :: Python'
-py_str = 'Programming Language :: Python :: {0}'
-os_str = 'Operating System :: {0}'
-
-classifiers = (
- [lc_str.format(license)]
- + [ln_str]
- + [py_str.format(ver) for ver in python_versions_supported]
- + [os_str.format(ops) for ops in os_platforms_supported]
-)
-
-# Source package description from README.md
-this_directory = os.path.abspath(os.path.dirname(__file__))
-with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f:
- long_description = f.read()
-
-# Source package requirements from requirements.txt
-with open('requirements.txt') as open_file:
- install_requires = open_file.read()
-
-# Source test requirements from develop.txt
-with open('develop.txt') as open_file:
- tests_require = open_file.read()
-
-# Source doc requirements from docs/requirements.txt
-with open('docs/requirements.txt') as open_file:
- docs_require = open_file.read()
-
-
-setup(
- name=name,
- author=author,
- author_email=email,
- version=version,
- license=license,
- url=url,
- description=description,
- long_description=long_description,
- long_description_content_type='text/markdown',
- packages=find_packages(),
- install_requires=install_requires,
- python_requires='>=3.6',
- setup_requires=['pytest-runner'],
- tests_require=tests_require,
- extras_require={'develop': tests_require + docs_require},
- classifiers=classifiers,
-)
From 02a16d4588cada4edde0a4616b5e27d11e8489dd Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Wed, 14 Feb 2024 19:05:05 +0100
Subject: [PATCH 12/45] feat: add a style checking CI.
---
.github/workflows/style.yml | 38 +++++++++++++++++++++++++++++++++++++
1 file changed, 38 insertions(+)
create mode 100644 .github/workflows/style.yml
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
new file mode 100644
index 00000000..04ce6da6
--- /dev/null
+++ b/.github/workflows/style.yml
@@ -0,0 +1,38 @@
+name: Style checking
+
+on:
+ push:
+ branches: [ "master", "main", "develop" ]
+ pull_request:
+ branches: [ "master", "main", "develop" ]
+
+ workflow_dispatch:
+
+env:
+ PYTHON_VERSION: "3.10"
+
+jobs:
+ linter-check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+ - name: Set up Python ${{ env.PYTHON_VERSION }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ env.PYTHON_VERSION }}
+ cache: pip
+
+ - name: Install Python deps
+ shell: bash
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e .[test,dev]
+
+ - name: Black Check
+ shell: bash
+ run: black . --diff --color --check
+
+ - name: ruff Check
+ shell: bash
+ run: ruff check
From d9aa9643815e4415686bac98d92edc9f38fc9ca6 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 21:18:17 +0100
Subject: [PATCH 13/45] fix: missing bracket.
---
modopt/tests/test_opt.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py
index 7c30186e..1c2e7824 100644
--- a/modopt/tests/test_opt.py
+++ b/modopt/tests/test_opt.py
@@ -299,6 +299,7 @@ class ProxCases:
array233 = np.arange(18).reshape(2, 3, 3).astype(float)
array233_2 = np.array(
+ [
[
[2.73843189, 3.14594066, 3.55344943],
[3.9609582, 4.36846698, 4.77597575],
From 4d37bef1744a920fb17df8a8ad457fcb312537c1 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 22:37:45 +0100
Subject: [PATCH 14/45] run black.
---
docs/source/conf.py | 169 +++++++++---------
modopt/__init__.py | 8 +-
modopt/base/__init__.py | 2 +-
modopt/base/backend.py | 58 +++---
modopt/base/np_adjust.py | 6 +-
modopt/base/observable.py | 10 +-
modopt/base/transform.py | 43 ++---
modopt/base/types.py | 27 ++-
modopt/examples/conftest.py | 5 +-
.../example_lasso_forward_backward.py | 4 +-
modopt/interface/__init__.py | 2 +-
modopt/interface/errors.py | 28 ++-
modopt/interface/log.py | 16 +-
modopt/math/__init__.py | 2 +-
modopt/math/convolve.py | 32 ++--
modopt/math/matrix.py | 31 ++--
modopt/math/metrics.py | 14 +-
modopt/opt/__init__.py | 2 +-
modopt/opt/algorithms/__init__.py | 25 +--
modopt/opt/algorithms/admm.py | 11 +-
modopt/opt/algorithms/base.py | 55 +++---
modopt/opt/algorithms/forward_backward.py | 161 ++++++++---------
modopt/opt/algorithms/gradient_descent.py | 23 ++-
modopt/opt/algorithms/primal_dual.py | 44 +++--
modopt/opt/cost.py | 27 ++-
modopt/opt/gradient.py | 4 +-
modopt/opt/linear/base.py | 17 +-
modopt/opt/linear/wavelet.py | 29 +--
modopt/opt/proximity.py | 142 ++++++++-------
modopt/opt/reweight.py | 4 +-
modopt/plot/__init__.py | 2 +-
modopt/plot/cost_plot.py | 16 +-
modopt/signal/__init__.py | 2 +-
modopt/signal/filter.py | 2 +-
modopt/signal/noise.py | 17 +-
modopt/signal/positivity.py | 8 +-
modopt/signal/svd.py | 60 +++----
modopt/signal/validation.py | 4 +-
modopt/signal/wavelet.py | 53 +++---
modopt/tests/test_algorithms.py | 8 +-
modopt/tests/test_base.py | 1 +
modopt/tests/test_helpers/utils.py | 1 +
modopt/tests/test_math.py | 1 +
modopt/tests/test_opt.py | 8 +-
modopt/tests/test_signal.py | 7 +-
45 files changed, 589 insertions(+), 602 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 46564b9f..987576a9 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -9,56 +9,56 @@
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath('../..'))
+sys.path.insert(0, os.path.abspath("../.."))
# -- General configuration ------------------------------------------------
# General information about the project.
-project = 'modopt'
+project = "modopt"
mdata = metadata(project)
-author = mdata['Author']
-version = mdata['Version']
-copyright = '2020, {}'.format(author)
-gh_user = 'sfarrens'
+author = mdata["Author"]
+version = mdata["Version"]
+copyright = "2020, {}".format(author)
+gh_user = "sfarrens"
# If your documentation needs a minimal Sphinx version, state it here.
-needs_sphinx = '3.3'
+needs_sphinx = "3.3"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.autosummary',
- 'sphinx.ext.coverage',
- 'sphinx.ext.doctest',
- 'sphinx.ext.ifconfig',
- 'sphinx.ext.intersphinx',
- 'sphinx.ext.mathjax',
- 'sphinx.ext.napoleon',
- 'sphinx.ext.todo',
- 'sphinx.ext.viewcode',
- 'sphinxawesome_theme',
- 'sphinxcontrib.bibtex',
- 'myst_parser',
- 'nbsphinx',
- 'nbsphinx_link',
- 'numpydoc',
- "sphinx_gallery.gen_gallery"
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.coverage",
+ "sphinx.ext.doctest",
+ "sphinx.ext.ifconfig",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.todo",
+ "sphinx.ext.viewcode",
+ "sphinxawesome_theme",
+ "sphinxcontrib.bibtex",
+ "myst_parser",
+ "nbsphinx",
+ "nbsphinx_link",
+ "numpydoc",
+ "sphinx_gallery.gen_gallery",
]
# Include module names for objects
add_module_names = False
# Set class documentation standard.
-autoclass_content = 'class'
+autoclass_content = "class"
# Audodoc options
autodoc_default_options = {
- 'member-order': 'bysource',
- 'private-members': True,
- 'show-inheritance': True
+ "member-order": "bysource",
+ "private-members": True,
+ "show-inheritance": True,
}
# Generate summaries
@@ -69,17 +69,17 @@
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
-source_suffix = ['.rst', '.md']
+source_suffix = [".rst", ".md"]
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
show_authors = True
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'default'
+pygments_style = "default"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
@@ -88,7 +88,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
-html_theme = 'sphinxawesome_theme'
+html_theme = "sphinxawesome_theme"
# html_theme = 'sphinx_book_theme'
# Theme options are theme-specific and customize the look and feel of a theme
@@ -101,11 +101,10 @@
"breadcrumbs_separator": "/",
"show_prev_next": True,
"show_scrolltop": True,
-
}
html_collapsible_definitions = True
html_awesome_headerlinks = True
-html_logo = 'modopt_logo.jpg'
+html_logo = "modopt_logo.jpg"
html_permalinks_icon = (
'
'''
+ r""" """
+ r""""""
)
nbsphinx_prolog = nb_header_pt1 + nb_header_pt2
@@ -240,28 +237,28 @@ def add_notebooks(nb_path='../../notebooks'):
# Refer to the package libraries for type definitions
intersphinx_mapping = {
- 'python': ('http://docs.python.org/3', None),
- 'numpy': ('https://numpy.org/doc/stable/', None),
- 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None),
- 'progressbar': ('https://progressbar-2.readthedocs.io/en/latest/', None),
- 'matplotlib': ('https://matplotlib.org', None),
- 'astropy': ('http://docs.astropy.org/en/latest/', None),
- 'cupy': ('https://docs-cupy.chainer.org/en/stable/', None),
- 'torch': ('https://pytorch.org/docs/stable/', None),
- 'sklearn': (
- 'http://scikit-learn.org/stable',
- (None, './_intersphinx/sklearn-objects.inv')
+ "python": ("http://docs.python.org/3", None),
+ "numpy": ("https://numpy.org/doc/stable/", None),
+ "scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
+ "progressbar": ("https://progressbar-2.readthedocs.io/en/latest/", None),
+ "matplotlib": ("https://matplotlib.org", None),
+ "astropy": ("http://docs.astropy.org/en/latest/", None),
+ "cupy": ("https://docs-cupy.chainer.org/en/stable/", None),
+ "torch": ("https://pytorch.org/docs/stable/", None),
+ "sklearn": (
+ "http://scikit-learn.org/stable",
+ (None, "./_intersphinx/sklearn-objects.inv"),
),
- 'tensorflow': (
- 'https://www.tensorflow.org/api_docs/python',
+ "tensorflow": (
+ "https://www.tensorflow.org/api_docs/python",
(
- 'https://github.com/GPflow/tensorflow-intersphinx/'
- + 'raw/master/tf2_py_objects.inv')
- )
-
+ "https://github.com/GPflow/tensorflow-intersphinx/"
+ + "raw/master/tf2_py_objects.inv"
+ ),
+ ),
}
# -- BibTeX Setting ----------------------------------------------
-bibtex_bibfiles = ['refs.bib', 'my_ref.bib']
-bibtex_default_style = 'alpha'
+bibtex_bibfiles = ["refs.bib", "my_ref.bib"]
+bibtex_default_style = "alpha"
diff --git a/modopt/__init__.py b/modopt/__init__.py
index 2c06c1db..d446e15d 100644
--- a/modopt/__init__.py
+++ b/modopt/__init__.py
@@ -13,12 +13,12 @@
from modopt.base import *
try:
- _version = version('modopt')
+ _version = version("modopt")
except Exception: # pragma: no cover
- _version = 'Unkown'
+ _version = "Unkown"
warn(
- 'Could not extract package metadata. Make sure the package is '
- + 'correctly installed.',
+ "Could not extract package metadata. Make sure the package is "
+ + "correctly installed.",
)
__version__ = _version
diff --git a/modopt/base/__init__.py b/modopt/base/__init__.py
index 88424bae..d75ff315 100644
--- a/modopt/base/__init__.py
+++ b/modopt/base/__init__.py
@@ -9,4 +9,4 @@
"""
-__all__ = ['np_adjust', 'transform', 'types', 'observable']
+__all__ = ["np_adjust", "transform", "types", "observable"]
diff --git a/modopt/base/backend.py b/modopt/base/backend.py
index 1f4e9a72..fd933ebb 100644
--- a/modopt/base/backend.py
+++ b/modopt/base/backend.py
@@ -26,22 +26,24 @@
# Handle the compatibility with variable
LIBRARIES = {
- 'cupy': None,
- 'tensorflow': None,
- 'numpy': np,
+ "cupy": None,
+ "tensorflow": None,
+ "numpy": np,
}
-if util.find_spec('cupy') is not None:
+if util.find_spec("cupy") is not None:
try:
import cupy as cp
- LIBRARIES['cupy'] = cp
+
+ LIBRARIES["cupy"] = cp
except ImportError:
pass
-if util.find_spec('tensorflow') is not None:
+if util.find_spec("tensorflow") is not None:
try:
from tensorflow.experimental import numpy as tnp
- LIBRARIES['tensorflow'] = tnp
+
+ LIBRARIES["tensorflow"] = tnp
except ImportError:
pass
@@ -66,12 +68,12 @@ def get_backend(backend):
"""
if backend not in LIBRARIES.keys() or LIBRARIES[backend] is None:
msg = (
- '{0} backend not possible, please ensure that '
- + 'the optional libraries are installed.\n'
- + 'Reverting to numpy.'
+ "{0} backend not possible, please ensure that "
+ + "the optional libraries are installed.\n"
+ + "Reverting to numpy."
)
warn(msg.format(backend))
- backend = 'numpy'
+ backend = "numpy"
return LIBRARIES[backend], backend
@@ -92,16 +94,16 @@ def get_array_module(input_data):
The numpy or cupy module
"""
- if LIBRARIES['tensorflow'] is not None:
- if isinstance(input_data, LIBRARIES['tensorflow'].ndarray):
- return LIBRARIES['tensorflow']
- if LIBRARIES['cupy'] is not None:
- if isinstance(input_data, LIBRARIES['cupy'].ndarray):
- return LIBRARIES['cupy']
+ if LIBRARIES["tensorflow"] is not None:
+ if isinstance(input_data, LIBRARIES["tensorflow"].ndarray):
+ return LIBRARIES["tensorflow"]
+ if LIBRARIES["cupy"] is not None:
+ if isinstance(input_data, LIBRARIES["cupy"].ndarray):
+ return LIBRARIES["cupy"]
return np
-def change_backend(input_data, backend='cupy'):
+def change_backend(input_data, backend="cupy"):
"""Move data to device.
This method changes the backend of an array. This can be used to copy data
@@ -151,13 +153,13 @@ def move_to_cpu(input_data):
"""
xp = get_array_module(input_data)
- if xp == LIBRARIES['numpy']:
+ if xp == LIBRARIES["numpy"]:
return input_data
- elif xp == LIBRARIES['cupy']:
+ elif xp == LIBRARIES["cupy"]:
return input_data.get()
- elif xp == LIBRARIES['tensorflow']:
+ elif xp == LIBRARIES["tensorflow"]:
return input_data.data.numpy()
- raise ValueError('Cannot identify the array type.')
+ raise ValueError("Cannot identify the array type.")
def convert_to_tensor(input_data):
@@ -184,9 +186,9 @@ def convert_to_tensor(input_data):
"""
if not import_torch:
raise ImportError(
- 'Required version of Torch package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Torch package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
xp = get_array_module(input_data)
@@ -220,9 +222,9 @@ def convert_to_cupy_array(input_data):
"""
if not import_torch:
raise ImportError(
- 'Required version of Torch package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Torch package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
if input_data.is_cuda:
diff --git a/modopt/base/np_adjust.py b/modopt/base/np_adjust.py
index 6d290e43..31a785f5 100644
--- a/modopt/base/np_adjust.py
+++ b/modopt/base/np_adjust.py
@@ -154,8 +154,8 @@ def pad2d(input_data, padding):
padding = np.array(padding)
elif not isinstance(padding, np.ndarray):
raise ValueError(
- 'Padding must be an integer or a tuple (or list, np.ndarray) '
- + 'of itegers',
+ "Padding must be an integer or a tuple (or list, np.ndarray) "
+ + "of itegers",
)
if padding.size == 1:
@@ -164,7 +164,7 @@ def pad2d(input_data, padding):
pad_x = (padding[0], padding[0])
pad_y = (padding[1], padding[1])
- return np.pad(input_data, (pad_x, pad_y), 'constant')
+ return np.pad(input_data, (pad_x, pad_y), "constant")
def ftr(input_data):
diff --git a/modopt/base/observable.py b/modopt/base/observable.py
index 6471ba58..2f69a1a7 100644
--- a/modopt/base/observable.py
+++ b/modopt/base/observable.py
@@ -264,9 +264,7 @@ def is_converge(self):
mid_idx = -(self.wind // 2)
old_mean = np.array(self.list_cv_values[start_idx:mid_idx]).mean()
current_mean = np.array(self.list_cv_values[mid_idx:]).mean()
- normalize_residual_metrics = (
- np.abs(old_mean - current_mean) / np.abs(old_mean)
- )
+ normalize_residual_metrics = np.abs(old_mean - current_mean) / np.abs(old_mean)
self.converge_flag = normalize_residual_metrics < self.eps
def retrieve_metrics(self):
@@ -287,7 +285,7 @@ def retrieve_metrics(self):
time_val -= time_val[0]
return {
- 'time': time_val,
- 'index': self.list_iters,
- 'values': self.list_cv_values,
+ "time": time_val,
+ "index": self.list_iters,
+ "values": self.list_cv_values,
}
diff --git a/modopt/base/transform.py b/modopt/base/transform.py
index 07ce846f..fedd5efb 100644
--- a/modopt/base/transform.py
+++ b/modopt/base/transform.py
@@ -53,18 +53,17 @@ def cube2map(data_cube, layout):
"""
if data_cube.ndim != 3:
- raise ValueError('The input data must have 3 dimensions.')
+ raise ValueError("The input data must have 3 dimensions.")
if data_cube.shape[0] != np.prod(layout):
raise ValueError(
- 'The desired layout must match the number of input '
- + 'data layers.',
+ "The desired layout must match the number of input " + "data layers.",
)
- res = ([
+ res = [
np.hstack(data_cube[slice(layout[1] * elem, layout[1] * (elem + 1))])
for elem in range(layout[0])
- ])
+ ]
return np.vstack(res)
@@ -118,20 +117,24 @@ def map2cube(data_map, layout):
"""
if np.all(np.array(data_map.shape) % np.array(layout)):
raise ValueError(
- 'The desired layout must be a multiple of the number '
- + 'pixels in the data map.',
+ "The desired layout must be a multiple of the number "
+ + "pixels in the data map.",
)
d_shape = np.array(data_map.shape) // np.array(layout)
- return np.array([
- data_map[(
- slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
- slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
- )]
- for i_elem in range(layout[0])
- for j_elem in range(layout[1])
- ])
+ return np.array(
+ [
+ data_map[
+ (
+ slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
+ slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
+ )
+ ]
+ for i_elem in range(layout[0])
+ for j_elem in range(layout[1])
+ ]
+ )
def map2matrix(data_map, layout):
@@ -186,9 +189,9 @@ def map2matrix(data_map, layout):
image_shape * (i_elem % layout[1] + 1),
)
data_matrix.append(
- (
- data_map[lower[0]:upper[0], lower[1]:upper[1]]
- ).reshape(image_shape ** 2),
+ (data_map[lower[0] : upper[0], lower[1] : upper[1]]).reshape(
+ image_shape**2
+ ),
)
return np.array(data_matrix).T
@@ -232,7 +235,7 @@ def matrix2map(data_matrix, map_shape):
# Get the shape and layout of the images
image_shape = np.sqrt(data_matrix.shape[0]).astype(int)
- layout = np.array(map_shape // np.repeat(image_shape, 2), dtype='int')
+ layout = np.array(map_shape // np.repeat(image_shape, 2), dtype="int")
# Map objects from matrix
data_map = np.zeros(map_shape)
@@ -248,7 +251,7 @@ def matrix2map(data_matrix, map_shape):
image_shape * (i_elem // layout[1] + 1),
image_shape * (i_elem % layout[1] + 1),
)
- data_map[lower[0]:upper[0], lower[1]:upper[1]] = temp[:, :, i_elem]
+ data_map[lower[0] : upper[0], lower[1] : upper[1]] = temp[:, :, i_elem]
return data_map.astype(int)
diff --git a/modopt/base/types.py b/modopt/base/types.py
index 16e06f15..7ea805ad 100644
--- a/modopt/base/types.py
+++ b/modopt/base/types.py
@@ -30,7 +30,7 @@ def check_callable(input_obj):
For invalid input type
"""
if not callable(input_obj):
- raise TypeError('The input object must be a callable function.')
+ raise TypeError("The input object must be a callable function.")
return input_obj
@@ -71,14 +71,13 @@ def check_float(input_obj):
"""
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
- raise TypeError('Invalid input type.')
+ raise TypeError("Invalid input type.")
if isinstance(input_obj, int):
input_obj = float(input_obj)
elif isinstance(input_obj, (list, tuple)):
input_obj = np.array(input_obj, dtype=float)
- elif (
- isinstance(input_obj, np.ndarray)
- and (not np.issubdtype(input_obj.dtype, np.floating))
+ elif isinstance(input_obj, np.ndarray) and (
+ not np.issubdtype(input_obj.dtype, np.floating)
):
input_obj = input_obj.astype(float)
@@ -121,14 +120,13 @@ def check_int(input_obj):
"""
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
- raise TypeError('Invalid input type.')
+ raise TypeError("Invalid input type.")
if isinstance(input_obj, float):
input_obj = int(input_obj)
elif isinstance(input_obj, (list, tuple)):
input_obj = np.array(input_obj, dtype=int)
- elif (
- isinstance(input_obj, np.ndarray)
- and (not np.issubdtype(input_obj.dtype, np.integer))
+ elif isinstance(input_obj, np.ndarray) and (
+ not np.issubdtype(input_obj.dtype, np.integer)
):
input_obj = input_obj.astype(int)
@@ -160,19 +158,18 @@ def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True):
"""
if not isinstance(input_obj, np.ndarray):
- raise TypeError('Input is not a numpy array.')
+ raise TypeError("Input is not a numpy array.")
- if (
- (not isinstance(dtype, type(None)))
- and (not np.issubdtype(input_obj.dtype, dtype))
+ if (not isinstance(dtype, type(None))) and (
+ not np.issubdtype(input_obj.dtype, dtype)
):
raise (
TypeError(
- 'The numpy array elements are not of type: {0}'.format(dtype),
+ "The numpy array elements are not of type: {0}".format(dtype),
),
)
if not writeable and verbose and input_obj.flags.writeable:
- warn('Making input data immutable.')
+ warn("Making input data immutable.")
input_obj.flags.writeable = writeable
diff --git a/modopt/examples/conftest.py b/modopt/examples/conftest.py
index 73358679..f3ed371b 100644
--- a/modopt/examples/conftest.py
+++ b/modopt/examples/conftest.py
@@ -11,10 +11,12 @@
https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test
"""
+
from pathlib import Path
import runpy
import pytest
+
def pytest_collect_file(path, parent):
"""Pytest hook.
@@ -22,7 +24,7 @@ def pytest_collect_file(path, parent):
The new node needs to have the specified parent as parent.
"""
p = Path(path)
- if p.suffix == '.py' and 'example' in p.name:
+ if p.suffix == ".py" and "example" in p.name:
return Script.from_parent(parent, path=p, name=p.name)
@@ -33,6 +35,7 @@ def collect(self):
"""Collect the script as its own item."""
yield ScriptItem.from_parent(self, name=self.name)
+
class ScriptItem(pytest.Item):
"""Item script collected by pytest."""
diff --git a/modopt/examples/example_lasso_forward_backward.py b/modopt/examples/example_lasso_forward_backward.py
index 7f820000..c28b0499 100644
--- a/modopt/examples/example_lasso_forward_backward.py
+++ b/modopt/examples/example_lasso_forward_backward.py
@@ -76,7 +76,7 @@
prox=prox_op,
cost=cost_op_fista,
metric_call_period=1,
- auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
)
fb_fista.iterate()
@@ -115,7 +115,7 @@
prox=prox_op,
cost=cost_op_pogm,
metric_call_period=1,
- auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
)
fb_pogm.iterate()
diff --git a/modopt/interface/__init__.py b/modopt/interface/__init__.py
index f9439747..55904ca1 100644
--- a/modopt/interface/__init__.py
+++ b/modopt/interface/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['errors', 'log']
+__all__ = ["errors", "log"]
diff --git a/modopt/interface/errors.py b/modopt/interface/errors.py
index 0fbe7e71..eb4aa4ca 100644
--- a/modopt/interface/errors.py
+++ b/modopt/interface/errors.py
@@ -34,12 +34,12 @@ def warn(warn_string, log=None):
"""
if import_fail:
- warn_txt = 'WARNING'
+ warn_txt = "WARNING"
else:
- warn_txt = colored('WARNING', 'yellow')
+ warn_txt = colored("WARNING", "yellow")
# Print warning to stdout.
- sys.stderr.write('{0}: {1}\n'.format(warn_txt, warn_string))
+ sys.stderr.write("{0}: {1}\n".format(warn_txt, warn_string))
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
@@ -61,17 +61,17 @@ def catch_error(exception, log=None):
"""
if import_fail:
- err_txt = 'ERROR'
+ err_txt = "ERROR"
else:
- err_txt = colored('ERROR', 'red')
+ err_txt = colored("ERROR", "red")
# Print exception to stdout.
- stream_txt = '{0}: {1}\n'.format(err_txt, exception)
+ stream_txt = "{0}: {1}\n".format(err_txt, exception)
sys.stderr.write(stream_txt)
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- log_txt = 'ERROR: {0}\n'.format(exception)
+ log_txt = "ERROR: {0}\n".format(exception)
log.exception(log_txt)
@@ -91,11 +91,11 @@ def file_name_error(file_name):
If file name not specified or file not found
"""
- if file_name == '' or file_name[0][0] == '-':
- raise IOError('Input file name not specified.')
+ if file_name == "" or file_name[0][0] == "-":
+ raise IOError("Input file name not specified.")
elif not os.path.isfile(file_name):
- raise IOError('Input file name {0} not found!'.format(file_name))
+ raise IOError("Input file name {0} not found!".format(file_name))
def is_exe(fpath):
@@ -136,7 +136,7 @@ def is_executable(exe_name):
"""
if not isinstance(exe_name, str):
- raise TypeError('Executable name must be a string.')
+ raise TypeError("Executable name must be a string.")
fpath, fname = os.path.split(exe_name)
@@ -146,11 +146,9 @@ def is_executable(exe_name):
else:
res = any(
is_exe(os.path.join(path, exe_name))
- for path in os.environ['PATH'].split(os.pathsep)
+ for path in os.environ["PATH"].split(os.pathsep)
)
if not res:
- message = (
- '{0} does not appear to be a valid executable on this system.'
- )
+ message = "{0} does not appear to be a valid executable on this system."
raise IOError(message.format(exe_name))
diff --git a/modopt/interface/log.py b/modopt/interface/log.py
index 3b2fa77a..a02428d9 100644
--- a/modopt/interface/log.py
+++ b/modopt/interface/log.py
@@ -30,22 +30,22 @@ def set_up_log(filename, verbose=True):
"""
# Add file extension.
- filename = '{0}.log'.format(filename)
+ filename = "{0}.log".format(filename)
if verbose:
- print('Preparing log file:', filename)
+ print("Preparing log file:", filename)
# Capture warnings.
logging.captureWarnings(True)
# Set output format.
formatter = logging.Formatter(
- fmt='%(asctime)s %(message)s',
- datefmt='%d/%m/%Y %H:%M:%S',
+ fmt="%(asctime)s %(message)s",
+ datefmt="%d/%m/%Y %H:%M:%S",
)
# Create file handler.
- fh = logging.FileHandler(filename=filename, mode='w')
+ fh = logging.FileHandler(filename=filename, mode="w")
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
@@ -55,7 +55,7 @@ def set_up_log(filename, verbose=True):
log.addHandler(fh)
# Send opening message.
- log.info('The log file has been set-up.')
+ log.info("The log file has been set-up.")
return log
@@ -74,10 +74,10 @@ def close_log(log, verbose=True):
"""
if verbose:
- print('Closing log file:', log.name)
+ print("Closing log file:", log.name)
# Send closing message.
- log.info('The log file has been closed.')
+ log.info("The log file has been closed.")
# Remove all handlers from log.
for log_handler in log.handlers:
diff --git a/modopt/math/__init__.py b/modopt/math/__init__.py
index a22c0c98..8e92aa50 100644
--- a/modopt/math/__init__.py
+++ b/modopt/math/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['convolve', 'matrix', 'stats', 'metrics']
+__all__ = ["convolve", "matrix", "stats", "metrics"]
diff --git a/modopt/math/convolve.py b/modopt/math/convolve.py
index a4322ff2..528b2338 100644
--- a/modopt/math/convolve.py
+++ b/modopt/math/convolve.py
@@ -18,7 +18,7 @@
from astropy.convolution import convolve_fft
except ImportError: # pragma: no cover
import_astropy = False
- warn('astropy not found, will default to scipy for convolution')
+ warn("astropy not found, will default to scipy for convolution")
else:
import_astropy = True
try:
@@ -30,7 +30,7 @@
warn('Using pyFFTW "monkey patch" for scipy.fftpack')
-def convolve(input_data, kernel, method='scipy'):
+def convolve(input_data, kernel, method="scipy"):
"""Convolve data with kernel.
This method convolves the input data with a given kernel using FFT and
@@ -80,29 +80,29 @@ def convolve(input_data, kernel, method='scipy'):
"""
if input_data.ndim != kernel.ndim:
- raise ValueError('Data and kernel must have the same dimensions.')
+ raise ValueError("Data and kernel must have the same dimensions.")
- if method not in {'astropy', 'scipy'}:
+ if method not in {"astropy", "scipy"}:
raise ValueError('Invalid method. Options are "astropy" or "scipy".')
if not import_astropy: # pragma: no cover
- method = 'scipy'
+ method = "scipy"
- if method == 'astropy':
+ if method == "astropy":
return convolve_fft(
input_data,
kernel,
- boundary='wrap',
+ boundary="wrap",
crop=False,
- nan_treatment='fill',
+ nan_treatment="fill",
normalize_kernel=False,
)
- elif method == 'scipy':
- return scipy.signal.fftconvolve(input_data, kernel, mode='same')
+ elif method == "scipy":
+ return scipy.signal.fftconvolve(input_data, kernel, mode="same")
-def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'):
+def convolve_stack(input_data, kernel, rot_kernel=False, method="scipy"):
"""Convolve stack of data with stack of kernels.
This method convolves the input data with a given kernel using FFT and
@@ -156,7 +156,9 @@ def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'):
if rot_kernel:
kernel = rotate_stack(kernel)
- return np.array([
- convolve(data_i, kernel_i, method=method)
- for data_i, kernel_i in zip(input_data, kernel)
- ])
+ return np.array(
+ [
+ convolve(data_i, kernel_i, method=method)
+ for data_i, kernel_i in zip(input_data, kernel)
+ ]
+ )
diff --git a/modopt/math/matrix.py b/modopt/math/matrix.py
index 8361531d..6ddb3f2f 100644
--- a/modopt/math/matrix.py
+++ b/modopt/math/matrix.py
@@ -15,7 +15,7 @@
from modopt.base.backend import get_array_module, get_backend
-def gram_schmidt(matrix, return_opt='orthonormal'):
+def gram_schmidt(matrix, return_opt="orthonormal"):
r"""Gram-Schmit.
This method orthonormalizes the row vectors of the input matrix.
@@ -55,7 +55,7 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
"""
- if return_opt not in {'orthonormal', 'orthogonal', 'both'}:
+ if return_opt not in {"orthonormal", "orthogonal", "both"}:
raise ValueError(
'Invalid return_opt, options are: "orthonormal", "orthogonal" or '
+ '"both"',
@@ -77,11 +77,11 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
u_vec = np.array(u_vec)
e_vec = np.array(e_vec)
- if return_opt == 'orthonormal':
+ if return_opt == "orthonormal":
return e_vec
- elif return_opt == 'orthogonal':
+ elif return_opt == "orthogonal":
return u_vec
- elif return_opt == 'both':
+ elif return_opt == "both":
return u_vec, e_vec
@@ -201,7 +201,7 @@ def rot_matrix(angle):
return np.around(
np.array(
[[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]],
- dtype='float',
+ dtype="float",
),
10,
)
@@ -243,16 +243,15 @@ def rotate(matrix, angle):
shape = np.array(matrix.shape)
if shape[0] != shape[1]:
- raise ValueError('Input matrix must be square.')
+ raise ValueError("Input matrix must be square.")
shift = (shape - 1) // 2
index = (
- np.array(list(product(*np.array([np.arange(sval) for sval in shape]))))
- - shift
+ np.array(list(product(*np.array([np.arange(sval) for sval in shape])))) - shift
)
- new_index = np.array(np.dot(index, rot_matrix(angle)), dtype='int') + shift
+ new_index = np.array(np.dot(index, rot_matrix(angle)), dtype="int") + shift
new_index[new_index >= shape[0]] -= shape[0]
return matrix[tuple(zip(new_index.T))].reshape(shape.T)
@@ -301,7 +300,7 @@ def __init__(
data_shape,
data_type=float,
auto_run=True,
- compute_backend='numpy',
+ compute_backend="numpy",
verbose=False,
):
@@ -363,18 +362,14 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
x_new /= x_new_norm
- if (xp.abs(x_new_norm - x_old_norm) < tolerance):
- message = (
- ' - Power Method converged after {0} iterations!'
- )
+ if xp.abs(x_new_norm - x_old_norm) < tolerance:
+ message = " - Power Method converged after {0} iterations!"
if self._verbose:
print(message.format(i_elem + 1))
break
elif i_elem == max_iter - 1 and self._verbose:
- message = (
- ' - Power Method did not converge after {0} iterations!'
- )
+ message = " - Power Method did not converge after {0} iterations!"
print(message.format(max_iter))
xp.copyto(x_old, x_new)
diff --git a/modopt/math/metrics.py b/modopt/math/metrics.py
index 21952624..93f7ce06 100644
--- a/modopt/math/metrics.py
+++ b/modopt/math/metrics.py
@@ -71,15 +71,13 @@ def _preprocess_input(test, ref, mask=None):
The SNR
"""
- test = np.abs(np.copy(test)).astype('float64')
- ref = np.abs(np.copy(ref)).astype('float64')
+ test = np.abs(np.copy(test)).astype("float64")
+ ref = np.abs(np.copy(ref)).astype("float64")
test = min_max_normalize(test)
ref = min_max_normalize(ref)
if (not isinstance(mask, np.ndarray)) and (mask is not None):
- message = (
- 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
- )
+ message = 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
raise ValueError(message.format(mask))
if mask is None:
@@ -119,9 +117,9 @@ def ssim(test, ref, mask=None):
"""
if not import_skimage: # pragma: no cover
raise ImportError(
- 'Required version of Scikit-Image package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Scikit-Image package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
test, ref, mask = _preprocess_input(test, ref, mask)
diff --git a/modopt/opt/__init__.py b/modopt/opt/__init__.py
index 2fd3d747..8b285bee 100644
--- a/modopt/opt/__init__.py
+++ b/modopt/opt/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['cost', 'gradient', 'linear', 'algorithms', 'proximity', 'reweight']
+__all__ = ["cost", "gradient", "linear", "algorithms", "proximity", "reweight"]
diff --git a/modopt/opt/algorithms/__init__.py b/modopt/opt/algorithms/__init__.py
index d4e7082b..6a29325a 100644
--- a/modopt/opt/algorithms/__init__.py
+++ b/modopt/opt/algorithms/__init__.py
@@ -46,15 +46,20 @@
"""
from modopt.opt.algorithms.base import SetUp
-from modopt.opt.algorithms.forward_backward import (FISTA, POGM,
- ForwardBackward,
- GenForwardBackward)
-from modopt.opt.algorithms.gradient_descent import (AdaGenericGradOpt,
- ADAMGradOpt,
- GenericGradOpt,
- MomentumGradOpt,
- RMSpropGradOpt,
- SAGAOptGradOpt,
- VanillaGenericGradOpt)
+from modopt.opt.algorithms.forward_backward import (
+ FISTA,
+ POGM,
+ ForwardBackward,
+ GenForwardBackward,
+)
+from modopt.opt.algorithms.gradient_descent import (
+ AdaGenericGradOpt,
+ ADAMGradOpt,
+ GenericGradOpt,
+ MomentumGradOpt,
+ RMSpropGradOpt,
+ SAGAOptGradOpt,
+ VanillaGenericGradOpt,
+)
from modopt.opt.algorithms.primal_dual import Condat
from modopt.opt.algorithms.admm import ADMM, FastADMM
diff --git a/modopt/opt/algorithms/admm.py b/modopt/opt/algorithms/admm.py
index b881b770..4fd8074e 100644
--- a/modopt/opt/algorithms/admm.py
+++ b/modopt/opt/algorithms/admm.py
@@ -1,4 +1,5 @@
"""ADMM Algorithms."""
+
import numpy as np
from modopt.base.backend import get_array_module
@@ -188,7 +189,7 @@ def iterate(self, max_iter=150):
self.retrieve_outputs()
# rename outputs as attributes
self.u_final = self._u_new
- self.x_final = self.u_final # for backward compatibility
+ self.x_final = self.u_final # for backward compatibility
self.v_final = self._v_new
def get_notify_observers_kwargs(self):
@@ -203,9 +204,9 @@ def get_notify_observers_kwargs(self):
The mapping between the iterated variables
"""
return {
- 'x_new': self._u_new,
- 'v_new': self._v_new,
- 'idx': self.idx,
+ "x_new": self._u_new,
+ "v_new": self._v_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -215,7 +216,7 @@ def retrieve_outputs(self):
y_final, metrics.
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/algorithms/base.py b/modopt/opt/algorithms/base.py
index c5a4b101..e2b9017d 100644
--- a/modopt/opt/algorithms/base.py
+++ b/modopt/opt/algorithms/base.py
@@ -69,7 +69,7 @@ def __init__(
verbose=False,
progress=True,
step_size=None,
- compute_backend='numpy',
+ compute_backend="numpy",
**dummy_kwargs,
):
self.idx = 0
@@ -79,26 +79,26 @@ def __init__(
self.metrics = metrics
self.step_size = step_size
self._op_parents = (
- 'GradParent',
- 'ProximityParent',
- 'LinearParent',
- 'costObj',
+ "GradParent",
+ "ProximityParent",
+ "LinearParent",
+ "costObj",
)
self.metric_call_period = metric_call_period
# Declaration of observers for metrics
- super().__init__(['cv_metrics'])
+ super().__init__(["cv_metrics"])
for name, dic in self.metrics.items():
observer = MetricObserver(
name,
- dic['metric'],
- dic['mapping'],
- dic['cst_kwargs'],
- dic['early_stopping'],
+ dic["metric"],
+ dic["mapping"],
+ dic["cst_kwargs"],
+ dic["early_stopping"],
)
- self.add_observer('cv_metrics', observer)
+ self.add_observer("cv_metrics", observer)
xp, compute_backend = backend.get_backend(compute_backend)
self.xp = xp
@@ -118,7 +118,7 @@ def metrics(self, metrics):
self._metrics = metrics
else:
raise TypeError(
- 'Metrics must be a dictionary, not {0}.'.format(type(metrics)),
+ "Metrics must be a dictionary, not {0}.".format(type(metrics)),
)
def any_convergence_flag(self):
@@ -132,9 +132,7 @@ def any_convergence_flag(self):
True if any convergence criteria met
"""
- return any(
- obs.converge_flag for obs in self._observers['cv_metrics']
- )
+ return any(obs.converge_flag for obs in self._observers["cv_metrics"])
def copy_data(self, input_data):
"""Copy Data.
@@ -152,10 +150,12 @@ def copy_data(self, input_data):
Copy of input data
"""
- return self.xp.copy(backend.change_backend(
- input_data,
- self.compute_backend,
- ))
+ return self.xp.copy(
+ backend.change_backend(
+ input_data,
+ self.compute_backend,
+ )
+ )
def _check_input_data(self, input_data):
"""Check input data type.
@@ -175,7 +175,7 @@ def _check_input_data(self, input_data):
"""
if not (isinstance(input_data, (self.xp.ndarray, np.ndarray))):
raise TypeError(
- 'Input data must be a numpy array or backend array',
+ "Input data must be a numpy array or backend array",
)
def _check_param(self, param_val):
@@ -195,7 +195,7 @@ def _check_param(self, param_val):
"""
if not isinstance(param_val, float):
- raise TypeError('Algorithm parameter must be a float value.')
+ raise TypeError("Algorithm parameter must be a float value.")
def _check_param_update(self, param_update):
"""Check algorithm parameter update methods.
@@ -213,14 +213,13 @@ def _check_param_update(self, param_update):
For invalid input type
"""
- param_conditions = (
- not isinstance(param_update, type(None))
- and not callable(param_update)
+ param_conditions = not isinstance(param_update, type(None)) and not callable(
+ param_update
)
if param_conditions:
raise TypeError(
- 'Algorithm parameter update must be a callabale function.',
+ "Algorithm parameter update must be a callabale function.",
)
def _check_operator(self, operator):
@@ -239,7 +238,7 @@ def _check_operator(self, operator):
tree = [op_obj.__name__ for op_obj in getmro(operator.__class__)]
if not any(parent in tree for parent in self._op_parents):
- message = '{0} does not inherit an operator parent.'
+ message = "{0} does not inherit an operator parent."
warn(message.format(str(operator.__class__)))
def _compute_metrics(self):
@@ -250,7 +249,7 @@ def _compute_metrics(self):
"""
kwargs = self.get_notify_observers_kwargs()
- self.notify_observers('cv_metrics', **kwargs)
+ self.notify_observers("cv_metrics", **kwargs)
def _iterations(self, max_iter, progbar=None):
"""Iterate method.
@@ -285,7 +284,7 @@ def _iterations(self, max_iter, progbar=None):
if self.converge:
if self.verbose:
- print(' - Converged!')
+ print(" - Converged!")
break
if progbar:
diff --git a/modopt/opt/algorithms/forward_backward.py b/modopt/opt/algorithms/forward_backward.py
index 702799c6..d34125fa 100644
--- a/modopt/opt/algorithms/forward_backward.py
+++ b/modopt/opt/algorithms/forward_backward.py
@@ -52,12 +52,12 @@ class FISTA(object):
"""
_restarting_strategies = (
- 'adaptive', # option 1 in alg 4
- 'adaptive-i',
- 'adaptive-1',
- 'adaptive-ii', # option 2 in alg 4
- 'adaptive-2',
- 'greedy', # alg 5
+ "adaptive", # option 1 in alg 4
+ "adaptive-i",
+ "adaptive-1",
+ "adaptive-ii", # option 2 in alg 4
+ "adaptive-2",
+ "greedy", # alg 5
None, # no restarting
)
@@ -75,24 +75,27 @@ def __init__(
):
if isinstance(a_cd, type(None)):
- self.mode = 'regular'
+ self.mode = "regular"
self.p_lazy = p_lazy
self.q_lazy = q_lazy
self.r_lazy = r_lazy
elif a_cd > 2:
- self.mode = 'CD'
+ self.mode = "CD"
self.a_cd = a_cd
self._n = 0
else:
raise ValueError(
- 'a_cd must either be None (for regular mode) or a number > 2',
+ "a_cd must either be None (for regular mode) or a number > 2",
)
if restart_strategy in self._restarting_strategies:
self._check_restart_params(
- restart_strategy, min_beta, s_greedy, xi_restart,
+ restart_strategy,
+ min_beta,
+ s_greedy,
+ xi_restart,
)
self.restart_strategy = restart_strategy
self.min_beta = min_beta
@@ -100,10 +103,10 @@ def __init__(
self.xi_restart = xi_restart
else:
- message = 'Restarting strategy must be one of {0}.'
+ message = "Restarting strategy must be one of {0}."
raise ValueError(
message.format(
- ', '.join(self._restarting_strategies),
+ ", ".join(self._restarting_strategies),
),
)
self._t_now = 1.0
@@ -155,22 +158,20 @@ def _check_restart_params(
if restart_strategy is None:
return True
- if self.mode != 'regular':
+ if self.mode != "regular":
raise ValueError(
- 'Restarting strategies can only be used with regular mode.',
+ "Restarting strategies can only be used with regular mode.",
)
- greedy_params_check = (
- min_beta is None or s_greedy is None or s_greedy <= 1
- )
+ greedy_params_check = min_beta is None or s_greedy is None or s_greedy <= 1
- if restart_strategy == 'greedy' and greedy_params_check:
+ if restart_strategy == "greedy" and greedy_params_check:
raise ValueError(
- 'You need a min_beta and an s_greedy > 1 for greedy restart.',
+ "You need a min_beta and an s_greedy > 1 for greedy restart.",
)
if xi_restart is None or xi_restart >= 1:
- raise ValueError('You need a xi_restart < 1 for restart.')
+ raise ValueError("You need a xi_restart < 1 for restart.")
return True
@@ -210,12 +211,12 @@ def is_restart(self, z_old, x_new, x_old):
criterion = xp.vdot(z_old - x_new, x_new - x_old) >= 0
if criterion:
- if 'adaptive' in self.restart_strategy:
+ if "adaptive" in self.restart_strategy:
self.r_lazy *= self.xi_restart
- if self.restart_strategy in {'adaptive-ii', 'adaptive-2'}:
+ if self.restart_strategy in {"adaptive-ii", "adaptive-2"}:
self._t_now = 1
- if self.restart_strategy == 'greedy':
+ if self.restart_strategy == "greedy":
cur_delta = xp.linalg.norm(x_new - x_old)
if self._delta0 is None:
self._delta0 = self.s_greedy * cur_delta
@@ -269,17 +270,17 @@ def update_lambda(self, *args, **kwargs):
Implements steps 3 and 4 from algoritm 10.7 in :cite:`bauschke2009`.
"""
- if self.restart_strategy == 'greedy':
+ if self.restart_strategy == "greedy":
return 2
# Steps 3 and 4 from alg.10.7.
self._t_prev = self._t_now
- if self.mode == 'regular':
- sqrt_part = self.r_lazy * self._t_prev ** 2 + self.q_lazy
+ if self.mode == "regular":
+ sqrt_part = self.r_lazy * self._t_prev**2 + self.q_lazy
self._t_now = self.p_lazy + np.sqrt(sqrt_part) * 0.5
- elif self.mode == 'CD':
+ elif self.mode == "CD":
self._t_now = (self._n + self.a_cd - 1) / self.a_cd
self._n += 1
@@ -344,11 +345,11 @@ def __init__(
x,
grad,
prox,
- cost='auto',
+ cost="auto",
beta_param=1.0,
lambda_param=1.0,
beta_update=None,
- lambda_update='fista',
+ lambda_update="fista",
auto_iterate=True,
metric_call_period=5,
metrics=None,
@@ -376,7 +377,7 @@ def __init__(
self._prox = prox
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -384,7 +385,7 @@ def __init__(
# Check if there is a linear op, needed for metrics in the FB algoritm
if metrics and self._linear is None:
raise ValueError(
- 'When using metrics, you must pass a linear operator',
+ "When using metrics, you must pass a linear operator",
)
if self._linear is None:
@@ -400,7 +401,7 @@ def __init__(
# Set the algorithm parameter update methods
self._check_param_update(beta_update)
self._beta_update = beta_update
- if isinstance(lambda_update, str) and lambda_update == 'fista':
+ if isinstance(lambda_update, str) and lambda_update == "fista":
fista = FISTA(**kwargs)
self._lambda_update = fista.update_lambda
self._is_restart = fista.is_restart
@@ -462,9 +463,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def iterate(self, max_iter=150, progbar=None):
@@ -500,9 +500,9 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._linear.adj_op(self._x_new),
- 'z_new': self._z_new,
- 'idx': self.idx,
+ "x_new": self._linear.adj_op(self._x_new),
+ "z_new": self._z_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -513,7 +513,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -577,7 +577,7 @@ def __init__(
x,
grad,
prox_list,
- cost='auto',
+ cost="auto",
gamma_param=1.0,
lambda_param=1.0,
gamma_update=None,
@@ -609,7 +609,7 @@ def __init__(
self._prox_list = self.xp.array(prox_list)
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad] + prox_list)
else:
self._cost_func = cost
@@ -617,7 +617,7 @@ def __init__(
# Check if there is a linear op, needed for metrics in the FB algoritm
if metrics and self._linear is None:
raise ValueError(
- 'When using metrics, you must pass a linear operator',
+ "When using metrics, you must pass a linear operator",
)
if self._linear is None:
@@ -641,9 +641,7 @@ def __init__(
self._set_weights(weights)
# Set initial z
- self._z = self.xp.array([
- self._x_old for i in range(self._prox_list.size)
- ])
+ self._z = self.xp.array([self._x_old for i in range(self._prox_list.size)])
# Automatically run the algorithm
if auto_iterate:
@@ -673,25 +671,25 @@ def _set_weights(self, weights):
self._prox_list.size,
)
elif not isinstance(weights, (list, tuple, np.ndarray)):
- raise TypeError('Weights must be provided as a list.')
+ raise TypeError("Weights must be provided as a list.")
weights = self.xp.array(weights)
if not np.issubdtype(weights.dtype, np.floating):
- raise ValueError('Weights must be list of float values.')
+ raise ValueError("Weights must be list of float values.")
if weights.size != self._prox_list.size:
raise ValueError(
- 'The number of weights must match the number of proximity '
- + 'operators.',
+ "The number of weights must match the number of proximity "
+ + "operators.",
)
expected_weight_sum = 1.0
if self.xp.sum(weights) != expected_weight_sum:
raise ValueError(
- 'Proximity operator weights must sum to 1.0. Current sum of '
- + 'weights = {0}'.format(self.xp.sum(weights)),
+ "Proximity operator weights must sum to 1.0. Current sum of "
+ + "weights = {0}".format(self.xp.sum(weights)),
)
self._weights = weights
@@ -726,9 +724,7 @@ def _update(self):
# Update z values.
for i in range(self._prox_list.size):
- z_temp = (
- 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad
- )
+ z_temp = 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad
z_prox = self._prox_list[i].op(
z_temp,
extra_factor=self._gamma / self._weights[i],
@@ -784,9 +780,9 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._linear.adj_op(self._x_new),
- 'z_new': self._z,
- 'idx': self.idx,
+ "x_new": self._linear.adj_op(self._x_new),
+ "z_new": self._z,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -797,7 +793,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -871,7 +867,7 @@ def __init__(
z,
grad,
prox,
- cost='auto',
+ cost="auto",
linear=None,
beta_param=1.0,
sigma_bar=1.0,
@@ -905,7 +901,7 @@ def __init__(
self._grad = grad
self._prox = prox
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -918,7 +914,7 @@ def __init__(
for param_val in (beta_param, sigma_bar):
self._check_param(param_val)
if sigma_bar < 0 or sigma_bar > 1:
- raise ValueError('The sigma bar parameter needs to be in [0, 1]')
+ raise ValueError("The sigma bar parameter needs to be in [0, 1]")
self._beta = self.step_size or beta_param
self._sigma_bar = sigma_bar
@@ -944,18 +940,18 @@ def _update(self):
"""
# Step 4 from alg. 3
self._grad.get_grad(self._x_old)
- #self._u_new = self._x_old - self._beta * self._grad.grad
- self._u_new = -self._beta * self._grad.grad
+ # self._u_new = self._x_old - self._beta * self._grad.grad
+ self._u_new = -self._beta * self._grad.grad
self._u_new += self._x_old
# Step 5 from alg. 3
- self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old ** 2))
+ self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old**2))
# Step 6 from alg. 3
t_shifted_ratio = (self._t_old - 1) / self._t_new
sigma_t_ratio = self._sigma * self._t_old / self._t_new
beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi
- self._z = - beta_xi_t_shifted_ratio * (self._x_old - self._z)
+ self._z = -beta_xi_t_shifted_ratio * (self._x_old - self._z)
self._z += self._u_new
self._z += t_shifted_ratio * (self._u_new - self._u_old)
self._z += sigma_t_ratio * (self._u_new - self._x_old)
@@ -968,20 +964,18 @@ def _update(self):
# Restarting and gamma-Decreasing
# Step 9 from alg. 3
- #self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
- self._g_new = (self._z - self._x_new)
+ # self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
+ self._g_new = self._z - self._x_new
self._g_new /= self._xi
self._g_new += self._grad.grad
# Step 10 from alg 3.
- #self._y_new = self._x_old - self._beta * self._g_new
- self._y_new = - self._beta * self._g_new
+ # self._y_new = self._x_old - self._beta * self._g_new
+ self._y_new = -self._beta * self._g_new
self._y_new += self._x_old
# Step 11 from alg. 3
- restart_crit = (
- self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0
- )
+ restart_crit = self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0
if restart_crit:
self._t_new = 1
self._sigma = 1
@@ -999,9 +993,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def iterate(self, max_iter=150, progbar=None):
@@ -1037,14 +1030,14 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'u_new': self._u_new,
- 'x_new': self._linear.adj_op(self._x_new),
- 'y_new': self._y_new,
- 'z_new': self._z,
- 'xi': self._xi,
- 'sigma': self._sigma,
- 't': self._t_new,
- 'idx': self.idx,
+ "u_new": self._u_new,
+ "x_new": self._linear.adj_op(self._x_new),
+ "y_new": self._y_new,
+ "z_new": self._z,
+ "xi": self._xi,
+ "sigma": self._sigma,
+ "t": self._t_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -1055,6 +1048,6 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/algorithms/gradient_descent.py b/modopt/opt/algorithms/gradient_descent.py
index f3fe4b10..d3af1686 100644
--- a/modopt/opt/algorithms/gradient_descent.py
+++ b/modopt/opt/algorithms/gradient_descent.py
@@ -103,7 +103,7 @@ def __init__(
self._check_operator(operator)
self._grad = grad
self._prox = prox
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -157,9 +157,8 @@ def _update(self):
self._eta = self._eta_update(self._eta, self.idx)
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def _update_grad_dir(self, grad):
@@ -208,10 +207,10 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._x_new,
- 'dir_grad': self._dir_grad,
- 'speed_grad': self._speed_grad,
- 'idx': self.idx,
+ "x_new": self._x_new,
+ "dir_grad": self._dir_grad,
+ "speed_grad": self._speed_grad,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -222,7 +221,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -308,7 +307,7 @@ class RMSpropGradOpt(GenericGradOpt):
def __init__(self, *args, gamma=0.5, **kwargs):
super().__init__(*args, **kwargs)
if gamma < 0 or gamma > 1:
- raise ValueError('gamma is outside of range [0,1]')
+ raise ValueError("gamma is outside of range [0,1]")
self._check_param(gamma)
self._gamma = gamma
@@ -405,9 +404,9 @@ def __init__(self, *args, gamma=0.9, beta=0.9, **kwargs):
self._check_param(gamma)
self._check_param(beta)
if gamma < 0 or gamma >= 1:
- raise ValueError('gamma is outside of range [0,1]')
+ raise ValueError("gamma is outside of range [0,1]")
if beta < 0 or beta >= 1:
- raise ValueError('beta is outside of range [0,1]')
+ raise ValueError("beta is outside of range [0,1]")
self._gamma = gamma
self._beta = beta
self._beta_pow = 1
diff --git a/modopt/opt/algorithms/primal_dual.py b/modopt/opt/algorithms/primal_dual.py
index d5bdd431..179ddf95 100644
--- a/modopt/opt/algorithms/primal_dual.py
+++ b/modopt/opt/algorithms/primal_dual.py
@@ -81,7 +81,7 @@ def __init__(
prox,
prox_dual,
linear=None,
- cost='auto',
+ cost="auto",
reweight=None,
rho=0.5,
sigma=1.0,
@@ -123,12 +123,14 @@ def __init__(
self._linear = Identity()
else:
self._linear = linear
- if cost == 'auto':
- self._cost_func = costObj([
- self._grad,
- self._prox,
- self._prox_dual,
- ])
+ if cost == "auto":
+ self._cost_func = costObj(
+ [
+ self._grad,
+ self._prox,
+ self._prox_dual,
+ ]
+ )
else:
self._cost_func = cost
@@ -187,22 +189,17 @@ def _update(self):
self._grad.get_grad(self._x_old)
x_prox = self._prox.op(
- self._x_old - self._tau * self._grad.grad - self._tau
- * self._linear.adj_op(self._y_old),
+ self._x_old
+ - self._tau * self._grad.grad
+ - self._tau * self._linear.adj_op(self._y_old),
)
# Step 2 from eq.9.
- y_temp = (
- self._y_old + self._sigma
- * self._linear.op(2 * x_prox - self._x_old)
- )
+ y_temp = self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old)
- y_prox = (
- y_temp - self._sigma
- * self._prox_dual.op(
- y_temp / self._sigma,
- extra_factor=(1.0 / self._sigma),
- )
+ y_prox = y_temp - self._sigma * self._prox_dual.op(
+ y_temp / self._sigma,
+ extra_factor=(1.0 / self._sigma),
)
# Step 3 from eq.9.
@@ -220,9 +217,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new, self._y_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new, self._y_new
)
def iterate(self, max_iter=150, n_rewightings=1, progbar=None):
@@ -267,7 +263,7 @@ def get_notify_observers_kwargs(self):
The mapping between the iterated variables
"""
- return {'x_new': self._x_new, 'y_new': self._y_new, 'idx': self.idx}
+ return {"x_new": self._x_new, "y_new": self._y_new, "idx": self.idx}
def retrieve_outputs(self):
"""Retrieve outputs.
@@ -277,6 +273,6 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/cost.py b/modopt/opt/cost.py
index 688a3959..4bead130 100644
--- a/modopt/opt/cost.py
+++ b/modopt/opt/cost.py
@@ -115,17 +115,17 @@ def _check_cost(self):
# The mean of the first half of the test list
t1 = xp.mean(
- xp.array(self._test_list[len(self._test_list) // 2:]),
+ xp.array(self._test_list[len(self._test_list) // 2 :]),
axis=0,
)
# The mean of the second half of the test list
t2 = xp.mean(
- xp.array(self._test_list[:len(self._test_list) // 2]),
+ xp.array(self._test_list[: len(self._test_list) // 2]),
axis=0,
)
# Calculate the change across the test list
if xp.around(t1, decimals=16):
- cost_diff = (xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1))
+ cost_diff = xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1)
else:
cost_diff = 0
@@ -133,9 +133,9 @@ def _check_cost(self):
self._test_list = []
if self._verbose:
- print(' - CONVERGENCE TEST - ')
- print(' - CHANGE IN COST:', cost_diff)
- print('')
+ print(" - CONVERGENCE TEST - ")
+ print(" - CHANGE IN COST:", cost_diff)
+ print("")
# Check for convergence
return cost_diff <= self._tolerance
@@ -176,8 +176,7 @@ def get_cost(self, *args, **kwargs):
"""
# Check if the cost should be calculated
test_conditions = (
- self._cost_interval is None
- or self._iteration % self._cost_interval
+ self._cost_interval is None or self._iteration % self._cost_interval
)
if test_conditions:
@@ -185,15 +184,15 @@ def get_cost(self, *args, **kwargs):
else:
if self._verbose:
- print(' - ITERATION:', self._iteration)
+ print(" - ITERATION:", self._iteration)
# Calculate the current cost
self.cost = self._calc_cost(verbose=self._verbose, *args, **kwargs)
self._cost_list.append(self.cost)
if self._verbose:
- print(' - COST:', self.cost)
- print('')
+ print(" - COST:", self.cost)
+ print("")
# Test for convergence
test_result = self._check_cost()
@@ -288,13 +287,11 @@ def _check_operators(self):
"""
if not isinstance(self._operators, (list, tuple, np.ndarray)):
- message = (
- 'Input operators must be provided as a list, not {0}'
- )
+ message = "Input operators must be provided as a list, not {0}"
raise TypeError(message.format(type(self._operators)))
for op in self._operators:
- if not hasattr(op, 'cost'):
+ if not hasattr(op, "cost"):
raise ValueError('Operators must contain "cost" method.')
op.cost = check_callable(op.cost)
diff --git a/modopt/opt/gradient.py b/modopt/opt/gradient.py
index caa8fa9d..8d7bacc7 100644
--- a/modopt/opt/gradient.py
+++ b/modopt/opt/gradient.py
@@ -289,7 +289,7 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - DATA FIDELITY (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - DATA FIDELITY (X):", cost_val)
return cost_val
diff --git a/modopt/opt/linear/base.py b/modopt/opt/linear/base.py
index e347970d..79f05d85 100644
--- a/modopt/opt/linear/base.py
+++ b/modopt/opt/linear/base.py
@@ -5,6 +5,7 @@
from modopt.base.types import check_callable
from modopt.base.backend import get_array_module
+
class LinearParent(object):
"""Linear Operator Parent Class.
@@ -69,7 +70,7 @@ def __init__(self):
self.op = lambda input_data: input_data
self.adj_op = self.op
- self.cost= lambda *args, **kwargs: 0
+ self.cost = lambda *args, **kwargs: 0
class MatrixOperator(LinearParent):
@@ -159,14 +160,13 @@ def _check_type(self, input_val):
"""
if not isinstance(input_val, (list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid input type, input must be a list, tuple or numpy '
- + 'array.',
+ "Invalid input type, input must be a list, tuple or numpy " + "array.",
)
input_val = np.array(input_val)
if not input_val.size:
- raise ValueError('Input list is empty.')
+ raise ValueError("Input list is empty.")
return input_val
@@ -200,10 +200,10 @@ def _check_inputs(self, operators, weights):
for operator in operators:
- if not hasattr(operator, 'op'):
+ if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
- if not hasattr(operator, 'adj_op'):
+ if not hasattr(operator, "adj_op"):
raise ValueError('Operators must contain "adj_op" method.')
operator.op = check_callable(operator.op)
@@ -214,12 +214,11 @@ def _check_inputs(self, operators, weights):
if weights.size != operators.size:
raise ValueError(
- 'The number of weights must match the number of '
- + 'operators.',
+ "The number of weights must match the number of " + "operators.",
)
if not np.issubdtype(weights.dtype, np.floating):
- raise TypeError('The weights must be a list of float values.')
+ raise TypeError("The weights must be a list of float values.")
return operators, weights
diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py
index 5feead66..e1608ff4 100644
--- a/modopt/opt/linear/wavelet.py
+++ b/modopt/opt/linear/wavelet.py
@@ -45,7 +45,7 @@ class WaveletConvolve(LinearParent):
"""
- def __init__(self, filters, method='scipy'):
+ def __init__(self, filters, method="scipy"):
self._filters = check_float(filters)
self.op = lambda input_data: filter_convolve_stack(
@@ -61,8 +61,6 @@ def __init__(self, filters, method='scipy'):
)
-
-
class WaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.
@@ -85,22 +83,28 @@ class WaveletTransform(LinearParent):
**kwargs: extra kwargs for Pywavelet or Pytorch Wavelet
"""
- def __init__(self,
+
+ def __init__(
+ self,
wavelet_name,
shape,
level=4,
mode="symmetric",
compute_backend="numpy",
- **kwargs):
+ **kwargs,
+ ):
if compute_backend == "cupy" and ptwt_available:
- self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode)
+ self.operator = CupyWaveletTransform(
+ wavelet=wavelet_name, shape=shape, level=level, mode=mode
+ )
elif compute_backend == "numpy" and pywt_available:
- self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, level=level, **kwargs)
+ self.operator = CPUWaveletTransform(
+ wavelet_name=wavelet_name, shape=shape, level=level, **kwargs
+ )
else:
raise ValueError(f"Compute Backend {compute_backend} not available")
-
self.op = self.operator.op
self.adj_op = self.operator.adj_op
@@ -108,6 +112,7 @@ def __init__(self,
def coeffs_shape(self):
return self.operator.coeffs_shape
+
class CPUWaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.
@@ -286,7 +291,7 @@ def __init__(
self.level = level
self.shape = shape
self.mode = mode
- self.coeffs_shape = None # will be set after op.
+ self.coeffs_shape = None # will be set after op.
def op(self, data: torch.Tensor) -> list[torch.Tensor]:
"""Apply the wavelet decomposition on.
@@ -419,8 +424,10 @@ def __init__(
self.shape = shape
self.mode = mode
- self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode)
- self.coeffs_shape = None # will be set after op
+ self.operator = TorchWaveletTransform(
+ shape=shape, wavelet=wavelet, level=level, mode=mode
+ )
+ self.coeffs_shape = None # will be set after op
def op(self, data: cp.array) -> cp.ndarray:
"""Define the wavelet operator.
diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py
index fc81a753..b562f77a 100644
--- a/modopt/opt/proximity.py
+++ b/modopt/opt/proximity.py
@@ -140,8 +140,8 @@ def _cost_method(self, *args, **kwargs):
``0.0``
"""
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - Min (X):', np.min(args[0]))
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - Min (X):", np.min(args[0]))
return 0
@@ -167,7 +167,7 @@ class SparseThreshold(ProximityParent):
"""
- def __init__(self, linear, weights, thresh_type='soft'):
+ def __init__(self, linear, weights, thresh_type="soft"):
self._linear = linear
self.weights = weights
@@ -221,8 +221,8 @@ def _cost_method(self, *args, **kwargs):
if isinstance(cost_val, xp.ndarray):
cost_val = cost_val.item()
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - L1 NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - L1 NORM (X):", cost_val)
return cost_val
@@ -273,8 +273,8 @@ class LowRankMatrix(ProximityParent):
def __init__(
self,
threshold,
- thresh_type='soft',
- lowr_type='standard',
+ thresh_type="soft",
+ lowr_type="standard",
initial_rank=None,
operator=None,
):
@@ -315,13 +315,13 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
"""
# Update threshold with extra factor.
threshold = self.thresh * extra_factor
- if self.lowr_type == 'standard' and self.rank is None and rank is None:
+ if self.lowr_type == "standard" and self.rank is None and rank is None:
data_matrix = svd_thresh(
cube2matrix(input_data),
threshold,
thresh_type=self.thresh_type,
)
- elif self.lowr_type == 'standard':
+ elif self.lowr_type == "standard":
data_matrix, update_rank = svd_thresh_coef_fast(
cube2matrix(input_data),
threshold,
@@ -331,7 +331,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
)
self.rank = update_rank # save for future use
- elif self.lowr_type == 'ngole':
+ elif self.lowr_type == "ngole":
data_matrix = svd_thresh_coef(
cube2matrix(input_data),
self.operator,
@@ -339,7 +339,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
thresh_type=self.thresh_type,
)
else:
- raise ValueError('lowr_type should be standard or ngole')
+ raise ValueError("lowr_type should be standard or ngole")
# Return updated data.
return matrix2cube(data_matrix, input_data.shape[1:])
@@ -365,8 +365,8 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = self.thresh * nuclear_norm(cube2matrix(args[0]))
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - NUCLEAR NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - NUCLEAR NORM (X):", cost_val)
return cost_val
@@ -506,19 +506,19 @@ def _check_operators(self, operators):
"""
if not isinstance(operators, (list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid input type, operators must be a list, tuple or '
- + 'numpy array.',
+ "Invalid input type, operators must be a list, tuple or "
+ + "numpy array.",
)
operators = np.array(operators)
if not operators.size:
- raise ValueError('Operator list is empty.')
+ raise ValueError("Operator list is empty.")
for operator in operators:
- if not hasattr(operator, 'op'):
+ if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
- if not hasattr(operator, 'cost'):
+ if not hasattr(operator, "cost"):
raise ValueError('Operators must contain "cost" method.')
operator.op = check_callable(operator.op)
operator.cost = check_callable(operator.cost)
@@ -573,10 +573,12 @@ def _cost_method(self, *args, **kwargs):
Combinded cost components
"""
- return np.sum([
- operator.cost(input_data)
- for operator, input_data in zip(self.operators, args[0])
- ])
+ return np.sum(
+ [
+ operator.cost(input_data)
+ for operator, input_data in zip(self.operators, args[0])
+ ]
+ )
class OrderedWeightedL1Norm(ProximityParent):
@@ -617,16 +619,16 @@ class OrderedWeightedL1Norm(ProximityParent):
def __init__(self, weights):
if not import_sklearn: # pragma: no cover
raise ImportError(
- 'Required version of Scikit-Learn package not found see '
- + 'documentation for details: '
- + 'https://cea-cosmic.github.io/ModOpt/#optional-packages',
+ "Required version of Scikit-Learn package not found see "
+ + "documentation for details: "
+ + "https://cea-cosmic.github.io/ModOpt/#optional-packages",
)
if np.max(np.diff(weights)) > 0:
- raise ValueError('Weights must be non increasing')
+ raise ValueError("Weights must be non increasing")
self.weights = weights.flatten()
if (self.weights < 0).any():
raise ValueError(
- 'The weight values must be provided in descending order',
+ "The weight values must be provided in descending order",
)
self.op = self._op_method
self.cost = self._cost_method
@@ -664,7 +666,9 @@ def _op_method(self, input_data, extra_factor=1.0):
# Projection onto the monotone non-negative cone using
# isotonic_regression
data_abs = isotonic_regression(
- data_abs - threshold, y_min=0, increasing=False,
+ data_abs - threshold,
+ y_min=0,
+ increasing=False,
)
# Unsorting the data
@@ -672,7 +676,7 @@ def _op_method(self, input_data, extra_factor=1.0):
data_abs_unsorted[data_abs_sort_idx] = data_abs
# Putting the sign back
- with np.errstate(invalid='ignore'):
+ with np.errstate(invalid="ignore"):
sign_data = data_squeezed / np.abs(data_squeezed)
# Removing NAN caused by the sign
@@ -702,8 +706,8 @@ def _cost_method(self, *args, **kwargs):
self.weights * np.sort(np.squeeze(np.abs(args[0]))[::-1]),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - OWL NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - OWL NORM (X):", cost_val)
return cost_val
@@ -734,7 +738,7 @@ class Ridge(ProximityParent):
"""
- def __init__(self, linear, weights, thresh_type='soft'):
+ def __init__(self, linear, weights, thresh_type="soft"):
self._linear = linear
self.weights = weights
@@ -786,8 +790,8 @@ def _cost_method(self, *args, **kwargs):
np.abs(self.weights * self._linear.op(args[0]) ** 2),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - L2 NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - L2 NORM (X):", cost_val)
return cost_val
@@ -848,8 +852,8 @@ def _op_method(self, input_data, extra_factor=1.0):
"""
soft_threshold = self.beta * extra_factor
- normalization = (self.alpha * 2 * extra_factor + 1)
- return thresh(input_data, soft_threshold, 'soft') / normalization
+ normalization = self.alpha * 2 * extra_factor + 1
+ return thresh(input_data, soft_threshold, "soft") / normalization
def _cost_method(self, *args, **kwargs):
"""Calculate Ridge component of the cost.
@@ -875,8 +879,8 @@ def _cost_method(self, *args, **kwargs):
+ np.abs(self.beta * self._linear.op(args[0])),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - ELASTIC NET (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - ELASTIC NET (X):", cost_val)
return cost_val
@@ -942,7 +946,7 @@ def k_value(self):
def k_value(self, k_val):
if k_val < 1:
raise ValueError(
- 'The k parameter should be greater or equal than 1',
+ "The k parameter should be greater or equal than 1",
)
self._k_value = k_val
@@ -987,7 +991,7 @@ def _compute_theta(self, input_data, alpha, extra_factor=1.0):
alpha_beta = alpha_input - self.beta * extra_factor
theta = alpha_beta * ((alpha_beta <= 1) & (alpha_beta >= 0))
theta = np.nan_to_num(theta)
- theta += (alpha_input > (self.beta * extra_factor + 1))
+ theta += alpha_input > (self.beta * extra_factor + 1)
return theta
def _interpolate(self, alpha0, alpha1, sum0, sum1):
@@ -1096,11 +1100,11 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
extra_factor,
).sum()
- if (np.abs(sum0 - self._k_value) <= tolerance):
+ if np.abs(sum0 - self._k_value) <= tolerance:
found = True
midpoint = first_idx
- if (np.abs(sum1 - self._k_value) <= tolerance):
+ if np.abs(sum1 - self._k_value) <= tolerance:
found = True
midpoint = last_idx - 1
# -1 because output is index such that
@@ -1145,13 +1149,17 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
if found:
return (
- midpoint, alpha[midpoint], alpha[midpoint + 1], sum0, sum1,
+ midpoint,
+ alpha[midpoint],
+ alpha[midpoint + 1],
+ sum0,
+ sum1,
)
raise ValueError(
- 'Cannot find the coordinate of alpha (i) such '
- + 'that sum(theta(alpha[i])) =< k and '
- + 'sum(theta(alpha[i+1])) >= k ',
+ "Cannot find the coordinate of alpha (i) such "
+ + "that sum(theta(alpha[i])) =< k and "
+ + "sum(theta(alpha[i+1])) >= k ",
)
def _find_alpha(self, input_data, extra_factor=1.0):
@@ -1177,13 +1185,11 @@ def _find_alpha(self, input_data, extra_factor=1.0):
# Computes the alpha^i points line 1 in Algorithm 1.
alpha = np.zeros((data_size * 2))
data_abs = np.abs(input_data)
- alpha[:data_size] = (
- (self.beta * extra_factor)
- / (data_abs + sys.float_info.epsilon)
+ alpha[:data_size] = (self.beta * extra_factor) / (
+ data_abs + sys.float_info.epsilon
)
- alpha[data_size:] = (
- (self.beta * extra_factor + 1)
- / (data_abs + sys.float_info.epsilon)
+ alpha[data_size:] = (self.beta * extra_factor + 1) / (
+ data_abs + sys.float_info.epsilon
)
alpha = np.sort(np.unique(alpha))
@@ -1220,8 +1226,8 @@ def _op_method(self, input_data, extra_factor=1.0):
k_max = np.prod(data_shape)
if self._k_value > k_max:
warn(
- 'K value of the K-support norm is greater than the input '
- + 'dimension, its value will be set to {0}'.format(k_max),
+ "K value of the K-support norm is greater than the input "
+ + "dimension, its value will be set to {0}".format(k_max),
)
self._k_value = k_max
@@ -1233,8 +1239,7 @@ def _op_method(self, input_data, extra_factor=1.0):
# Computes line 5. in Algorithm 1.
rslt = np.nan_to_num(
- (input_data.flatten() * theta)
- / (theta + self.beta * extra_factor),
+ (input_data.flatten() * theta) / (theta + self.beta * extra_factor),
)
return rslt.reshape(data_shape)
@@ -1275,15 +1280,13 @@ def _find_q(self, sorted_data):
found = True
q_val = 0
- elif (
- (sorted_data[self._k_value - 1:].sum())
- <= sorted_data[self._k_value - 1]
- ):
+ elif (sorted_data[self._k_value - 1 :].sum()) <= sorted_data[self._k_value - 1]:
found = True
q_val = self._k_value - 1
while (
- not found and not cnt == self._k_value
+ not found
+ and not cnt == self._k_value
and (first_idx <= last_idx < self._k_value)
):
@@ -1291,9 +1294,7 @@ def _find_q(self, sorted_data):
cnt += 1
l1_part = sorted_data[q_val:].sum() / (self._k_value - q_val)
- if (
- sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]
- ):
+ if sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]:
found = True
else:
@@ -1328,15 +1329,12 @@ def _cost_method(self, *args, **kwargs):
data_abs = data_abs[ix] # Sorted absolute value of the data
q_val = self._find_q(data_abs)
cost_val = (
- (
- np.sum(data_abs[:q_val] ** 2) * 0.5
- + np.sum(data_abs[q_val:]) ** 2
- / (self._k_value - q_val)
- ) * self.beta
- )
+ np.sum(data_abs[:q_val] ** 2) * 0.5
+ + np.sum(data_abs[q_val:]) ** 2 / (self._k_value - q_val)
+ ) * self.beta
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - K-SUPPORT NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - K-SUPPORT NORM (X):", cost_val)
return cost_val
diff --git a/modopt/opt/reweight.py b/modopt/opt/reweight.py
index 8c4f2449..7ff9aac4 100644
--- a/modopt/opt/reweight.py
+++ b/modopt/opt/reweight.py
@@ -81,7 +81,7 @@ def reweight(self, input_data):
"""
if self.verbose:
- print(' - Reweighting: {0}'.format(self._rw_num))
+ print(" - Reweighting: {0}".format(self._rw_num))
self._rw_num += 1
@@ -89,7 +89,7 @@ def reweight(self, input_data):
if input_data.shape != self.weights.shape:
raise ValueError(
- 'Input data must have the same shape as the initial weights.',
+ "Input data must have the same shape as the initial weights.",
)
thresh_weights = self.thresh_factor * self.original_weights
diff --git a/modopt/plot/__init__.py b/modopt/plot/__init__.py
index 28d60be6..da6e096c 100644
--- a/modopt/plot/__init__.py
+++ b/modopt/plot/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['cost_plot']
+__all__ = ["cost_plot"]
diff --git a/modopt/plot/cost_plot.py b/modopt/plot/cost_plot.py
index aa855eaa..36958450 100644
--- a/modopt/plot/cost_plot.py
+++ b/modopt/plot/cost_plot.py
@@ -37,20 +37,20 @@ def plotCost(cost_list, output=None):
"""
if import_fail:
- raise ImportError('Matplotlib package not found')
+ raise ImportError("Matplotlib package not found")
else:
if isinstance(output, type(None)):
- file_name = 'cost_function.png'
+ file_name = "cost_function.png"
else:
- file_name = '{0}_cost_function.png'.format(output)
+ file_name = "{0}_cost_function.png".format(output)
plt.figure()
- plt.plot(np.log10(cost_list), 'r-')
- plt.title('Cost Function')
- plt.xlabel('Iteration')
- plt.ylabel(r'$\log_{10}$ Cost')
+ plt.plot(np.log10(cost_list), "r-")
+ plt.title("Cost Function")
+ plt.xlabel("Iteration")
+ plt.ylabel(r"$\log_{10}$ Cost")
plt.savefig(file_name)
plt.close()
- print(' - Saving cost function data to:', file_name)
+ print(" - Saving cost function data to:", file_name)
diff --git a/modopt/signal/__init__.py b/modopt/signal/__init__.py
index dbc6d053..09b2d2c4 100644
--- a/modopt/signal/__init__.py
+++ b/modopt/signal/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['filter', 'noise', 'positivity', 'svd', 'validation', 'wavelet']
+__all__ = ["filter", "noise", "positivity", "svd", "validation", "wavelet"]
diff --git a/modopt/signal/filter.py b/modopt/signal/filter.py
index 84dd8160..2c7d8626 100644
--- a/modopt/signal/filter.py
+++ b/modopt/signal/filter.py
@@ -81,7 +81,7 @@ def mex_hat(data_point, sigma):
sigma = check_float(sigma)
xs = (data_point / sigma) ** 2
- factor = 2 * (3 * sigma) ** -0.5 * np.pi ** -0.25
+ factor = 2 * (3 * sigma) ** -0.5 * np.pi**-0.25
return factor * (1 - xs) * np.exp(-0.5 * xs)
diff --git a/modopt/signal/noise.py b/modopt/signal/noise.py
index a59d5553..fadf5308 100644
--- a/modopt/signal/noise.py
+++ b/modopt/signal/noise.py
@@ -15,7 +15,7 @@
from modopt.base.backend import get_array_module
-def add_noise(input_data, sigma=1.0, noise_type='gauss'):
+def add_noise(input_data, sigma=1.0, noise_type="gauss"):
"""Add noise to data.
This method adds Gaussian or Poisson noise to the input data.
@@ -70,7 +70,7 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
"""
input_data = np.array(input_data)
- if noise_type not in {'gauss', 'poisson'}:
+ if noise_type not in {"gauss", "poisson"}:
raise ValueError(
'Invalid noise type. Options are "gauss" or "poisson"',
)
@@ -78,14 +78,13 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
if isinstance(sigma, (list, tuple, np.ndarray)):
if len(sigma) != input_data.shape[0]:
raise ValueError(
- 'Number of sigma values must match first dimension of input '
- + 'data',
+ "Number of sigma values must match first dimension of input " + "data",
)
- if noise_type == 'gauss':
+ if noise_type == "gauss":
random = np.random.randn(*input_data.shape)
- elif noise_type == 'poisson':
+ elif noise_type == "poisson":
random = np.random.poisson(np.abs(input_data))
if isinstance(sigma, (int, float)):
@@ -96,7 +95,7 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
return input_data + noise
-def thresh(input_data, threshold, threshold_type='hard'):
+def thresh(input_data, threshold, threshold_type="hard"):
r"""Threshold data.
This method perfoms hard or soft thresholding on the input data.
@@ -169,12 +168,12 @@ def thresh(input_data, threshold, threshold_type='hard'):
input_data = xp.array(input_data)
- if threshold_type not in {'hard', 'soft'}:
+ if threshold_type not in {"hard", "soft"}:
raise ValueError(
'Invalid threshold type. Options are "hard" or "soft"',
)
- if threshold_type == 'soft':
+ if threshold_type == "soft":
denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data))
max_value = xp.maximum((1.0 - threshold / denominator), 0)
diff --git a/modopt/signal/positivity.py b/modopt/signal/positivity.py
index c19ba62c..5c4b795b 100644
--- a/modopt/signal/positivity.py
+++ b/modopt/signal/positivity.py
@@ -47,7 +47,7 @@ def pos_recursive(input_data):
Positive coefficients
"""
- if input_data.dtype == 'O':
+ if input_data.dtype == "O":
res = np.array([pos_recursive(elem) for elem in input_data], dtype="object")
else:
@@ -97,15 +97,15 @@ def positive(input_data, ragged=False):
"""
if not isinstance(input_data, (int, float, list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid data type, input must be `int`, `float`, `list`, '
- + '`tuple` or `np.ndarray`.',
+ "Invalid data type, input must be `int`, `float`, `list`, "
+ + "`tuple` or `np.ndarray`.",
)
if isinstance(input_data, (int, float)):
return pos_thresh(input_data)
if ragged:
- input_data = np.array(input_data, dtype='object')
+ input_data = np.array(input_data, dtype="object")
else:
input_data = np.array(input_data)
diff --git a/modopt/signal/svd.py b/modopt/signal/svd.py
index f3d40a51..cc204817 100644
--- a/modopt/signal/svd.py
+++ b/modopt/signal/svd.py
@@ -52,8 +52,8 @@ def find_n_pc(u_vec, factor=0.5):
"""
if np.sqrt(u_vec.shape[0]) % 1:
raise ValueError(
- 'Invalid left singular vector. The size of the first '
- + 'dimenion of ``u_vec`` must be perfect square.',
+ "Invalid left singular vector. The size of the first "
+ + "dimenion of ``u_vec`` must be perfect square.",
)
# Get the shape of the array
@@ -69,13 +69,12 @@ def find_n_pc(u_vec, factor=0.5):
]
# Return the required number of principal components.
- return np.sum([
- (
- u_val[tuple(zip(array_shape // 2))] ** 2 <= factor
- * np.sum(u_val ** 2),
- )
- for u_val in u_auto
- ])
+ return np.sum(
+ [
+ (u_val[tuple(zip(array_shape // 2))] ** 2 <= factor * np.sum(u_val**2),)
+ for u_val in u_auto
+ ]
+ )
def calculate_svd(input_data):
@@ -101,17 +100,17 @@ def calculate_svd(input_data):
"""
if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
- raise TypeError('Input data must be a 2D np.ndarray.')
+ raise TypeError("Input data must be a 2D np.ndarray.")
return svd(
input_data,
check_finite=False,
- lapack_driver='gesvd',
+ lapack_driver="gesvd",
full_matrices=False,
)
-def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
+def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type="hard"):
"""Threshold the singular values.
This method thresholds the input data using singular value decomposition.
@@ -156,16 +155,11 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
"""
less_than_zero = isinstance(n_pc, int) and n_pc <= 0
- str_not_all = isinstance(n_pc, str) and n_pc != 'all'
+ str_not_all = isinstance(n_pc, str) and n_pc != "all"
- if (
- (not isinstance(n_pc, (int, str, type(None))))
- or less_than_zero
- or str_not_all
- ):
+ if (not isinstance(n_pc, (int, str, type(None)))) or less_than_zero or str_not_all:
raise ValueError(
- 'Invalid value for "n_pc", specify a positive integer value or '
- + '"all"',
+ 'Invalid value for "n_pc", specify a positive integer value or ' + '"all"',
)
# Get SVD of input data.
@@ -176,15 +170,14 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
# Find the required number of principal components if not specified.
if isinstance(n_pc, type(None)):
n_pc = find_n_pc(u_vec, factor=0.1)
- print('xxxx', n_pc, u_vec)
+ print("xxxx", n_pc, u_vec)
# If the number of PCs is too large use all of the singular values.
- if (
- (isinstance(n_pc, int) and n_pc >= s_values.size)
- or (isinstance(n_pc, str) and n_pc == 'all')
+ if (isinstance(n_pc, int) and n_pc >= s_values.size) or (
+ isinstance(n_pc, str) and n_pc == "all"
):
n_pc = s_values.size
- warn('Using all singular values.')
+ warn("Using all singular values.")
threshold = s_values[n_pc - 1]
@@ -192,7 +185,7 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
s_new = thresh(s_values, threshold, thresh_type)
if np.all(s_new == s_values):
- warn('No change to singular values.')
+ warn("No change to singular values.")
# Diagonalize the svd
s_new = np.diag(s_new)
@@ -206,7 +199,7 @@ def svd_thresh_coef_fast(
threshold,
n_vals=-1,
extra_vals=5,
- thresh_type='hard',
+ thresh_type="hard",
):
"""Threshold the singular values coefficients.
@@ -241,7 +234,7 @@ def svd_thresh_coef_fast(
ok = False
while not ok:
(u_vec, s_values, v_vec) = svds(input_data, k=n_vals)
- ok = (s_values[0] <= threshold or n_vals == min(input_data.shape) - 1)
+ ok = s_values[0] <= threshold or n_vals == min(input_data.shape) - 1
n_vals = min(n_vals + extra_vals, *input_data.shape)
s_values = thresh(
@@ -259,7 +252,7 @@ def svd_thresh_coef_fast(
)
-def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
+def svd_thresh_coef(input_data, operator, threshold, thresh_type="hard"):
"""Threshold the singular values coefficients.
This method thresholds the input data using singular value decomposition.
@@ -287,7 +280,7 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
"""
if not callable(operator):
- raise TypeError('Operator must be a callable function.')
+ raise TypeError("Operator must be a callable function.")
# Get SVD of data matrix
u_vec, s_values, v_vec = calculate_svd(input_data)
@@ -302,10 +295,9 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2)
# Compute threshold matrix.
- ti = np.array([
- np.linalg.norm(elem)
- for elem in operator(matrix2cube(u_vec, array_shape))
- ])
+ ti = np.array(
+ [np.linalg.norm(elem) for elem in operator(matrix2cube(u_vec, array_shape))]
+ )
threshold *= np.repeat(ti, a_matrix.shape[1]).reshape(a_matrix.shape)
# Threshold coefficients.
diff --git a/modopt/signal/validation.py b/modopt/signal/validation.py
index 422a987b..68c1e726 100644
--- a/modopt/signal/validation.py
+++ b/modopt/signal/validation.py
@@ -54,7 +54,7 @@ def transpose_test(
"""
if not callable(operator) or not callable(operator_t):
- raise TypeError('The input operators must be callable functions.')
+ raise TypeError("The input operators must be callable functions.")
if isinstance(y_shape, type(None)):
y_shape = x_shape
@@ -73,4 +73,4 @@ def transpose_test(
x_mty = np.sum(np.multiply(x_val, operator_t(y_val, y_args)))
# Test the difference between the two.
- print(' - | - | =', np.abs(mx_y - x_mty))
+ print(" - | - | =", np.abs(mx_y - x_mty))
diff --git a/modopt/signal/wavelet.py b/modopt/signal/wavelet.py
index bc4ffc70..72d608e7 100644
--- a/modopt/signal/wavelet.py
+++ b/modopt/signal/wavelet.py
@@ -58,20 +58,20 @@ def execute(command_line):
"""
if not isinstance(command_line, str):
- raise TypeError('Command line must be a string.')
+ raise TypeError("Command line must be a string.")
command = command_line.split()
process = sp.Popen(command, stdout=sp.PIPE, stderr=sp.PIPE)
stdout, stderr = process.communicate()
- return stdout.decode('utf-8'), stderr.decode('utf-8')
+ return stdout.decode("utf-8"), stderr.decode("utf-8")
def call_mr_transform(
input_data,
- opt='',
- path='./',
+ opt="",
+ path="./",
remove_files=True,
): # pragma: no cover
"""Call ``mr_transform``.
@@ -127,26 +127,23 @@ def call_mr_transform(
"""
if not import_astropy:
- raise ImportError('Astropy package not found.')
+ raise ImportError("Astropy package not found.")
if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
- raise ValueError('Input data must be a 2D numpy array.')
+ raise ValueError("Input data must be a 2D numpy array.")
- executable = 'mr_transform'
+ executable = "mr_transform"
# Make sure mr_transform is installed.
is_executable(executable)
# Create a unique string using the current date and time.
- unique_string = (
- datetime.now().strftime('%Y.%m.%d_%H.%M.%S')
- + str(getrandbits(128))
- )
+ unique_string = datetime.now().strftime("%Y.%m.%d_%H.%M.%S") + str(getrandbits(128))
# Set the ouput file names.
- file_name = '{0}mr_temp_{1}'.format(path, unique_string)
- file_fits = '{0}.fits'.format(file_name)
- file_mr = '{0}.mr'.format(file_name)
+ file_name = "{0}mr_temp_{1}".format(path, unique_string)
+ file_fits = "{0}.fits".format(file_name)
+ file_mr = "{0}.mr".format(file_name)
# Write the input data to a fits file.
fits.writeto(file_fits, input_data)
@@ -155,15 +152,15 @@ def call_mr_transform(
opt = opt.split()
# Prepare command and execute it
- command_line = ' '.join([executable] + opt + [file_fits, file_mr])
+ command_line = " ".join([executable] + opt + [file_fits, file_mr])
stdout, _ = execute(command_line)
# Check for errors
- if any(word in stdout for word in ('bad', 'Error', 'Sorry')):
+ if any(word in stdout for word in ("bad", "Error", "Sorry")):
remove(file_fits)
message = '{0} raised following exception: "{1}"'
raise RuntimeError(
- message.format(executable, stdout.rstrip('\n')),
+ message.format(executable, stdout.rstrip("\n")),
)
# Retrieve wavelet transformed data.
@@ -198,12 +195,12 @@ def trim_filter(filter_array):
min_idx = np.min(non_zero_indices, axis=-1)
max_idx = np.max(non_zero_indices, axis=-1)
- return filter_array[min_idx[0]:max_idx[0] + 1, min_idx[1]:max_idx[1] + 1]
+ return filter_array[min_idx[0] : max_idx[0] + 1, min_idx[1] : max_idx[1] + 1]
def get_mr_filters(
data_shape,
- opt='',
+ opt="",
coarse=False,
trim=False,
): # pragma: no cover
@@ -256,7 +253,7 @@ def get_mr_filters(
return mr_filters[:-1]
-def filter_convolve(input_data, filters, filter_rot=False, method='scipy'):
+def filter_convolve(input_data, filters, filter_rot=False, method="scipy"):
"""Filter convolve.
This method convolves the input image with the wavelet filters.
@@ -315,16 +312,14 @@ def filter_convolve(input_data, filters, filter_rot=False, method='scipy'):
axis=0,
)
- return np.array([
- convolve(input_data, filt, method=method) for filt in filters
- ])
+ return np.array([convolve(input_data, filt, method=method) for filt in filters])
def filter_convolve_stack(
input_data,
filters,
filter_rot=False,
- method='scipy',
+ method="scipy",
):
"""Filter convolve.
@@ -366,7 +361,9 @@ def filter_convolve_stack(
"""
# Return the convolved data cube.
- return np.array([
- filter_convolve(elem, filters, filter_rot=filter_rot, method=method)
- for elem in input_data
- ])
+ return np.array(
+ [
+ filter_convolve(elem, filters, filter_rot=filter_rot, method=method)
+ for elem in input_data
+ ]
+ )
diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py
index 5671b8e3..ca0cd666 100644
--- a/modopt/tests/test_algorithms.py
+++ b/modopt/tests/test_algorithms.py
@@ -111,8 +111,7 @@ class AlgoCases:
]
)
def case_forward_backward(self, kwargs, idty, use_metrics):
- """Forward Backward case.
- """
+ """Forward Backward case."""
update_kwargs = build_kwargs(kwargs, use_metrics)
algo = algorithms.ForwardBackward(
self.data1,
@@ -242,9 +241,11 @@ def case_grad(self, GradDescent, use_metrics, idty):
)
algo.iterate()
return algo, update_kwargs
- @parametrize(admm=[algorithms.ADMM,algorithms.FastADMM])
+
+ @parametrize(admm=[algorithms.ADMM, algorithms.FastADMM])
def case_admm(self, admm, use_metrics, idty):
"""ADMM setup."""
+
def optim1(init, obs):
return obs
@@ -265,6 +266,7 @@ def optim2(init, obs):
algo.iterate()
return algo, update_kwargs
+
@parametrize_with_cases("algo, kwargs", cases=AlgoCases)
def test_algo(algo, kwargs):
"""Test algorithms."""
diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py
index e32ff94b..298253d6 100644
--- a/modopt/tests/test_base.py
+++ b/modopt/tests/test_base.py
@@ -5,6 +5,7 @@
Samuel Farrens
Pierre-Antoine Comby
"""
+
import numpy as np
import numpy.testing as npt
import pytest
diff --git a/modopt/tests/test_helpers/utils.py b/modopt/tests/test_helpers/utils.py
index d8227640..895b2371 100644
--- a/modopt/tests/test_helpers/utils.py
+++ b/modopt/tests/test_helpers/utils.py
@@ -4,6 +4,7 @@
:Author: Pierre-Antoine Comby
"""
+
import pytest
diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py
index ea177b15..e7536b03 100644
--- a/modopt/tests/test_math.py
+++ b/modopt/tests/test_math.py
@@ -6,6 +6,7 @@
Samuel Farrens
Pierre-Antoine Comby
"""
+
import pytest
from test_helpers import failparam, skipparam
diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py
index 1c2e7824..e77074ab 100644
--- a/modopt/tests/test_opt.py
+++ b/modopt/tests/test_opt.py
@@ -193,9 +193,13 @@ def case_linear_wavelet_transform(self, compute_backend="numpy"):
level=2,
)
data_op = np.arange(64).reshape(8, 8).astype(float)
- res_op, slices, shapes = pywt.ravel_coeffs(pywt.wavedecn(data_op, "haar", level=2))
+ res_op, slices, shapes = pywt.ravel_coeffs(
+ pywt.wavedecn(data_op, "haar", level=2)
+ )
data_adj_op = linop.op(data_op)
- res_adj_op = pywt.waverecn(pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar")
+ res_adj_op = pywt.waverecn(
+ pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar"
+ )
return linop, data_op, data_adj_op, res_op, res_adj_op
@parametrize(weights=[[1.0, 1.0], None])
diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py
index 202e541b..b3787fc6 100644
--- a/modopt/tests/test_signal.py
+++ b/modopt/tests/test_signal.py
@@ -17,6 +17,7 @@
class TestFilter:
"""Test filter module"""
+
@pytest.mark.parametrize(
("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)]
)
@@ -24,7 +25,6 @@ def test_gaussian_filter(self, norm, result):
"""Test gaussian filter."""
npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result)
-
def test_mex_hat(self):
"""Test mexican hat filter."""
npt.assert_almost_equal(
@@ -32,7 +32,6 @@ def test_mex_hat(self):
-0.35213905225713371,
)
-
def test_mex_hat_dir(self):
"""Test directional mexican hat filter."""
npt.assert_almost_equal(
@@ -86,13 +85,16 @@ def test_thresh(self, threshold_type, result):
noise.thresh(self.data1, 5, threshold_type=threshold_type), result
)
+
class TestPositivity:
"""Test positivity module."""
+
data1 = np.arange(9).reshape(3, 3).astype(float)
data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
data5 = np.array(
[[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]],
)
+
@pytest.mark.parametrize(
("value", "expected"),
[
@@ -231,6 +233,7 @@ def test_svd_thresh_coef(self, data, operator):
# TODO test_svd_thresh_coef_fast
+
class TestValidation:
"""Test validation Module."""
From 067c40c83972990099fa95c2d9d5d7338dc5d3fe Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 21:22:47 +0100
Subject: [PATCH 15/45] run ruff --fix --unsafe-fixes .
---
docs/source/conf.py | 23 +++++++++----------
modopt/__init__.py | 1 -
modopt/base/__init__.py | 1 -
modopt/base/backend.py | 1 -
modopt/base/np_adjust.py | 1 -
modopt/base/observable.py | 7 +++---
modopt/base/transform.py | 5 ++--
modopt/base/types.py | 3 +--
.../example_lasso_forward_backward.py | 1 -
modopt/interface/__init__.py | 1 -
modopt/interface/errors.py | 13 +++++------
modopt/interface/log.py | 3 +--
modopt/math/__init__.py | 1 -
modopt/math/convolve.py | 1 -
modopt/math/matrix.py | 3 +--
modopt/math/metrics.py | 3 +--
modopt/math/stats.py | 1 -
modopt/opt/__init__.py | 1 -
modopt/opt/algorithms/__init__.py | 19 ---------------
modopt/opt/algorithms/base.py | 3 +--
modopt/opt/algorithms/forward_backward.py | 9 ++++----
modopt/opt/algorithms/gradient_descent.py | 1 -
modopt/opt/algorithms/primal_dual.py | 1 -
modopt/opt/gradient.py | 5 ++--
modopt/opt/linear/base.py | 2 +-
modopt/opt/proximity.py | 21 ++++++++---------
modopt/opt/reweight.py | 5 ++--
modopt/plot/__init__.py | 1 -
modopt/plot/cost_plot.py | 3 +--
modopt/signal/__init__.py | 1 -
modopt/signal/filter.py | 1 -
modopt/signal/noise.py | 2 --
modopt/signal/positivity.py | 1 -
modopt/signal/svd.py | 1 -
modopt/signal/wavelet.py | 9 ++++----
modopt/tests/test_algorithms.py | 8 +------
modopt/tests/test_helpers/__init__.py | 1 -
modopt/tests/test_opt.py | 12 ++++++----
modopt/tests/test_signal.py | 2 +-
39 files changed, 61 insertions(+), 117 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 987576a9..cd39ee08 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Python Template sphinx config
# Import relevant modules
@@ -19,7 +18,7 @@
mdata = metadata(project)
author = mdata["Author"]
version = mdata["Version"]
-copyright = "2020, {}".format(author)
+copyright = f"2020, {author}"
gh_user = "sfarrens"
# If your documentation needs a minimal Sphinx version, state it here.
@@ -117,7 +116,7 @@
)
# The name for this set of Sphinx documents. If None, it defaults to
# " v documentation".
-html_title = "{0} v{1}".format(project, version)
+html_title = f"{project} v{version}"
# A shorter title for the navigation bar. Default is the same as html_title.
# html_short_title = None
@@ -174,20 +173,20 @@ def add_notebooks(nb_path="../../notebooks"):
nb_name = nb.rstrip(nb_ext)
nb_link_file_name = nb_name + ".nblink"
- print("Writing {0}".format(nb_link_file_name))
+ print(f"Writing {nb_link_file_name}")
with open(nb_link_file_name, "w") as nb_link_file:
nb_link_file.write(nb_link_format.format(nb_path, nb))
- print("Looking for {0} in {1}".format(nb_name, nb_rst_file_name))
- with open(nb_rst_file_name, "r") as nb_rst_file:
+ print(f"Looking for {nb_name} in {nb_rst_file_name}")
+ with open(nb_rst_file_name) as nb_rst_file:
check_name = nb_name not in nb_rst_file.read()
if check_name:
- print("Adding {0} to {1}".format(nb_name, nb_rst_file_name))
+ print(f"Adding {nb_name} to {nb_rst_file_name}")
with open(nb_rst_file_name, "a") as nb_rst_file:
if list_pos == 0:
nb_rst_file.write("\n")
- nb_rst_file.write(" {0}\n".format(nb_name))
+ nb_rst_file.write(f" {nb_name}\n")
return nbs
@@ -220,14 +219,14 @@ def add_notebooks(nb_path="../../notebooks"):
"""
nb_header_pt2 = (
r""" """
r""""""
)
diff --git a/modopt/__init__.py b/modopt/__init__.py
index d446e15d..958f3ace 100644
--- a/modopt/__init__.py
+++ b/modopt/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""MODOPT PACKAGE.
diff --git a/modopt/base/__init__.py b/modopt/base/__init__.py
index d75ff315..e7df6c37 100644
--- a/modopt/base/__init__.py
+++ b/modopt/base/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""BASE ROUTINES.
diff --git a/modopt/base/backend.py b/modopt/base/backend.py
index fd933ebb..b4987942 100644
--- a/modopt/base/backend.py
+++ b/modopt/base/backend.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""BACKEND MODULE.
diff --git a/modopt/base/np_adjust.py b/modopt/base/np_adjust.py
index 31a785f5..586a1ee0 100644
--- a/modopt/base/np_adjust.py
+++ b/modopt/base/np_adjust.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""NUMPY ADJUSTMENT ROUTINES.
diff --git a/modopt/base/observable.py b/modopt/base/observable.py
index 2f69a1a7..69c6b238 100644
--- a/modopt/base/observable.py
+++ b/modopt/base/observable.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Observable.
@@ -13,13 +12,13 @@
import numpy as np
-class SignalObject(object):
+class SignalObject:
"""Dummy class for signals."""
pass
-class Observable(object):
+class Observable:
"""Base class for observable classes.
This class defines a simple interface to add or remove observers
@@ -177,7 +176,7 @@ def _remove_observer(self, signal, observer):
self._observers[signal].remove(observer)
-class MetricObserver(object):
+class MetricObserver:
"""Metric observer.
Wrapper of the metric to the observer object notify by the Observable
diff --git a/modopt/base/transform.py b/modopt/base/transform.py
index fedd5efb..1dc9039a 100644
--- a/modopt/base/transform.py
+++ b/modopt/base/transform.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""DATA TRANSFORM ROUTINES.
@@ -288,7 +287,7 @@ def cube2matrix(data_cube):
"""
return data_cube.reshape(
- [data_cube.shape[0]] + [np.prod(data_cube.shape[1:])],
+ [data_cube.shape[0], np.prod(data_cube.shape[1:])],
).T
@@ -333,4 +332,4 @@ def matrix2cube(data_matrix, im_shape):
cube2matrix : complimentary function
"""
- return data_matrix.T.reshape([data_matrix.shape[1]] + list(im_shape))
+ return data_matrix.T.reshape([data_matrix.shape[1], *list(im_shape)])
diff --git a/modopt/base/types.py b/modopt/base/types.py
index 7ea805ad..5ed24ec3 100644
--- a/modopt/base/types.py
+++ b/modopt/base/types.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""TYPE HANDLING ROUTINES.
@@ -165,7 +164,7 @@ def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True):
):
raise (
TypeError(
- "The numpy array elements are not of type: {0}".format(dtype),
+ f"The numpy array elements are not of type: {dtype}",
),
)
diff --git a/modopt/examples/example_lasso_forward_backward.py b/modopt/examples/example_lasso_forward_backward.py
index c28b0499..7e650e05 100644
--- a/modopt/examples/example_lasso_forward_backward.py
+++ b/modopt/examples/example_lasso_forward_backward.py
@@ -1,4 +1,3 @@
-# noqa: D205
"""
Solving the LASSO Problem with the Forward Backward Algorithm.
==============================================================
diff --git a/modopt/interface/__init__.py b/modopt/interface/__init__.py
index 55904ca1..529816ee 100644
--- a/modopt/interface/__init__.py
+++ b/modopt/interface/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""INTERFACE ROUTINES.
diff --git a/modopt/interface/errors.py b/modopt/interface/errors.py
index eb4aa4ca..5c84ad0e 100644
--- a/modopt/interface/errors.py
+++ b/modopt/interface/errors.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""ERROR HANDLING ROUTINES.
@@ -39,7 +38,7 @@ def warn(warn_string, log=None):
warn_txt = colored("WARNING", "yellow")
# Print warning to stdout.
- sys.stderr.write("{0}: {1}\n".format(warn_txt, warn_string))
+ sys.stderr.write(f"{warn_txt}: {warn_string}\n")
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
@@ -66,12 +65,12 @@ def catch_error(exception, log=None):
err_txt = colored("ERROR", "red")
# Print exception to stdout.
- stream_txt = "{0}: {1}\n".format(err_txt, exception)
+ stream_txt = f"{err_txt}: {exception}\n"
sys.stderr.write(stream_txt)
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- log_txt = "ERROR: {0}\n".format(exception)
+ log_txt = f"ERROR: {exception}\n"
log.exception(log_txt)
@@ -92,10 +91,10 @@ def file_name_error(file_name):
"""
if file_name == "" or file_name[0][0] == "-":
- raise IOError("Input file name not specified.")
+ raise OSError("Input file name not specified.")
elif not os.path.isfile(file_name):
- raise IOError("Input file name {0} not found!".format(file_name))
+ raise OSError(f"Input file name {file_name} not found!")
def is_exe(fpath):
@@ -151,4 +150,4 @@ def is_executable(exe_name):
if not res:
message = "{0} does not appear to be a valid executable on this system."
- raise IOError(message.format(exe_name))
+ raise OSError(message.format(exe_name))
diff --git a/modopt/interface/log.py b/modopt/interface/log.py
index a02428d9..d3e0d8e9 100644
--- a/modopt/interface/log.py
+++ b/modopt/interface/log.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""LOGGING ROUTINES.
@@ -30,7 +29,7 @@ def set_up_log(filename, verbose=True):
"""
# Add file extension.
- filename = "{0}.log".format(filename)
+ filename = f"{filename}.log"
if verbose:
print("Preparing log file:", filename)
diff --git a/modopt/math/__init__.py b/modopt/math/__init__.py
index 8e92aa50..0423a333 100644
--- a/modopt/math/__init__.py
+++ b/modopt/math/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""MATHEMATICS ROUTINES.
diff --git a/modopt/math/convolve.py b/modopt/math/convolve.py
index 528b2338..ac1cf84c 100644
--- a/modopt/math/convolve.py
+++ b/modopt/math/convolve.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""CONVOLUTION ROUTINES.
diff --git a/modopt/math/matrix.py b/modopt/math/matrix.py
index 6ddb3f2f..a2419a6c 100644
--- a/modopt/math/matrix.py
+++ b/modopt/math/matrix.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""MATRIX ROUTINES.
@@ -257,7 +256,7 @@ def rotate(matrix, angle):
return matrix[tuple(zip(new_index.T))].reshape(shape.T)
-class PowerMethod(object):
+class PowerMethod:
"""Power method class.
This method performs implements power method to calculate the spectral
diff --git a/modopt/math/metrics.py b/modopt/math/metrics.py
index 93f7ce06..8f797f02 100644
--- a/modopt/math/metrics.py
+++ b/modopt/math/metrics.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""METRICS.
@@ -268,6 +267,6 @@ def nrmse(test, ref, mask=None):
ref = mask * ref
num = np.sqrt(mse(test, ref))
- deno = np.sqrt(np.mean((np.square(test))))
+ deno = np.sqrt(np.mean(np.square(test)))
return num / deno
diff --git a/modopt/math/stats.py b/modopt/math/stats.py
index 59bf6759..b3ee0d8b 100644
--- a/modopt/math/stats.py
+++ b/modopt/math/stats.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""STATISTICS ROUTINES.
diff --git a/modopt/opt/__init__.py b/modopt/opt/__init__.py
index 8b285bee..86564f90 100644
--- a/modopt/opt/__init__.py
+++ b/modopt/opt/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""OPTIMISATION PROBLEM MODULES.
diff --git a/modopt/opt/algorithms/__init__.py b/modopt/opt/algorithms/__init__.py
index 6a29325a..ce6c5e56 100644
--- a/modopt/opt/algorithms/__init__.py
+++ b/modopt/opt/algorithms/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
r"""OPTIMISATION ALGORITHMS.
This module contains class implementations of various optimisation algoritms.
@@ -45,21 +44,3 @@
"""
-from modopt.opt.algorithms.base import SetUp
-from modopt.opt.algorithms.forward_backward import (
- FISTA,
- POGM,
- ForwardBackward,
- GenForwardBackward,
-)
-from modopt.opt.algorithms.gradient_descent import (
- AdaGenericGradOpt,
- ADAMGradOpt,
- GenericGradOpt,
- MomentumGradOpt,
- RMSpropGradOpt,
- SAGAOptGradOpt,
- VanillaGenericGradOpt,
-)
-from modopt.opt.algorithms.primal_dual import Condat
-from modopt.opt.algorithms.admm import ADMM, FastADMM
diff --git a/modopt/opt/algorithms/base.py b/modopt/opt/algorithms/base.py
index e2b9017d..dbb73be0 100644
--- a/modopt/opt/algorithms/base.py
+++ b/modopt/opt/algorithms/base.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Base SetUp for optimisation algorithms."""
from inspect import getmro
@@ -118,7 +117,7 @@ def metrics(self, metrics):
self._metrics = metrics
else:
raise TypeError(
- "Metrics must be a dictionary, not {0}.".format(type(metrics)),
+ f"Metrics must be a dictionary, not {type(metrics)}.",
)
def any_convergence_flag(self):
diff --git a/modopt/opt/algorithms/forward_backward.py b/modopt/opt/algorithms/forward_backward.py
index d34125fa..4c1cb35c 100644
--- a/modopt/opt/algorithms/forward_backward.py
+++ b/modopt/opt/algorithms/forward_backward.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Forward-Backward Algorithms."""
import numpy as np
@@ -9,7 +8,7 @@
from modopt.opt.linear import Identity
-class FISTA(object):
+class FISTA:
r"""FISTA.
This class is inherited by optimisation classes to speed up convergence
@@ -602,7 +601,7 @@ def __init__(
self._x_old = self.xp.copy(x)
# Set the algorithm operators
- for operator in [grad, cost] + prox_list:
+ for operator in [grad, cost, *prox_list]:
self._check_operator(operator)
self._grad = grad
@@ -610,7 +609,7 @@ def __init__(
self._linear = linear
if cost == "auto":
- self._cost_func = costObj([self._grad] + prox_list)
+ self._cost_func = costObj([self._grad, *prox_list])
else:
self._cost_func = cost
@@ -689,7 +688,7 @@ def _set_weights(self, weights):
if self.xp.sum(weights) != expected_weight_sum:
raise ValueError(
"Proximity operator weights must sum to 1.0. Current sum of "
- + "weights = {0}".format(self.xp.sum(weights)),
+ + f"weights = {self.xp.sum(weights)}",
)
self._weights = weights
diff --git a/modopt/opt/algorithms/gradient_descent.py b/modopt/opt/algorithms/gradient_descent.py
index d3af1686..0960be5a 100644
--- a/modopt/opt/algorithms/gradient_descent.py
+++ b/modopt/opt/algorithms/gradient_descent.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Gradient Descent Algorithms."""
import numpy as np
diff --git a/modopt/opt/algorithms/primal_dual.py b/modopt/opt/algorithms/primal_dual.py
index 179ddf95..24908993 100644
--- a/modopt/opt/algorithms/primal_dual.py
+++ b/modopt/opt/algorithms/primal_dual.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Primal-Dual Algorithms."""
from modopt.opt.algorithms.base import SetUp
diff --git a/modopt/opt/gradient.py b/modopt/opt/gradient.py
index 8d7bacc7..bd214f21 100644
--- a/modopt/opt/gradient.py
+++ b/modopt/opt/gradient.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""GRADIENT CLASSES.
@@ -14,7 +13,7 @@
from modopt.base.types import check_callable, check_float, check_npndarray
-class GradParent(object):
+class GradParent:
"""Gradient Parent Class.
This class defines the basic methods that will be inherited by specific
@@ -289,7 +288,7 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - DATA FIDELITY (X):", cost_val)
return cost_val
diff --git a/modopt/opt/linear/base.py b/modopt/opt/linear/base.py
index 79f05d85..9fa35187 100644
--- a/modopt/opt/linear/base.py
+++ b/modopt/opt/linear/base.py
@@ -6,7 +6,7 @@
from modopt.base.backend import get_array_module
-class LinearParent(object):
+class LinearParent:
"""Linear Operator Parent Class.
This class sets the structure for defining linear operator instances.
diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py
index b562f77a..91a99f2a 100644
--- a/modopt/opt/proximity.py
+++ b/modopt/opt/proximity.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""PROXIMITY OPERATORS.
@@ -32,7 +31,7 @@
from modopt.signal.svd import svd_thresh, svd_thresh_coef, svd_thresh_coef_fast
-class ProximityParent(object):
+class ProximityParent:
"""Proximity Operator Parent Class.
This class sets the structure for defining proximity operator instances.
@@ -140,7 +139,7 @@ def _cost_method(self, *args, **kwargs):
``0.0``
"""
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - Min (X):", np.min(args[0]))
return 0
@@ -221,7 +220,7 @@ def _cost_method(self, *args, **kwargs):
if isinstance(cost_val, xp.ndarray):
cost_val = cost_val.item()
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - L1 NORM (X):", cost_val)
return cost_val
@@ -365,7 +364,7 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = self.thresh * nuclear_norm(cube2matrix(args[0]))
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - NUCLEAR NORM (X):", cost_val)
return cost_val
@@ -706,7 +705,7 @@ def _cost_method(self, *args, **kwargs):
self.weights * np.sort(np.squeeze(np.abs(args[0]))[::-1]),
)
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - OWL NORM (X):", cost_val)
return cost_val
@@ -790,7 +789,7 @@ def _cost_method(self, *args, **kwargs):
np.abs(self.weights * self._linear.op(args[0]) ** 2),
)
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - L2 NORM (X):", cost_val)
return cost_val
@@ -879,7 +878,7 @@ def _cost_method(self, *args, **kwargs):
+ np.abs(self.beta * self._linear.op(args[0])),
)
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - ELASTIC NET (X):", cost_val)
return cost_val
@@ -1183,7 +1182,7 @@ def _find_alpha(self, input_data, extra_factor=1.0):
data_size = input_data.shape[0]
# Computes the alpha^i points line 1 in Algorithm 1.
- alpha = np.zeros((data_size * 2))
+ alpha = np.zeros(data_size * 2)
data_abs = np.abs(input_data)
alpha[:data_size] = (self.beta * extra_factor) / (
data_abs + sys.float_info.epsilon
@@ -1227,7 +1226,7 @@ def _op_method(self, input_data, extra_factor=1.0):
if self._k_value > k_max:
warn(
"K value of the K-support norm is greater than the input "
- + "dimension, its value will be set to {0}".format(k_max),
+ + f"dimension, its value will be set to {k_max}",
)
self._k_value = k_max
@@ -1333,7 +1332,7 @@ def _cost_method(self, *args, **kwargs):
+ np.sum(data_abs[q_val:]) ** 2 / (self._k_value - q_val)
) * self.beta
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - K-SUPPORT NORM (X):", cost_val)
return cost_val
diff --git a/modopt/opt/reweight.py b/modopt/opt/reweight.py
index 7ff9aac4..8d120101 100644
--- a/modopt/opt/reweight.py
+++ b/modopt/opt/reweight.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""REWEIGHTING CLASSES.
@@ -13,7 +12,7 @@
from modopt.base.types import check_float
-class cwbReweight(object):
+class cwbReweight:
"""Candes, Wakin and Boyd reweighting class.
This class implements the reweighting scheme described in
@@ -81,7 +80,7 @@ def reweight(self, input_data):
"""
if self.verbose:
- print(" - Reweighting: {0}".format(self._rw_num))
+ print(f" - Reweighting: {self._rw_num}")
self._rw_num += 1
diff --git a/modopt/plot/__init__.py b/modopt/plot/__init__.py
index da6e096c..f6b39978 100644
--- a/modopt/plot/__init__.py
+++ b/modopt/plot/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""PLOTTING ROUTINES.
diff --git a/modopt/plot/cost_plot.py b/modopt/plot/cost_plot.py
index 36958450..2274f35d 100644
--- a/modopt/plot/cost_plot.py
+++ b/modopt/plot/cost_plot.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""PLOTTING ROUTINES.
@@ -43,7 +42,7 @@ def plotCost(cost_list, output=None):
if isinstance(output, type(None)):
file_name = "cost_function.png"
else:
- file_name = "{0}_cost_function.png".format(output)
+ file_name = f"{output}_cost_function.png"
plt.figure()
plt.plot(np.log10(cost_list), "r-")
diff --git a/modopt/signal/__init__.py b/modopt/signal/__init__.py
index 09b2d2c4..2aee1987 100644
--- a/modopt/signal/__init__.py
+++ b/modopt/signal/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""SIGNAL PROCESSING ROUTINES.
diff --git a/modopt/signal/filter.py b/modopt/signal/filter.py
index 2c7d8626..0e50d28f 100644
--- a/modopt/signal/filter.py
+++ b/modopt/signal/filter.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""FILTER ROUTINES.
diff --git a/modopt/signal/noise.py b/modopt/signal/noise.py
index fadf5308..b43a0b61 100644
--- a/modopt/signal/noise.py
+++ b/modopt/signal/noise.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""NOISE ROUTINES.
@@ -8,7 +7,6 @@
"""
-from builtins import zip
import numpy as np
diff --git a/modopt/signal/positivity.py b/modopt/signal/positivity.py
index 5c4b795b..f3f312d3 100644
--- a/modopt/signal/positivity.py
+++ b/modopt/signal/positivity.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""POSITIVITY.
diff --git a/modopt/signal/svd.py b/modopt/signal/svd.py
index cc204817..dd080306 100644
--- a/modopt/signal/svd.py
+++ b/modopt/signal/svd.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""SVD ROUTINES.
diff --git a/modopt/signal/wavelet.py b/modopt/signal/wavelet.py
index 72d608e7..d624db3a 100644
--- a/modopt/signal/wavelet.py
+++ b/modopt/signal/wavelet.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""WAVELET MODULE.
@@ -141,9 +140,9 @@ def call_mr_transform(
unique_string = datetime.now().strftime("%Y.%m.%d_%H.%M.%S") + str(getrandbits(128))
# Set the ouput file names.
- file_name = "{0}mr_temp_{1}".format(path, unique_string)
- file_fits = "{0}.fits".format(file_name)
- file_mr = "{0}.mr".format(file_name)
+ file_name = f"{path}mr_temp_{unique_string}"
+ file_fits = f"{file_name}.fits"
+ file_mr = f"{file_name}.mr"
# Write the input data to a fits file.
fits.writeto(file_fits, input_data)
@@ -152,7 +151,7 @@ def call_mr_transform(
opt = opt.split()
# Prepare command and execute it
- command_line = " ".join([executable] + opt + [file_fits, file_mr])
+ command_line = " ".join([executable, *opt, file_fits, file_mr])
stdout, _ = execute(command_line)
# Check for errors
diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py
index ca0cd666..c1e676a5 100644
--- a/modopt/tests/test_algorithms.py
+++ b/modopt/tests/test_algorithms.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""UNIT TESTS FOR Algorithms.
@@ -11,18 +10,13 @@
import numpy as np
import numpy.testing as npt
-import pytest
from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight
from pytest_cases import (
- case,
fixture,
- fixture_ref,
- lazy_value,
parametrize,
parametrize_with_cases,
)
-from test_helpers import Dummy
SKLEARN_AVAILABLE = True
try:
@@ -80,7 +74,7 @@ def build_kwargs(kwargs, use_metrics):
@parametrize(use_metrics=[True, False])
class AlgoCases:
- """Cases for algorithms.
+ r"""Cases for algorithms.
Most of the test solves the trivial problem
diff --git a/modopt/tests/test_helpers/__init__.py b/modopt/tests/test_helpers/__init__.py
index 3886b877..e69de29b 100644
--- a/modopt/tests/test_helpers/__init__.py
+++ b/modopt/tests/test_helpers/__init__.py
@@ -1 +0,0 @@
-from .utils import failparam, skipparam, Dummy
diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py
index e77074ab..7dc27871 100644
--- a/modopt/tests/test_opt.py
+++ b/modopt/tests/test_opt.py
@@ -37,10 +37,14 @@
PYWT_AVAILABLE = False
# Basic functions to be used as operators or as dummy functions
-func_identity = lambda x_val: x_val
-func_double = lambda x_val: x_val * 2
-func_sq = lambda x_val: x_val**2
-func_cube = lambda x_val: x_val**3
+def func_identity(x_val):
+ return x_val
+def func_double(x_val):
+ return x_val * 2
+def func_sq(x_val):
+ return x_val ** 2
+def func_cube(x_val):
+ return x_val ** 3
@case(tags="cost")
diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py
index b3787fc6..cdd95277 100644
--- a/modopt/tests/test_signal.py
+++ b/modopt/tests/test_signal.py
@@ -16,7 +16,7 @@
class TestFilter:
- """Test filter module"""
+ """Test filter module."""
@pytest.mark.parametrize(
("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)]
From 29833bc0f5f6cb85b7add3af7ad947d79440ed6b Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 21:25:05 +0100
Subject: [PATCH 16/45] move to a src layout.
---
{modopt/examples => examples}/README.rst | 0
{modopt/examples => examples}/__init__.py | 0
{modopt/examples => examples}/conftest.py | 0
.../example_lasso_forward_backward.py | 0
pyproject.toml | 11 +++--------
{modopt => src/modopt}/__init__.py | 0
{modopt => src/modopt}/base/__init__.py | 0
{modopt => src/modopt}/base/backend.py | 0
{modopt => src/modopt}/base/np_adjust.py | 0
{modopt => src/modopt}/base/observable.py | 0
{modopt => src/modopt}/base/transform.py | 0
{modopt => src/modopt}/base/types.py | 0
{modopt => src/modopt}/interface/__init__.py | 0
{modopt => src/modopt}/interface/errors.py | 0
{modopt => src/modopt}/interface/log.py | 0
{modopt => src/modopt}/math/__init__.py | 0
{modopt => src/modopt}/math/convolve.py | 0
{modopt => src/modopt}/math/matrix.py | 0
{modopt => src/modopt}/math/metrics.py | 0
{modopt => src/modopt}/math/stats.py | 0
{modopt => src/modopt}/opt/__init__.py | 0
{modopt => src/modopt}/opt/algorithms/__init__.py | 0
{modopt => src/modopt}/opt/algorithms/admm.py | 0
{modopt => src/modopt}/opt/algorithms/base.py | 0
.../modopt}/opt/algorithms/forward_backward.py | 0
.../modopt}/opt/algorithms/gradient_descent.py | 0
{modopt => src/modopt}/opt/algorithms/primal_dual.py | 0
{modopt => src/modopt}/opt/cost.py | 0
{modopt => src/modopt}/opt/gradient.py | 0
{modopt => src/modopt}/opt/linear/__init__.py | 0
{modopt => src/modopt}/opt/linear/base.py | 0
{modopt => src/modopt}/opt/linear/wavelet.py | 0
{modopt => src/modopt}/opt/proximity.py | 0
{modopt => src/modopt}/opt/reweight.py | 0
{modopt => src/modopt}/plot/__init__.py | 0
{modopt => src/modopt}/plot/cost_plot.py | 0
{modopt => src/modopt}/signal/__init__.py | 0
{modopt => src/modopt}/signal/filter.py | 0
{modopt => src/modopt}/signal/noise.py | 0
{modopt => src/modopt}/signal/positivity.py | 0
{modopt => src/modopt}/signal/svd.py | 0
{modopt => src/modopt}/signal/validation.py | 0
{modopt => src/modopt}/signal/wavelet.py | 0
{modopt/tests => tests}/test_algorithms.py | 0
{modopt/tests => tests}/test_base.py | 0
{modopt/tests => tests}/test_helpers/__init__.py | 0
{modopt/tests => tests}/test_helpers/utils.py | 0
{modopt/tests => tests}/test_math.py | 0
{modopt/tests => tests}/test_opt.py | 0
{modopt/tests => tests}/test_signal.py | 0
50 files changed, 3 insertions(+), 8 deletions(-)
rename {modopt/examples => examples}/README.rst (100%)
rename {modopt/examples => examples}/__init__.py (100%)
rename {modopt/examples => examples}/conftest.py (100%)
rename {modopt/examples => examples}/example_lasso_forward_backward.py (100%)
rename {modopt => src/modopt}/__init__.py (100%)
rename {modopt => src/modopt}/base/__init__.py (100%)
rename {modopt => src/modopt}/base/backend.py (100%)
rename {modopt => src/modopt}/base/np_adjust.py (100%)
rename {modopt => src/modopt}/base/observable.py (100%)
rename {modopt => src/modopt}/base/transform.py (100%)
rename {modopt => src/modopt}/base/types.py (100%)
rename {modopt => src/modopt}/interface/__init__.py (100%)
rename {modopt => src/modopt}/interface/errors.py (100%)
rename {modopt => src/modopt}/interface/log.py (100%)
rename {modopt => src/modopt}/math/__init__.py (100%)
rename {modopt => src/modopt}/math/convolve.py (100%)
rename {modopt => src/modopt}/math/matrix.py (100%)
rename {modopt => src/modopt}/math/metrics.py (100%)
rename {modopt => src/modopt}/math/stats.py (100%)
rename {modopt => src/modopt}/opt/__init__.py (100%)
rename {modopt => src/modopt}/opt/algorithms/__init__.py (100%)
rename {modopt => src/modopt}/opt/algorithms/admm.py (100%)
rename {modopt => src/modopt}/opt/algorithms/base.py (100%)
rename {modopt => src/modopt}/opt/algorithms/forward_backward.py (100%)
rename {modopt => src/modopt}/opt/algorithms/gradient_descent.py (100%)
rename {modopt => src/modopt}/opt/algorithms/primal_dual.py (100%)
rename {modopt => src/modopt}/opt/cost.py (100%)
rename {modopt => src/modopt}/opt/gradient.py (100%)
rename {modopt => src/modopt}/opt/linear/__init__.py (100%)
rename {modopt => src/modopt}/opt/linear/base.py (100%)
rename {modopt => src/modopt}/opt/linear/wavelet.py (100%)
rename {modopt => src/modopt}/opt/proximity.py (100%)
rename {modopt => src/modopt}/opt/reweight.py (100%)
rename {modopt => src/modopt}/plot/__init__.py (100%)
rename {modopt => src/modopt}/plot/cost_plot.py (100%)
rename {modopt => src/modopt}/signal/__init__.py (100%)
rename {modopt => src/modopt}/signal/filter.py (100%)
rename {modopt => src/modopt}/signal/noise.py (100%)
rename {modopt => src/modopt}/signal/positivity.py (100%)
rename {modopt => src/modopt}/signal/svd.py (100%)
rename {modopt => src/modopt}/signal/validation.py (100%)
rename {modopt => src/modopt}/signal/wavelet.py (100%)
rename {modopt/tests => tests}/test_algorithms.py (100%)
rename {modopt/tests => tests}/test_base.py (100%)
rename {modopt/tests => tests}/test_helpers/__init__.py (100%)
rename {modopt/tests => tests}/test_helpers/utils.py (100%)
rename {modopt/tests => tests}/test_math.py (100%)
rename {modopt/tests => tests}/test_opt.py (100%)
rename {modopt/tests => tests}/test_signal.py (100%)
diff --git a/modopt/examples/README.rst b/examples/README.rst
similarity index 100%
rename from modopt/examples/README.rst
rename to examples/README.rst
diff --git a/modopt/examples/__init__.py b/examples/__init__.py
similarity index 100%
rename from modopt/examples/__init__.py
rename to examples/__init__.py
diff --git a/modopt/examples/conftest.py b/examples/conftest.py
similarity index 100%
rename from modopt/examples/conftest.py
rename to examples/conftest.py
diff --git a/modopt/examples/example_lasso_forward_backward.py b/examples/example_lasso_forward_backward.py
similarity index 100%
rename from modopt/examples/example_lasso_forward_backward.py
rename to examples/example_lasso_forward_backward.py
diff --git a/pyproject.toml b/pyproject.toml
index 71bdce82..37da48a3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,9 +26,6 @@ dev=["black", "pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-sugar", "ruf
[build-system]
requires=["setuptools", "setuptools-scm[toml]", "wheel"]
-[tool.setuptools]
-packages=["modopt"]
-
[tool.coverage.run]
omit = ["*tests*", "*__init__*", "*setup.py*", "*_version.py*", "*example*"]
@@ -38,12 +35,10 @@ exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
[tool.black]
-[tool.ruff]
-
-src=["modopt"]
+[lint]
select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"]
-[tool.ruff.pydocstyle]
+[lint.pydocstyle]
convention="numpy"
[tool.isort]
@@ -51,6 +46,6 @@ profile="black"
[tool.pytest.ini_options]
minversion = "6.0"
-norecursedirs = ["tests/helpers"]
+norecursedirs = ["tests/test_helpers"]
testpaths=["modopt"]
addopts = ["--verbose", "--cov=modopt", "--cov-report=term-missing", "--cov-report=xml", "--junitxml=pytest.xml"]
diff --git a/modopt/__init__.py b/src/modopt/__init__.py
similarity index 100%
rename from modopt/__init__.py
rename to src/modopt/__init__.py
diff --git a/modopt/base/__init__.py b/src/modopt/base/__init__.py
similarity index 100%
rename from modopt/base/__init__.py
rename to src/modopt/base/__init__.py
diff --git a/modopt/base/backend.py b/src/modopt/base/backend.py
similarity index 100%
rename from modopt/base/backend.py
rename to src/modopt/base/backend.py
diff --git a/modopt/base/np_adjust.py b/src/modopt/base/np_adjust.py
similarity index 100%
rename from modopt/base/np_adjust.py
rename to src/modopt/base/np_adjust.py
diff --git a/modopt/base/observable.py b/src/modopt/base/observable.py
similarity index 100%
rename from modopt/base/observable.py
rename to src/modopt/base/observable.py
diff --git a/modopt/base/transform.py b/src/modopt/base/transform.py
similarity index 100%
rename from modopt/base/transform.py
rename to src/modopt/base/transform.py
diff --git a/modopt/base/types.py b/src/modopt/base/types.py
similarity index 100%
rename from modopt/base/types.py
rename to src/modopt/base/types.py
diff --git a/modopt/interface/__init__.py b/src/modopt/interface/__init__.py
similarity index 100%
rename from modopt/interface/__init__.py
rename to src/modopt/interface/__init__.py
diff --git a/modopt/interface/errors.py b/src/modopt/interface/errors.py
similarity index 100%
rename from modopt/interface/errors.py
rename to src/modopt/interface/errors.py
diff --git a/modopt/interface/log.py b/src/modopt/interface/log.py
similarity index 100%
rename from modopt/interface/log.py
rename to src/modopt/interface/log.py
diff --git a/modopt/math/__init__.py b/src/modopt/math/__init__.py
similarity index 100%
rename from modopt/math/__init__.py
rename to src/modopt/math/__init__.py
diff --git a/modopt/math/convolve.py b/src/modopt/math/convolve.py
similarity index 100%
rename from modopt/math/convolve.py
rename to src/modopt/math/convolve.py
diff --git a/modopt/math/matrix.py b/src/modopt/math/matrix.py
similarity index 100%
rename from modopt/math/matrix.py
rename to src/modopt/math/matrix.py
diff --git a/modopt/math/metrics.py b/src/modopt/math/metrics.py
similarity index 100%
rename from modopt/math/metrics.py
rename to src/modopt/math/metrics.py
diff --git a/modopt/math/stats.py b/src/modopt/math/stats.py
similarity index 100%
rename from modopt/math/stats.py
rename to src/modopt/math/stats.py
diff --git a/modopt/opt/__init__.py b/src/modopt/opt/__init__.py
similarity index 100%
rename from modopt/opt/__init__.py
rename to src/modopt/opt/__init__.py
diff --git a/modopt/opt/algorithms/__init__.py b/src/modopt/opt/algorithms/__init__.py
similarity index 100%
rename from modopt/opt/algorithms/__init__.py
rename to src/modopt/opt/algorithms/__init__.py
diff --git a/modopt/opt/algorithms/admm.py b/src/modopt/opt/algorithms/admm.py
similarity index 100%
rename from modopt/opt/algorithms/admm.py
rename to src/modopt/opt/algorithms/admm.py
diff --git a/modopt/opt/algorithms/base.py b/src/modopt/opt/algorithms/base.py
similarity index 100%
rename from modopt/opt/algorithms/base.py
rename to src/modopt/opt/algorithms/base.py
diff --git a/modopt/opt/algorithms/forward_backward.py b/src/modopt/opt/algorithms/forward_backward.py
similarity index 100%
rename from modopt/opt/algorithms/forward_backward.py
rename to src/modopt/opt/algorithms/forward_backward.py
diff --git a/modopt/opt/algorithms/gradient_descent.py b/src/modopt/opt/algorithms/gradient_descent.py
similarity index 100%
rename from modopt/opt/algorithms/gradient_descent.py
rename to src/modopt/opt/algorithms/gradient_descent.py
diff --git a/modopt/opt/algorithms/primal_dual.py b/src/modopt/opt/algorithms/primal_dual.py
similarity index 100%
rename from modopt/opt/algorithms/primal_dual.py
rename to src/modopt/opt/algorithms/primal_dual.py
diff --git a/modopt/opt/cost.py b/src/modopt/opt/cost.py
similarity index 100%
rename from modopt/opt/cost.py
rename to src/modopt/opt/cost.py
diff --git a/modopt/opt/gradient.py b/src/modopt/opt/gradient.py
similarity index 100%
rename from modopt/opt/gradient.py
rename to src/modopt/opt/gradient.py
diff --git a/modopt/opt/linear/__init__.py b/src/modopt/opt/linear/__init__.py
similarity index 100%
rename from modopt/opt/linear/__init__.py
rename to src/modopt/opt/linear/__init__.py
diff --git a/modopt/opt/linear/base.py b/src/modopt/opt/linear/base.py
similarity index 100%
rename from modopt/opt/linear/base.py
rename to src/modopt/opt/linear/base.py
diff --git a/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
similarity index 100%
rename from modopt/opt/linear/wavelet.py
rename to src/modopt/opt/linear/wavelet.py
diff --git a/modopt/opt/proximity.py b/src/modopt/opt/proximity.py
similarity index 100%
rename from modopt/opt/proximity.py
rename to src/modopt/opt/proximity.py
diff --git a/modopt/opt/reweight.py b/src/modopt/opt/reweight.py
similarity index 100%
rename from modopt/opt/reweight.py
rename to src/modopt/opt/reweight.py
diff --git a/modopt/plot/__init__.py b/src/modopt/plot/__init__.py
similarity index 100%
rename from modopt/plot/__init__.py
rename to src/modopt/plot/__init__.py
diff --git a/modopt/plot/cost_plot.py b/src/modopt/plot/cost_plot.py
similarity index 100%
rename from modopt/plot/cost_plot.py
rename to src/modopt/plot/cost_plot.py
diff --git a/modopt/signal/__init__.py b/src/modopt/signal/__init__.py
similarity index 100%
rename from modopt/signal/__init__.py
rename to src/modopt/signal/__init__.py
diff --git a/modopt/signal/filter.py b/src/modopt/signal/filter.py
similarity index 100%
rename from modopt/signal/filter.py
rename to src/modopt/signal/filter.py
diff --git a/modopt/signal/noise.py b/src/modopt/signal/noise.py
similarity index 100%
rename from modopt/signal/noise.py
rename to src/modopt/signal/noise.py
diff --git a/modopt/signal/positivity.py b/src/modopt/signal/positivity.py
similarity index 100%
rename from modopt/signal/positivity.py
rename to src/modopt/signal/positivity.py
diff --git a/modopt/signal/svd.py b/src/modopt/signal/svd.py
similarity index 100%
rename from modopt/signal/svd.py
rename to src/modopt/signal/svd.py
diff --git a/modopt/signal/validation.py b/src/modopt/signal/validation.py
similarity index 100%
rename from modopt/signal/validation.py
rename to src/modopt/signal/validation.py
diff --git a/modopt/signal/wavelet.py b/src/modopt/signal/wavelet.py
similarity index 100%
rename from modopt/signal/wavelet.py
rename to src/modopt/signal/wavelet.py
diff --git a/modopt/tests/test_algorithms.py b/tests/test_algorithms.py
similarity index 100%
rename from modopt/tests/test_algorithms.py
rename to tests/test_algorithms.py
diff --git a/modopt/tests/test_base.py b/tests/test_base.py
similarity index 100%
rename from modopt/tests/test_base.py
rename to tests/test_base.py
diff --git a/modopt/tests/test_helpers/__init__.py b/tests/test_helpers/__init__.py
similarity index 100%
rename from modopt/tests/test_helpers/__init__.py
rename to tests/test_helpers/__init__.py
diff --git a/modopt/tests/test_helpers/utils.py b/tests/test_helpers/utils.py
similarity index 100%
rename from modopt/tests/test_helpers/utils.py
rename to tests/test_helpers/utils.py
diff --git a/modopt/tests/test_math.py b/tests/test_math.py
similarity index 100%
rename from modopt/tests/test_math.py
rename to tests/test_math.py
diff --git a/modopt/tests/test_opt.py b/tests/test_opt.py
similarity index 100%
rename from modopt/tests/test_opt.py
rename to tests/test_opt.py
diff --git a/modopt/tests/test_signal.py b/tests/test_signal.py
similarity index 100%
rename from modopt/tests/test_signal.py
rename to tests/test_signal.py
From 536366d621d7f71cfd5f1135a25329aff51a819b Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 22:47:11 +0100
Subject: [PATCH 17/45] fix: make majority of tests passing.
---
src/modopt/opt/algorithms/__init__.py | 29 +++++++++++++++++++++++++
src/modopt/opt/linear/wavelet.py | 8 +++----
tests/test_helpers/__init__.py | 5 +++++
tests/test_opt.py | 31 +++++++++++++++++++++------
4 files changed, 63 insertions(+), 10 deletions(-)
diff --git a/src/modopt/opt/algorithms/__init__.py b/src/modopt/opt/algorithms/__init__.py
index ce6c5e56..ff79502c 100644
--- a/src/modopt/opt/algorithms/__init__.py
+++ b/src/modopt/opt/algorithms/__init__.py
@@ -44,3 +44,32 @@
"""
+from .forward_backward import FISTA, ForwardBackward, GenForwardBackward, POGM
+from .primal_dual import Condat
+from .gradient_descent import (
+ ADAMGradOpt,
+ AdaGenericGradOpt,
+ GenericGradOpt,
+ MomentumGradOpt,
+ RMSpropGradOpt,
+ SAGAOptGradOpt,
+ VanillaGenericGradOpt,
+)
+from .admm import ADMM, FastADMM
+
+__all__ = [
+ "FISTA",
+ "ForwardBackward",
+ "GenForwardBackward",
+ "POGM",
+ "Condat",
+ "ADAMGradOpt",
+ "AdaGenericGradOpt",
+ "GenericGradOpt",
+ "MomentumGradOpt",
+ "RMSpropGradOpt",
+ "SAGAOptGradOpt",
+ "VanillaGenericGradOpt",
+ "ADMM",
+ "FastADMM",
+]
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index e1608ff4..8dc44fd3 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -293,7 +293,7 @@ def __init__(
self.mode = mode
self.coeffs_shape = None # will be set after op.
- def op(self, data: torch.Tensor) -> list[torch.Tensor]:
+ def op(self, data):
"""Apply the wavelet decomposition on.
Parameters
@@ -355,7 +355,7 @@ def op(self, data: torch.Tensor) -> list[torch.Tensor]:
)
return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c.values()]
- def adj_op(self, coeffs: list[torch.Tensor]) -> torch.Tensor:
+ def adj_op(self, coeffs):
"""Apply the wavelet recomposition.
Parameters
@@ -429,7 +429,7 @@ def __init__(
)
self.coeffs_shape = None # will be set after op
- def op(self, data: cp.array) -> cp.ndarray:
+ def op(self, data):
"""Define the wavelet operator.
This method returns the input data convolved with the wavelet filter.
@@ -459,7 +459,7 @@ def op(self, data: cp.array) -> cp.ndarray:
return ret
- def adj_op(self, data: cp.ndarray) -> cp.ndarray:
+ def adj_op(self, data):
"""Define the wavelet adjoint operator.
This method returns the reconstructed image.
diff --git a/tests/test_helpers/__init__.py b/tests/test_helpers/__init__.py
index e69de29b..0ded847a 100644
--- a/tests/test_helpers/__init__.py
+++ b/tests/test_helpers/__init__.py
@@ -0,0 +1,5 @@
+"""Utilities for tests."""
+
+from .utils import Dummy, failparam, skipparam
+
+__all__ = ["Dummy", "failparam", "skipparam"]
diff --git a/tests/test_opt.py b/tests/test_opt.py
index 7dc27871..e31d3a49 100644
--- a/tests/test_opt.py
+++ b/tests/test_opt.py
@@ -36,15 +36,22 @@
except ImportError:
PYWT_AVAILABLE = False
+
# Basic functions to be used as operators or as dummy functions
def func_identity(x_val):
return x_val
+
+
def func_double(x_val):
return x_val * 2
+
+
def func_sq(x_val):
- return x_val ** 2
+ return x_val**2
+
+
def func_cube(x_val):
- return x_val ** 3
+ return x_val**3
@case(tags="cost")
@@ -187,10 +194,21 @@ def case_linear_wavelet_convolve(self):
@parametrize(
compute_backend=[
- pytest.param("numpy", marks=pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")),
- pytest.param("cupy", marks=pytest.mark.skipif(not PTWT_AVAILABLE, reason="Pytorch Wavelet not available."))
- ])
- def case_linear_wavelet_transform(self, compute_backend="numpy"):
+ pytest.param(
+ "numpy",
+ marks=pytest.mark.skipif(
+ not PYWT_AVAILABLE, reason="PyWavelet not available."
+ ),
+ ),
+ pytest.param(
+ "cupy",
+ marks=pytest.mark.skipif(
+ not PTWT_AVAILABLE, reason="Pytorch Wavelet not available."
+ ),
+ ),
+ ]
+ )
+ def case_linear_wavelet_transform(self, compute_backend):
linop = linear.WaveletTransform(
wavelet_name="haar",
shape=(8, 8),
@@ -318,6 +336,7 @@ class ProxCases:
[11.67394789, 12.87497954, 14.07601119],
[15.27704284, 16.47807449, 17.67910614],
],
+ ]
)
array233_3 = np.array(
[
From 0c3b13ce83e1d7d4147be1d050a4bb387b10310a Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 22:55:46 +0100
Subject: [PATCH 18/45] typo
---
src/modopt/base/np_adjust.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/src/modopt/base/np_adjust.py b/src/modopt/base/np_adjust.py
index 586a1ee0..10cb5c29 100644
--- a/src/modopt/base/np_adjust.py
+++ b/src/modopt/base/np_adjust.py
@@ -1,4 +1,3 @@
-
"""NUMPY ADJUSTMENT ROUTINES.
This module contains methods for adjusting the default output for certain
@@ -153,8 +152,7 @@ def pad2d(input_data, padding):
padding = np.array(padding)
elif not isinstance(padding, np.ndarray):
raise ValueError(
- "Padding must be an integer or a tuple (or list, np.ndarray) "
- + "of itegers",
+ "Padding must be an integer or a tuple (or list, np.ndarray) of integers",
)
if padding.size == 1:
From a79113b03ddaae8fa786ec1584270cf558b68cb8 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 23:04:13 +0100
Subject: [PATCH 19/45] fix tests setup
---
pyproject.toml | 3 +--
tests/test_helpers/utils.py | 4 ++--
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 37da48a3..dd08cbd0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,5 +47,4 @@ profile="black"
[tool.pytest.ini_options]
minversion = "6.0"
norecursedirs = ["tests/test_helpers"]
-testpaths=["modopt"]
-addopts = ["--verbose", "--cov=modopt", "--cov-report=term-missing", "--cov-report=xml", "--junitxml=pytest.xml"]
+addopts = ["--cov=modopt", "--cov-report=term-missing", "--cov-report=xml"]
diff --git a/tests/test_helpers/utils.py b/tests/test_helpers/utils.py
index 895b2371..049d257f 100644
--- a/tests/test_helpers/utils.py
+++ b/tests/test_helpers/utils.py
@@ -12,12 +12,12 @@ def failparam(*args, raises=None):
"""Return a pytest parameterization that should raise an error."""
if not issubclass(raises, Exception):
raise ValueError("raises should be an expected Exception.")
- return pytest.param(*args, marks=pytest.mark.raises(exception=raises))
+ return pytest.param(*args, marks=[pytest.mark.xfail(exception=raises)])
def skipparam(*args, cond=True, reason=""):
"""Return a pytest parameterization that should be skip if cond is valid."""
- return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason))
+ return pytest.param(*args, marks=[pytest.mark.skipif(cond, reason=reason)])
class Dummy:
From bcbe1f3ff469ca6cd12dc4e306b0ce2c97b513d0 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 23:13:04 +0100
Subject: [PATCH 20/45] slim down the CI.
---
.github/workflows/cd-build.yml | 9 ++---
.github/workflows/ci-build.yml | 72 ++++------------------------------
pyproject.toml | 3 +-
3 files changed, 13 insertions(+), 71 deletions(-)
diff --git a/.github/workflows/cd-build.yml b/.github/workflows/cd-build.yml
index fca9feb1..ded7159d 100644
--- a/.github/workflows/cd-build.yml
+++ b/.github/workflows/cd-build.yml
@@ -27,19 +27,17 @@ jobs:
shell: bash -l {0}
run: |
python -m pip install --upgrade pip
- python -m pip install -r develop.txt
python -m pip install twine
- python -m pip install .
+ python -m pip install .[doc,test]
- name: Run Tests
shell: bash -l {0}
run: |
- python setup.py test
+ pytest
- name: Check distribution
shell: bash -l {0}
run: |
- python setup.py sdist
twine check dist/*
- name: Upload coverage to Codecov
@@ -69,8 +67,7 @@ jobs:
run: |
conda install -c conda-forge pandoc
python -m pip install --upgrade pip
- python -m pip install -r docs/requirements.txt
- python -m pip install .
+ python -m pip install .[doc]
- name: Build API documentation
shell: bash -l {0}
diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml
index 88129d45..9d4226cb 100644
--- a/.github/workflows/ci-build.yml
+++ b/.github/workflows/ci-build.yml
@@ -16,42 +16,27 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
- python-version: ["3.10"]
+ python-version: ["3.8", "3.9", "3.10"]
steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python ${{ matrix.python-version }}
- uses: conda-incubator/setup-miniconda@v2
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
with:
- auto-update-conda: true
python-version: ${{ matrix.python-version }}
- auto-activate-base: false
-
- - name: Check Conda
- shell: bash -l {0}
- run: |
- conda info
- conda list
- python --version
+ cache: pip
- name: Install Dependencies
shell: bash -l {0}
run: |
python --version
python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install -r docs/requirements.txt
- python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
+ python -m pip install .[test]
+ python -m pip install astropy scikit-image scikit-learn matplotlib
python -m pip install tensorflow>=2.4.1 torch
- python -m pip install twine
- python -m pip install .
- name: Run Tests
shell: bash -l {0}
run: |
- export PATH=/usr/share/miniconda/bin:$PATH
pytest -n 2
- name: Save Test Results
@@ -59,18 +44,12 @@ jobs:
uses: actions/upload-artifact@v2
with:
name: unit-test-results-${{ matrix.os }}-${{ matrix.python-version }}
- path: pytest.xml
-
- - name: Check Distribution
- shell: bash -l {0}
- run: |
- python setup.py sdist
- twine check dist/*
+ path: .coverage.xml
- name: Check API Documentation build
shell: bash -l {0}
run: |
- conda install -c conda-forge pandoc
+ pip install .[doc]
sphinx-apidoc -t docs/_templates -feTMo docs/source modopt
sphinx-build -b doctest -E docs/source docs/_build
@@ -81,38 +60,3 @@ jobs:
file: coverage.xml
flags: unittests
- test-basic:
- name: Basic Test Suite
- runs-on: ${{ matrix.os }}
-
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-latest, macos-latest]
- python-version: ["3.7", "3.8", "3.9"]
-
- steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python ${{ matrix.python-version }}
- uses: conda-incubator/setup-miniconda@v2
- with:
- auto-update-conda: true
- python-version: ${{ matrix.python-version }}
- auto-activate-base: false
-
- - name: Install Dependencies
- shell: bash -l {0}
- run: |
- python --version
- python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
- python -m pip install .
-
- - name: Run Tests
- shell: bash -l {0}
- run: |
- export PATH=/usr/share/miniconda/bin:$PATH
- pytest -n 2
diff --git a/pyproject.toml b/pyproject.toml
index dd08cbd0..9c8f4bee 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,7 +21,8 @@ doc=["myst-parser==0.16.1",
"sphinx-gallery==0.11.1",
"sphinxawesome-theme==3.2.1",
"sphinxcontrib-bibtex"]
-dev=["black", "pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-sugar", "ruff"]
+dev=["black", "ruff"]
+test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-sugar"]
[build-system]
requires=["setuptools", "setuptools-scm[toml]", "wheel"]
From 5fb4cb22850de137022cc61d2268f79a264286a4 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Sun, 18 Feb 2024 16:15:46 +0100
Subject: [PATCH 21/45] black 24
---
src/modopt/__init__.py | 1 -
src/modopt/base/__init__.py | 1 -
src/modopt/base/backend.py | 1 -
src/modopt/base/observable.py | 1 -
src/modopt/base/transform.py | 1 -
src/modopt/base/types.py | 1 -
src/modopt/interface/__init__.py | 1 -
src/modopt/interface/errors.py | 1 -
src/modopt/interface/log.py | 1 -
src/modopt/math/__init__.py | 1 -
src/modopt/math/convolve.py | 1 -
src/modopt/math/matrix.py | 1 -
src/modopt/math/metrics.py | 1 -
src/modopt/math/stats.py | 1 -
src/modopt/opt/__init__.py | 1 -
src/modopt/opt/gradient.py | 1 -
src/modopt/opt/proximity.py | 1 -
src/modopt/opt/reweight.py | 1 -
src/modopt/plot/__init__.py | 1 -
src/modopt/plot/cost_plot.py | 1 -
src/modopt/signal/__init__.py | 1 -
src/modopt/signal/filter.py | 1 -
src/modopt/signal/noise.py | 2 --
src/modopt/signal/positivity.py | 1 -
src/modopt/signal/svd.py | 1 -
src/modopt/signal/wavelet.py | 1 -
tests/test_algorithms.py | 1 -
27 files changed, 28 deletions(-)
diff --git a/src/modopt/__init__.py b/src/modopt/__init__.py
index 958f3ace..354c31d0 100644
--- a/src/modopt/__init__.py
+++ b/src/modopt/__init__.py
@@ -1,4 +1,3 @@
-
"""MODOPT PACKAGE.
ModOpt is a series of Modular Optimisation tools for solving inverse problems.
diff --git a/src/modopt/base/__init__.py b/src/modopt/base/__init__.py
index e7df6c37..c4c681d7 100644
--- a/src/modopt/base/__init__.py
+++ b/src/modopt/base/__init__.py
@@ -1,4 +1,3 @@
-
"""BASE ROUTINES.
This module contains submodules for basic operations such as type
diff --git a/src/modopt/base/backend.py b/src/modopt/base/backend.py
index b4987942..485f649a 100644
--- a/src/modopt/base/backend.py
+++ b/src/modopt/base/backend.py
@@ -1,4 +1,3 @@
-
"""BACKEND MODULE.
This module contains methods for GPU Compatiblity.
diff --git a/src/modopt/base/observable.py b/src/modopt/base/observable.py
index 69c6b238..15996dfa 100644
--- a/src/modopt/base/observable.py
+++ b/src/modopt/base/observable.py
@@ -1,4 +1,3 @@
-
"""Observable.
This module contains observable classes
diff --git a/src/modopt/base/transform.py b/src/modopt/base/transform.py
index 1dc9039a..25ed102a 100644
--- a/src/modopt/base/transform.py
+++ b/src/modopt/base/transform.py
@@ -1,4 +1,3 @@
-
"""DATA TRANSFORM ROUTINES.
This module contains methods for transforming data.
diff --git a/src/modopt/base/types.py b/src/modopt/base/types.py
index 5ed24ec3..9e9a15b9 100644
--- a/src/modopt/base/types.py
+++ b/src/modopt/base/types.py
@@ -1,4 +1,3 @@
-
"""TYPE HANDLING ROUTINES.
This module contains methods for handing object types.
diff --git a/src/modopt/interface/__init__.py b/src/modopt/interface/__init__.py
index 529816ee..a54f4bf5 100644
--- a/src/modopt/interface/__init__.py
+++ b/src/modopt/interface/__init__.py
@@ -1,4 +1,3 @@
-
"""INTERFACE ROUTINES.
This module contains submodules for error handling, logging and IO interaction.
diff --git a/src/modopt/interface/errors.py b/src/modopt/interface/errors.py
index 5c84ad0e..93e9ed1b 100644
--- a/src/modopt/interface/errors.py
+++ b/src/modopt/interface/errors.py
@@ -1,4 +1,3 @@
-
"""ERROR HANDLING ROUTINES.
This module contains methods for handing warnings and errors.
diff --git a/src/modopt/interface/log.py b/src/modopt/interface/log.py
index d3e0d8e9..50c316b7 100644
--- a/src/modopt/interface/log.py
+++ b/src/modopt/interface/log.py
@@ -1,4 +1,3 @@
-
"""LOGGING ROUTINES.
This module contains methods for handing logging.
diff --git a/src/modopt/math/__init__.py b/src/modopt/math/__init__.py
index 0423a333..d5ffc67a 100644
--- a/src/modopt/math/__init__.py
+++ b/src/modopt/math/__init__.py
@@ -1,4 +1,3 @@
-
"""MATHEMATICS ROUTINES.
This module contains submodules for mathematical applications.
diff --git a/src/modopt/math/convolve.py b/src/modopt/math/convolve.py
index ac1cf84c..21dc8b4e 100644
--- a/src/modopt/math/convolve.py
+++ b/src/modopt/math/convolve.py
@@ -1,4 +1,3 @@
-
"""CONVOLUTION ROUTINES.
This module contains methods for convolution.
diff --git a/src/modopt/math/matrix.py b/src/modopt/math/matrix.py
index a2419a6c..ef59f785 100644
--- a/src/modopt/math/matrix.py
+++ b/src/modopt/math/matrix.py
@@ -1,4 +1,3 @@
-
"""MATRIX ROUTINES.
This module contains methods for matrix operations.
diff --git a/src/modopt/math/metrics.py b/src/modopt/math/metrics.py
index 8f797f02..befd4fa4 100644
--- a/src/modopt/math/metrics.py
+++ b/src/modopt/math/metrics.py
@@ -1,4 +1,3 @@
-
"""METRICS.
This module contains classes of different metric functions for optimization.
diff --git a/src/modopt/math/stats.py b/src/modopt/math/stats.py
index b3ee0d8b..022e5f3c 100644
--- a/src/modopt/math/stats.py
+++ b/src/modopt/math/stats.py
@@ -1,4 +1,3 @@
-
"""STATISTICS ROUTINES.
This module contains methods for basic statistics.
diff --git a/src/modopt/opt/__init__.py b/src/modopt/opt/__init__.py
index 86564f90..62d1f388 100644
--- a/src/modopt/opt/__init__.py
+++ b/src/modopt/opt/__init__.py
@@ -1,4 +1,3 @@
-
"""OPTIMISATION PROBLEM MODULES.
This module contains submodules for solving optimisation problems.
diff --git a/src/modopt/opt/gradient.py b/src/modopt/opt/gradient.py
index bd214f21..3c5f0031 100644
--- a/src/modopt/opt/gradient.py
+++ b/src/modopt/opt/gradient.py
@@ -1,4 +1,3 @@
-
"""GRADIENT CLASSES.
This module contains classses for defining algorithm gradients.
diff --git a/src/modopt/opt/proximity.py b/src/modopt/opt/proximity.py
index 91a99f2a..10e69a98 100644
--- a/src/modopt/opt/proximity.py
+++ b/src/modopt/opt/proximity.py
@@ -1,4 +1,3 @@
-
"""PROXIMITY OPERATORS.
This module contains classes of proximity operators for optimisation.
diff --git a/src/modopt/opt/reweight.py b/src/modopt/opt/reweight.py
index 8d120101..b37fc6fb 100644
--- a/src/modopt/opt/reweight.py
+++ b/src/modopt/opt/reweight.py
@@ -1,4 +1,3 @@
-
"""REWEIGHTING CLASSES.
This module contains classes for reweighting optimisation implementations.
diff --git a/src/modopt/plot/__init__.py b/src/modopt/plot/__init__.py
index f6b39978..f31ed596 100644
--- a/src/modopt/plot/__init__.py
+++ b/src/modopt/plot/__init__.py
@@ -1,4 +1,3 @@
-
"""PLOTTING ROUTINES.
This module contains submodules for plotting applications.
diff --git a/src/modopt/plot/cost_plot.py b/src/modopt/plot/cost_plot.py
index 2274f35d..7fb7e39b 100644
--- a/src/modopt/plot/cost_plot.py
+++ b/src/modopt/plot/cost_plot.py
@@ -1,4 +1,3 @@
-
"""PLOTTING ROUTINES.
This module contains methods for making plots.
diff --git a/src/modopt/signal/__init__.py b/src/modopt/signal/__init__.py
index 2aee1987..6bf0912b 100644
--- a/src/modopt/signal/__init__.py
+++ b/src/modopt/signal/__init__.py
@@ -1,4 +1,3 @@
-
"""SIGNAL PROCESSING ROUTINES.
This module contains submodules for signal processing.
diff --git a/src/modopt/signal/filter.py b/src/modopt/signal/filter.py
index 0e50d28f..33c3c105 100644
--- a/src/modopt/signal/filter.py
+++ b/src/modopt/signal/filter.py
@@ -1,4 +1,3 @@
-
"""FILTER ROUTINES.
This module contains methods for distance measurements in cosmology.
diff --git a/src/modopt/signal/noise.py b/src/modopt/signal/noise.py
index b43a0b61..2594fc62 100644
--- a/src/modopt/signal/noise.py
+++ b/src/modopt/signal/noise.py
@@ -1,4 +1,3 @@
-
"""NOISE ROUTINES.
This module contains methods for adding and removing noise from data.
@@ -7,7 +6,6 @@
"""
-
import numpy as np
from modopt.base.backend import get_array_module
diff --git a/src/modopt/signal/positivity.py b/src/modopt/signal/positivity.py
index f3f312d3..8d7aa46c 100644
--- a/src/modopt/signal/positivity.py
+++ b/src/modopt/signal/positivity.py
@@ -1,4 +1,3 @@
-
"""POSITIVITY.
This module contains a function that retains only positive coefficients in
diff --git a/src/modopt/signal/svd.py b/src/modopt/signal/svd.py
index dd080306..cf147503 100644
--- a/src/modopt/signal/svd.py
+++ b/src/modopt/signal/svd.py
@@ -1,4 +1,3 @@
-
"""SVD ROUTINES.
This module contains methods for thresholding singular values.
diff --git a/src/modopt/signal/wavelet.py b/src/modopt/signal/wavelet.py
index d624db3a..b55b78d9 100644
--- a/src/modopt/signal/wavelet.py
+++ b/src/modopt/signal/wavelet.py
@@ -1,4 +1,3 @@
-
"""WAVELET MODULE.
This module contains methods for performing wavelet transformations using
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index c1e676a5..fe5f92a6 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -1,4 +1,3 @@
-
"""UNIT TESTS FOR Algorithms.
This module contains unit tests for the modopt.opt module.
From 44e7cf4a6c52ddf41fbfa2ac2aadcfc1d254d1ff Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 10:39:12 +0100
Subject: [PATCH 22/45] update ruff config.
---
pyproject.toml | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 9c8f4bee..835dae3d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,10 +36,15 @@ exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
[tool.black]
-[lint]
+
+[tool.ruff]
+exclude = ["examples", "docs"]
+[tool.ruff.lint]
select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"]
-[lint.pydocstyle]
+ignore = ["F401"] # we like the try: import ... expect: ...
+
+[tool.ruff.lint.pydocstyle]
convention="numpy"
[tool.isort]
From 67342e79d4a71917fa8480b8919fd61eb66094bb Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 10:39:18 +0100
Subject: [PATCH 23/45] fix: F403
---
src/modopt/__init__.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/src/modopt/__init__.py b/src/modopt/__init__.py
index 354c31d0..e8ae1c3c 100644
--- a/src/modopt/__init__.py
+++ b/src/modopt/__init__.py
@@ -8,7 +8,9 @@
from importlib_metadata import version
-from modopt.base import *
+from modopt.base import np_adjust, transform, types, wrappers, observable
+
+__all__ = ["np_adjust", "transform", "types", "wrappers", "observable"]
try:
_version = version("modopt")
From 14542cc81e942d077a937afe101ba638cd8f46b2 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 10:41:54 +0100
Subject: [PATCH 24/45] fix: D1** errors.
---
src/modopt/opt/linear/wavelet.py | 1 +
tests/test_helpers/utils.py | 3 +++
tests/test_opt.py | 5 +++++
3 files changed, 9 insertions(+)
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index 8dc44fd3..9fb64b33 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -110,6 +110,7 @@ def __init__(
@property
def coeffs_shape(self):
+ """Get the coeffs shapes."""
return self.operator.coeffs_shape
diff --git a/tests/test_helpers/utils.py b/tests/test_helpers/utils.py
index 049d257f..41f948a6 100644
--- a/tests/test_helpers/utils.py
+++ b/tests/test_helpers/utils.py
@@ -1,5 +1,6 @@
"""
Some helper functions for the test parametrization.
+
They should be used inside ``@pytest.mark.parametrize`` call.
:Author: Pierre-Antoine Comby
@@ -21,4 +22,6 @@ def skipparam(*args, cond=True, reason=""):
class Dummy:
+ """Dummy Class."""
+
pass
diff --git a/tests/test_opt.py b/tests/test_opt.py
index e31d3a49..0a73e835 100644
--- a/tests/test_opt.py
+++ b/tests/test_opt.py
@@ -39,18 +39,22 @@
# Basic functions to be used as operators or as dummy functions
def func_identity(x_val):
+ """Return x."""
return x_val
def func_double(x_val):
+ """Double x."""
return x_val * 2
def func_sq(x_val):
+ """Square x."""
return x_val**2
def func_cube(x_val):
+ """Cube x."""
return x_val**3
@@ -209,6 +213,7 @@ def case_linear_wavelet_convolve(self):
]
)
def case_linear_wavelet_transform(self, compute_backend):
+ """Case linear wavelet operator."""
linop = linear.WaveletTransform(
wavelet_name="haar",
shape=(8, 8),
From 99208cd2362247a0a92ffb6ddcb823b80b8dbbea Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 10:51:44 +0100
Subject: [PATCH 25/45] fix: E501
---
src/modopt/math/stats.py | 4 ++--
src/modopt/opt/algorithms/admm.py | 10 ++++++----
src/modopt/opt/linear/wavelet.py | 8 +++++---
3 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/src/modopt/math/stats.py b/src/modopt/math/stats.py
index 022e5f3c..8583a8c3 100644
--- a/src/modopt/math/stats.py
+++ b/src/modopt/math/stats.py
@@ -29,8 +29,8 @@ def gaussian_kernel(data_shape, sigma, norm="max"):
Desiered shape of the kernel
sigma : float
Standard deviation of the kernel
- norm : {'max', 'sum'}, optional
- Normalisation of the kerenl (options are ``'max'`` or ``'sum'``, default is ``'max'``)
+ norm : {'max', 'sum'}, optional, default='max'
+ Normalisation of the kernel
Returns
-------
diff --git a/src/modopt/opt/algorithms/admm.py b/src/modopt/opt/algorithms/admm.py
index 4fd8074e..b2f45171 100644
--- a/src/modopt/opt/algorithms/admm.py
+++ b/src/modopt/opt/algorithms/admm.py
@@ -68,7 +68,8 @@ def _calc_cost(self, u, v, **kwargs):
class ADMM(SetUp):
r"""Fast ADMM Optimisation Algorihm.
- This class implement the ADMM algorithm described in :cite:`Goldstein2014` (Algorithm 1).
+ This class implement the ADMM algorithm described in :cite:`Goldstein2014`
+ (Algorithm 1).
Parameters
----------
@@ -86,7 +87,7 @@ class ADMM(SetUp):
Constraint vector
optimizers: tuple
2-tuple of callable, that are the optimizers for the u and v.
- Each callable should access an init and obs argument and returns an estimate for:
+ Each callable should access init and obs argument and returns an estimate for:
.. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2
.. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2
cost_funcs: tuple
@@ -243,7 +244,7 @@ class FastADMM(ADMM):
Constraint vector
optimizers: tuple
2-tuple of callable, that are the optimizers for the u and v.
- Each callable should access an init and obs argument and returns an estimate for:
+ Each callable should access init and obs argument and returns an estimate for:
.. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2
.. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2
cost_funcs: tuple
@@ -257,7 +258,8 @@ class FastADMM(ADMM):
Notes
-----
- This is an accelerated version of the ADMM algorithm. The convergence hypothesis are stronger than for the ADMM algorithm.
+ This is an accelerated version of the ADMM algorithm. The convergence hypothesis are
+ stronger than for the ADMM algorithm.
See Also
--------
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index 9fb64b33..ae92efa7 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -65,7 +65,7 @@ class WaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.
- This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU using Pytorch).
+ This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU).
Parameters
----------
@@ -79,7 +79,8 @@ class WaveletTransform(LinearParent):
mode: str, default "zero"
Boundary Condition mode
compute_backend: str, "numpy" or "cupy", default "numpy"
- Backend library to use. "cupy" also requires a working installation of PyTorch and pytorch wavelets.
+ Backend library to use. "cupy" also requires a working installation of PyTorch
+ and PyTorch wavelets (ptwt).
**kwargs: extra kwargs for Pywavelet or Pytorch Wavelet
"""
@@ -165,7 +166,8 @@ def __init__(
):
if not pywt_available:
raise ImportError(
- "PyWavelet and/or joblib are not available. Please install it to use WaveletTransform."
+ "PyWavelet and/or joblib are not available. "
+ "Please install it to use WaveletTransform."
)
if wavelet_name not in pywt.wavelist(kind="all"):
raise ValueError(
From fdd6d1ee283bd2fd2c7c6eaa0793615225bd1524 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 11:29:24 +0100
Subject: [PATCH 26/45] fix: NPY002
random number generator stuff.
This adds possibility to use rng properly in ModOpt.
---
src/modopt/math/matrix.py | 7 ++++++-
src/modopt/signal/noise.py | 12 +++++++++---
src/modopt/signal/validation.py | 9 +++++++--
tests/test_algorithms.py | 5 ++++-
tests/test_math.py | 16 +++++-----------
tests/test_signal.py | 15 ++++++++-------
6 files changed, 39 insertions(+), 25 deletions(-)
diff --git a/src/modopt/math/matrix.py b/src/modopt/math/matrix.py
index ef59f785..e52cbfc5 100644
--- a/src/modopt/math/matrix.py
+++ b/src/modopt/math/matrix.py
@@ -274,6 +274,8 @@ class PowerMethod:
initialisation (default is ``True``)
verbose : bool, optional
Optional verbosity (default is ``False``)
+ rng: int, xp.random.Generator or None (default is ``None``)
+ Random number generator or seed.
Examples
--------
@@ -300,6 +302,7 @@ def __init__(
auto_run=True,
compute_backend="numpy",
verbose=False,
+ rng=None,
):
self._operator = operator
@@ -308,6 +311,7 @@ def __init__(
self._verbose = verbose
xp, compute_backend = get_backend(compute_backend)
self.xp = xp
+ self.rng = None
self.compute_backend = compute_backend
if auto_run:
self.get_spec_rad()
@@ -324,7 +328,8 @@ def _set_initial_x(self):
Random values of the same shape as the input data
"""
- return self.xp.random.random(self._data_shape).astype(self._data_type)
+ rng = self.xp.random.default_rng(self.rng)
+ return rng.random(self._data_shape).astype(self._data_type)
def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
"""Get spectral radius.
diff --git a/src/modopt/signal/noise.py b/src/modopt/signal/noise.py
index 2594fc62..28307f52 100644
--- a/src/modopt/signal/noise.py
+++ b/src/modopt/signal/noise.py
@@ -11,7 +11,7 @@
from modopt.base.backend import get_array_module
-def add_noise(input_data, sigma=1.0, noise_type="gauss"):
+def add_noise(input_data, sigma=1.0, noise_type="gauss", rng=None):
"""Add noise to data.
This method adds Gaussian or Poisson noise to the input data.
@@ -25,6 +25,9 @@ def add_noise(input_data, sigma=1.0, noise_type="gauss"):
default is ``1.0``)
noise_type : {'gauss', 'poisson'}
Type of noise to be added (default is ``'gauss'``)
+ rng: np.random.Generator or int
+ A Random number generator or a seed to initialize one.
+
Returns
-------
@@ -64,6 +67,9 @@ def add_noise(input_data, sigma=1.0, noise_type="gauss"):
array([ 3.24869073, -1.22351283, -1.0563435 , -2.14593724, 1.73081526])
"""
+ if not isinstance(rng, np.random.Generator):
+ rng = np.random.default_rng(rng)
+
input_data = np.array(input_data)
if noise_type not in {"gauss", "poisson"}:
@@ -78,10 +84,10 @@ def add_noise(input_data, sigma=1.0, noise_type="gauss"):
)
if noise_type == "gauss":
- random = np.random.randn(*input_data.shape)
+ random = rng.standard_normal(input_data.shape)
elif noise_type == "poisson":
- random = np.random.poisson(np.abs(input_data))
+ random = rng.poisson(np.abs(input_data))
if isinstance(sigma, (int, float)):
return input_data + sigma * random
diff --git a/src/modopt/signal/validation.py b/src/modopt/signal/validation.py
index 68c1e726..66485a54 100644
--- a/src/modopt/signal/validation.py
+++ b/src/modopt/signal/validation.py
@@ -16,6 +16,7 @@ def transpose_test(
x_args=None,
y_shape=None,
y_args=None,
+ rng=None,
):
"""Transpose test.
@@ -36,6 +37,8 @@ def transpose_test(
Shape of transpose operator input data (default is ``None``)
y_args : tuple, optional
Arguments to be passed to transpose operator (default is ``None``)
+ rng: np.random.Generator or int or None (default is ``None``)
+ Initialized random number generator or seed.
Raises
------
@@ -62,9 +65,11 @@ def transpose_test(
if isinstance(y_args, type(None)):
y_args = x_args
+ if not isinstance(rng, np.random.Generator):
+ rng = np.random.default_rng(rng)
# Generate random arrays.
- x_val = np.random.ranf(x_shape)
- y_val = np.random.ranf(y_shape)
+ x_val = rng.random(x_shape)
+ y_val = rng.random(y_shape)
# Calculate
mx_y = np.sum(np.multiply(operator(x_val, x_args), y_val))
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index fe5f92a6..63847764 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -24,6 +24,9 @@
SKLEARN_AVAILABLE = False
+rng = np.random.default_rng()
+
+
@fixture
def idty():
"""Identity function."""
@@ -84,7 +87,7 @@ class AlgoCases:
"""
data1 = np.arange(9).reshape(3, 3).astype(float)
- data2 = data1 + np.random.randn(*data1.shape) * 1e-6
+ data2 = data1 + rng.standard_normal(data1.shape) * 1e-6
max_iter = 20
@parametrize(
diff --git a/tests/test_math.py b/tests/test_math.py
index e7536b03..37b71b1e 100644
--- a/tests/test_math.py
+++ b/tests/test_math.py
@@ -29,6 +29,8 @@
else:
SKIMAGE_AVAILABLE = True
+rng = np.random.default_rng(1)
+
class TestConvolve:
"""Test convolve functions."""
@@ -136,18 +138,15 @@ class TestMatrix:
),
)
- @pytest.fixture
+ @pytest.fixture(scope="module")
def pm_instance(self, request):
"""Power Method instance."""
- np.random.seed(1)
pm = matrix.PowerMethod(
lambda x_val: x_val.dot(x_val.T),
self.array33.shape,
- auto_run=request.param,
verbose=True,
+ rng=np.random.default_rng(0),
)
- if not request.param:
- pm.get_spec_rad(max_iter=1)
return pm
@pytest.mark.parametrize(
@@ -195,12 +194,7 @@ def test_rotate(self):
npt.assert_raises(ValueError, matrix.rotate, self.array23, np.pi / 2)
- @pytest.mark.parametrize(
- ("pm_instance", "value"),
- [(True, 1.0), (False, 0.8675467477372257)],
- indirect=["pm_instance"],
- )
- def test_power_method(self, pm_instance, value):
+ def test_power_method(self, pm_instance, value=1):
"""Test power method."""
npt.assert_almost_equal(pm_instance.spec_rad, value)
npt.assert_almost_equal(pm_instance.inv_spec_rad, 1 / value)
diff --git a/tests/test_signal.py b/tests/test_signal.py
index cdd95277..6dbb0bba 100644
--- a/tests/test_signal.py
+++ b/tests/test_signal.py
@@ -45,13 +45,13 @@ class TestNoise:
data1 = np.arange(9).reshape(3, 3).astype(float)
data2 = np.array(
- [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]],
+ [[0, 3.0, 4.0], [6.0, 9.0, 8.0], [14.0, 14.0, 17.0]],
)
data3 = np.array(
[
- [1.62434536, 0.38824359, 1.47182825],
- [1.92703138, 4.86540763, 2.6984613],
- [7.74481176, 6.2387931, 8.3190391],
+ [0.3455842, 1.8216181, 2.3304371],
+ [1.6968428, 4.9053559, 5.4463746],
+ [5.4630468, 7.5811181, 8.3645724],
]
)
data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
@@ -70,9 +70,10 @@ class TestNoise:
)
def test_add_noise(self, data, noise_type, sigma, data_noise):
"""Test add_noise."""
- np.random.seed(1)
+ rng = np.random.default_rng(1)
npt.assert_almost_equal(
- noise.add_noise(data, sigma=sigma, noise_type=noise_type), data_noise
+ noise.add_noise(data, sigma=sigma, noise_type=noise_type, rng=rng),
+ data_noise,
)
@pytest.mark.parametrize(
@@ -241,13 +242,13 @@ class TestValidation:
def test_transpose_test(self):
"""Test transpose_test."""
- np.random.seed(2)
npt.assert_equal(
validation.transpose_test(
lambda x_val, y_val: x_val.dot(y_val),
lambda x_val, y_val: x_val.dot(y_val.T),
self.array33.shape,
x_args=self.array33,
+ rng=2,
),
None,
)
From 22c016fcc81b18af0794dd07fb144ed8409b2566 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 11:31:50 +0100
Subject: [PATCH 27/45] fix: RUF012
Mutable class args. we don't want type annotations.
---
src/modopt/opt/linear/wavelet.py | 2 +-
tests/test_base.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index ae92efa7..ff434287 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -281,7 +281,7 @@ def _adj_op(self, coeffs):
class TorchWaveletTransform:
"""Wavelet transform using pytorch."""
- wavedec3_keys = ["aad", "ada", "add", "daa", "dad", "dda", "ddd"]
+ wavedec3_keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd")
def __init__(
self,
diff --git a/tests/test_base.py b/tests/test_base.py
index 298253d6..62e09095 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -149,7 +149,7 @@ def test_matrix2cube(self):
class TestType:
"""Test for type module."""
- data_list = list(range(5))
+ data_list = list(range(5)) # noqa: RUF012
data_int = np.arange(5)
data_flt = np.arange(5).astype(float)
From 59f58dddf5132829a115d7b89945321ea864ee0b Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 12:17:56 +0100
Subject: [PATCH 28/45] fix: B026
---
src/modopt/opt/cost.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/modopt/opt/cost.py b/src/modopt/opt/cost.py
index 4bead130..5c5701b1 100644
--- a/src/modopt/opt/cost.py
+++ b/src/modopt/opt/cost.py
@@ -187,7 +187,7 @@ def get_cost(self, *args, **kwargs):
print(" - ITERATION:", self._iteration)
# Calculate the current cost
- self.cost = self._calc_cost(verbose=self._verbose, *args, **kwargs)
+ self.cost = self._calc_cost(*args, verbose=self._verbose, **kwargs)
self._cost_list.append(self.cost)
if self._verbose:
From b9827ee38bbac3ff7b6d1d899491fc556d52fca6 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 12:19:45 +0100
Subject: [PATCH 29/45] fix: B028
---
src/modopt/__init__.py | 1 +
src/modopt/interface/errors.py | 2 +-
src/modopt/opt/linear/wavelet.py | 4 +++-
3 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/src/modopt/__init__.py b/src/modopt/__init__.py
index e8ae1c3c..e93e45ff 100644
--- a/src/modopt/__init__.py
+++ b/src/modopt/__init__.py
@@ -19,6 +19,7 @@
warn(
"Could not extract package metadata. Make sure the package is "
+ "correctly installed.",
+ stacklevel=1,
)
__version__ = _version
diff --git a/src/modopt/interface/errors.py b/src/modopt/interface/errors.py
index 93e9ed1b..84031e3c 100644
--- a/src/modopt/interface/errors.py
+++ b/src/modopt/interface/errors.py
@@ -41,7 +41,7 @@ def warn(warn_string, log=None):
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- warnings.warn(warn_string)
+ warnings.warn(warn_string, stacklevel=2)
def catch_error(exception, log=None):
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index ff434287..fa450ba2 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -201,7 +201,9 @@ def __init__(
self.n_batch = n_batch
if self.n_batch == 1 and self.n_jobs != 1:
- warnings.warn("Making n_jobs = 1 for WaveletTransform as n_batchs = 1")
+ warnings.warn(
+ "Making n_jobs = 1 for WaveletTransform as n_batchs = 1", stacklevel=1
+ )
self.n_jobs = 1
self.backend = backend
n_proc = self.n_jobs
From 0811b7a9582a289ba4f9b24ba12a8257e2528bbf Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 12:20:37 +0100
Subject: [PATCH 30/45] fix: ruff format
---
src/modopt/base/observable.py | 2 --
src/modopt/math/matrix.py | 2 --
src/modopt/opt/algorithms/base.py | 2 --
src/modopt/opt/algorithms/forward_backward.py | 4 ----
src/modopt/opt/algorithms/primal_dual.py | 1 -
src/modopt/opt/cost.py | 2 --
src/modopt/opt/gradient.py | 8 --------
src/modopt/opt/linear/base.py | 6 ------
src/modopt/opt/linear/wavelet.py | 2 --
src/modopt/opt/proximity.py | 13 -------------
src/modopt/opt/reweight.py | 1 -
11 files changed, 43 deletions(-)
diff --git a/src/modopt/base/observable.py b/src/modopt/base/observable.py
index 15996dfa..bf8371c3 100644
--- a/src/modopt/base/observable.py
+++ b/src/modopt/base/observable.py
@@ -31,7 +31,6 @@ class Observable:
"""
def __init__(self, signals):
-
# Define class parameters
self._allowed_signals = []
self._observers = {}
@@ -213,7 +212,6 @@ def __init__(
wind=6,
eps=1.0e-3,
):
-
self.name = name
self.metric = metric
self.mapping = mapping
diff --git a/src/modopt/math/matrix.py b/src/modopt/math/matrix.py
index e52cbfc5..b200f15d 100644
--- a/src/modopt/math/matrix.py
+++ b/src/modopt/math/matrix.py
@@ -63,7 +63,6 @@ def gram_schmidt(matrix, return_opt="orthonormal"):
e_vec = []
for vector in matrix:
-
if u_vec:
u_now = vector - sum(project(u_i, vector) for u_i in u_vec)
else:
@@ -304,7 +303,6 @@ def __init__(
verbose=False,
rng=None,
):
-
self._operator = operator
self._data_shape = data_shape
self._data_type = data_type
diff --git a/src/modopt/opt/algorithms/base.py b/src/modopt/opt/algorithms/base.py
index dbb73be0..f7391063 100644
--- a/src/modopt/opt/algorithms/base.py
+++ b/src/modopt/opt/algorithms/base.py
@@ -110,7 +110,6 @@ def metrics(self):
@metrics.setter
def metrics(self, metrics):
-
if isinstance(metrics, type(None)):
self._metrics = {}
elif isinstance(metrics, dict):
@@ -271,7 +270,6 @@ def _iterations(self, max_iter, progbar=None):
# We do not call metrics if metrics is empty or metric call
# period is None
if self.metrics and self.metric_call_period is not None:
-
metric_conditions = (
self.idx % self.metric_call_period == 0
or self.idx == (max_iter - 1)
diff --git a/src/modopt/opt/algorithms/forward_backward.py b/src/modopt/opt/algorithms/forward_backward.py
index 4c1cb35c..31927eb0 100644
--- a/src/modopt/opt/algorithms/forward_backward.py
+++ b/src/modopt/opt/algorithms/forward_backward.py
@@ -72,7 +72,6 @@ def __init__(
r_lazy=4,
**kwargs,
):
-
if isinstance(a_cd, type(None)):
self.mode = "regular"
self.p_lazy = p_lazy
@@ -355,7 +354,6 @@ def __init__(
linear=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -588,7 +586,6 @@ def __init__(
linear=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -875,7 +872,6 @@ def __init__(
metrics=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
diff --git a/src/modopt/opt/algorithms/primal_dual.py b/src/modopt/opt/algorithms/primal_dual.py
index 24908993..fee49a25 100644
--- a/src/modopt/opt/algorithms/primal_dual.py
+++ b/src/modopt/opt/algorithms/primal_dual.py
@@ -95,7 +95,6 @@ def __init__(
metrics=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
diff --git a/src/modopt/opt/cost.py b/src/modopt/opt/cost.py
index 5c5701b1..37771f16 100644
--- a/src/modopt/opt/cost.py
+++ b/src/modopt/opt/cost.py
@@ -81,7 +81,6 @@ def __init__(
verbose=True,
plot_output=None,
):
-
self.cost = initial_cost
self._cost_list = []
self._cost_interval = cost_interval
@@ -112,7 +111,6 @@ def _check_cost(self):
# Check if enough cost values have been collected
if len(self._test_list) == self._test_range:
-
# The mean of the first half of the test list
t1 = xp.mean(
xp.array(self._test_list[len(self._test_list) // 2 :]),
diff --git a/src/modopt/opt/gradient.py b/src/modopt/opt/gradient.py
index 3c5f0031..fe9b87d8 100644
--- a/src/modopt/opt/gradient.py
+++ b/src/modopt/opt/gradient.py
@@ -69,7 +69,6 @@ def __init__(
input_data_writeable=False,
verbose=True,
):
-
self.verbose = verbose
self._input_data_writeable = input_data_writeable
self._grad_data_type = data_type
@@ -98,7 +97,6 @@ def obs_data(self):
@obs_data.setter
def obs_data(self, input_data):
-
if self._grad_data_type in {float, np.floating}:
input_data = check_float(input_data)
check_npndarray(
@@ -126,7 +124,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -145,7 +142,6 @@ def trans_op(self):
@trans_op.setter
def trans_op(self, operator):
-
self._trans_op = check_callable(operator)
@property
@@ -155,7 +151,6 @@ def get_grad(self):
@get_grad.setter
def get_grad(self, method):
-
self._get_grad = check_callable(method)
@property
@@ -165,7 +160,6 @@ def grad(self):
@grad.setter
def grad(self, input_value):
-
if self._grad_data_type in {float, np.floating}:
input_value = check_float(input_value)
self._grad = input_value
@@ -177,7 +171,6 @@ def cost(self):
@cost.setter
def cost(self, method):
-
self._cost = check_callable(method)
def trans_op_op(self, input_data):
@@ -241,7 +234,6 @@ class GradBasic(GradParent):
"""
def __init__(self, *args, **kwargs):
-
super().__init__(*args, **kwargs)
self.get_grad = self._get_grad_method
self.cost = self._cost_method
diff --git a/src/modopt/opt/linear/base.py b/src/modopt/opt/linear/base.py
index 9fa35187..af748a73 100644
--- a/src/modopt/opt/linear/base.py
+++ b/src/modopt/opt/linear/base.py
@@ -30,7 +30,6 @@ class LinearParent:
"""
def __init__(self, op, adj_op):
-
self.op = op
self.adj_op = adj_op
@@ -41,7 +40,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -51,7 +49,6 @@ def adj_op(self):
@adj_op.setter
def adj_op(self, operator):
-
self._adj_op = check_callable(operator)
@@ -67,7 +64,6 @@ class Identity(LinearParent):
"""
def __init__(self):
-
self.op = lambda input_data: input_data
self.adj_op = self.op
self.cost = lambda *args, **kwargs: 0
@@ -127,7 +123,6 @@ class LinearCombo(LinearParent):
"""
def __init__(self, operators, weights=None):
-
operators, weights = self._check_inputs(operators, weights)
self.operators = operators
self.weights = weights
@@ -199,7 +194,6 @@ def _check_inputs(self, operators, weights):
operators = self._check_type(operators)
for operator in operators:
-
if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index fa450ba2..e554150e 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -46,7 +46,6 @@ class WaveletConvolve(LinearParent):
"""
def __init__(self, filters, method="scipy"):
-
self._filters = check_float(filters)
self.op = lambda input_data: filter_convolve_stack(
input_data,
@@ -94,7 +93,6 @@ def __init__(
compute_backend="numpy",
**kwargs,
):
-
if compute_backend == "cupy" and ptwt_available:
self.operator = CupyWaveletTransform(
wavelet=wavelet_name, shape=shape, level=level, mode=mode
diff --git a/src/modopt/opt/proximity.py b/src/modopt/opt/proximity.py
index 10e69a98..204a168d 100644
--- a/src/modopt/opt/proximity.py
+++ b/src/modopt/opt/proximity.py
@@ -46,7 +46,6 @@ class ProximityParent:
"""
def __init__(self, op, cost):
-
self.op = op
self.cost = cost
@@ -57,7 +56,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -77,7 +75,6 @@ def cost(self):
@cost.setter
def cost(self, method):
-
self._cost = check_callable(method)
@@ -97,7 +94,6 @@ class IdentityProx(ProximityParent):
"""
def __init__(self):
-
self.op = lambda x_val: x_val
self.cost = lambda x_val: 0
@@ -115,7 +111,6 @@ class Positivity(ProximityParent):
"""
def __init__(self):
-
self.op = lambda input_data: positive(input_data)
self.cost = self._cost_method
@@ -166,7 +161,6 @@ class SparseThreshold(ProximityParent):
"""
def __init__(self, linear, weights, thresh_type="soft"):
-
self._linear = linear
self.weights = weights
self._thresh_type = thresh_type
@@ -276,7 +270,6 @@ def __init__(
initial_rank=None,
operator=None,
):
-
self.thresh = threshold
self.thresh_type = thresh_type
self.lowr_type = lowr_type
@@ -468,7 +461,6 @@ class ProximityCombo(ProximityParent):
"""
def __init__(self, operators):
-
operators = self._check_operators(operators)
self.operators = operators
self.op = self._op_method
@@ -737,7 +729,6 @@ class Ridge(ProximityParent):
"""
def __init__(self, linear, weights, thresh_type="soft"):
-
self._linear = linear
self.weights = weights
self.op = self._op_method
@@ -824,7 +815,6 @@ class ElasticNet(ProximityParent):
"""
def __init__(self, linear, alpha, beta):
-
self._linear = linear
self.alpha = alpha
self.beta = beta
@@ -1080,12 +1070,10 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
midpoint = 0
while (first_idx <= last_idx) and not found and (cnt < alpha.shape[0]):
-
midpoint = (first_idx + last_idx) // 2
cnt += 1
if prev_midpoint == midpoint:
-
# Particular case
sum0 = self._compute_theta(
data_abs,
@@ -1287,7 +1275,6 @@ def _find_q(self, sorted_data):
and not cnt == self._k_value
and (first_idx <= last_idx < self._k_value)
):
-
q_val = (first_idx + last_idx) // 2
cnt += 1
l1_part = sorted_data[q_val:].sum() / (self._k_value - q_val)
diff --git a/src/modopt/opt/reweight.py b/src/modopt/opt/reweight.py
index b37fc6fb..4a9bf44b 100644
--- a/src/modopt/opt/reweight.py
+++ b/src/modopt/opt/reweight.py
@@ -43,7 +43,6 @@ class cwbReweight:
"""
def __init__(self, weights, thresh_factor=1.0, verbose=False):
-
self.weights = check_float(weights)
self.original_weights = np.copy(self.weights)
self.thresh_factor = check_float(thresh_factor)
From 067e29f18a67ecb33c067538ee6245df75e93688 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 13:59:01 +0100
Subject: [PATCH 31/45] fix: NPY002
---
tests/test_opt.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/tests/test_opt.py b/tests/test_opt.py
index 0a73e835..039cb9fd 100644
--- a/tests/test_opt.py
+++ b/tests/test_opt.py
@@ -36,6 +36,8 @@
except ImportError:
PYWT_AVAILABLE = False
+rng = np.random.default_rng()
+
# Basic functions to be used as operators or as dummy functions
def func_identity(x_val):
@@ -461,7 +463,7 @@ def case_prox_grouplasso(self, use_weights):
else:
weights = np.tile(np.zeros((3, 3)), (4, 1, 1))
- random_data = 3 * np.random.random(weights[0].shape)
+ random_data = 3 * rng.random(weights[0].shape)
random_data_tile = np.tile(random_data, (weights.shape[0], 1, 1))
if use_weights:
gl_result_data = 2 * random_data_tile - 3
From 0040d1a7def41be40089e377325447a1c459684f Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 14:23:37 +0100
Subject: [PATCH 32/45] proj: add pytest-xdist for parallel testing.
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 835dae3d..1e9a7f20 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,7 +22,7 @@ doc=["myst-parser==0.16.1",
"sphinxawesome-theme==3.2.1",
"sphinxcontrib-bibtex"]
dev=["black", "ruff"]
-test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-sugar"]
+test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-xdist", "pytest-sugar"]
[build-system]
requires=["setuptools", "setuptools-scm[toml]", "wheel"]
From 4189da6307237041875d89f0c9c82de47f5c9f9e Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 14:32:17 +0100
Subject: [PATCH 33/45] remove type annotations.
---
src/modopt/opt/linear/wavelet.py | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index e554150e..8012a072 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -285,10 +285,10 @@ class TorchWaveletTransform:
def __init__(
self,
- shape: tuple[int, ...],
- wavelet: str,
- level: int,
- mode: str,
+ shape,
+ wavelet,
+ level,
+ mode,
):
self.wavelet = wavelet
self.level = level
@@ -417,10 +417,10 @@ class CupyWaveletTransform(LinearParent):
def __init__(
self,
- shape: tuple[int, ...],
- wavelet: str,
- level: int,
- mode: str,
+ shape,
+ wavelet,
+ level,
+ mode,
):
self.wavelet = wavelet
self.level = level
From 81953519875e012a0e1e8752f7db5638f1c45696 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 14:33:34 +0100
Subject: [PATCH 34/45] proj: use importlib_metada for python 3.8 compat.
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 1e9a7f20..6f9637d1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,7 @@ authors = [{name="Samuel Farrens", email="samuel.farrens@cea.fr"},
readme="README.md"
license={file="LICENCE.txt"}
-dependencies = ["numpy", "scipy", "tqdm"]
+dependencies = ["numpy", "scipy", "tqdm", "importlib_metadata"]
[project.optional-dependencies]
gpu=["torch", "ptwt"]
From 51ffef3b714351b46a515e578b14629fb2e8bb52 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 14:42:21 +0100
Subject: [PATCH 35/45] fix: update sssim values
---
tests/test_math.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tests/test_math.py b/tests/test_math.py
index 37b71b1e..5c466e5e 100644
--- a/tests/test_math.py
+++ b/tests/test_math.py
@@ -205,8 +205,8 @@ class TestMetrics:
data1 = np.arange(49).reshape(7, 7)
mask = np.ones(data1.shape)
- ssim_res = 0.8963363560519094
- ssim_mask_res = 0.805154442543846
+ ssim_res = 0.8958315888566867
+ ssim_mask_res = 0.8023827544418249
snr_res = 10.134554256920536
psnr_res = 14.860761791850397
mse_res = 0.03265305507330247
From 495680387323817a1e897582f428b87653e9172f Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 15:32:51 +0100
Subject: [PATCH 36/45] adapt doc to new format.
---
docs/source/conf.py | 9 +++------
examples/example_lasso_forward_backward.py | 1 +
pyproject.toml | 11 ++++++-----
3 files changed, 10 insertions(+), 11 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index cd39ee08..e9d88229 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -21,9 +21,6 @@
copyright = f"2020, {author}"
gh_user = "sfarrens"
-# If your documentation needs a minimal Sphinx version, state it here.
-needs_sphinx = "3.3"
-
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
@@ -38,7 +35,7 @@
"sphinx.ext.napoleon",
"sphinx.ext.todo",
"sphinx.ext.viewcode",
- "sphinxawesome_theme",
+ "sphinxawesome_theme.highlighting",
"sphinxcontrib.bibtex",
"myst_parser",
"nbsphinx",
@@ -103,7 +100,7 @@
}
html_collapsible_definitions = True
html_awesome_headerlinks = True
-html_logo = "modopt_logo.jpg"
+html_logo = "modopt_logo.png"
html_permalinks_icon = (
'