Skip to content

Commit

Permalink
Deprecation handling (#516)
Browse files Browse the repository at this point in the history
* add: add deprecation warning decorators and tests that assert deprecated functions are indeed removed
  • Loading branch information
jnsbck authored Nov 20, 2024
1 parent 32068e9 commit 8fe6e68
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 27 deletions.
1 change: 1 addition & 0 deletions jaxley/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from jaxley.modules import *
from jaxley.optimize import ParamTransform
from jaxley.stimulus import datapoint_to_step_currents, step_current
from jaxley.utils.misc_utils import deprecated, deprecated_kwargs
57 changes: 57 additions & 0 deletions jaxley/utils/misc_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from typing import List, Optional, Union

import jax.numpy as jnp
Expand Down Expand Up @@ -34,3 +35,59 @@ def is_str_all(arg, force: bool = True) -> bool:
assert arg == "all", "Only 'all' is allowed"
return arg == "all"
return False


class deprecated:
"""Decorator to mark a function as deprecated.
Can be used to mark functions that will be removed in future versions. This will
also be tested in the CI pipeline to ensure that deprecated functions are removed.
Warns with: "func_name is deprecated and will be removed in version version."
Args:
version: The version in which the function will be removed, i.e. "0.1.0".
amend_msg: An optional message to append to the deprecation warning.
"""

def __init__(self, version: str, amend_msg: str = ""):
self._version: str = version
self._amend_msg: str = amend_msg

def __call__(self, func):
def wrapper(*args, **kwargs):
msg = f"{func.__name__} is deprecated and will be removed in version {self._version}."
warnings.warn(msg + self._amend_msg)
return func(*args, **kwargs)

return wrapper


class deprecated_kwargs:
"""Decorator to mark a keyword arguemnt of a function as deprecated.
Can be used to mark kwargs that will be removed in future versions. This will
also be tested in the CI pipeline to ensure that deprecated kwargs are removed.
Warns with: "kwarg is deprecated and will be removed in version version."
Args:
version: The version in which the keyword argument will be removed, i.e. "0.1.0".
deprecated_kwargs: A list of keyword arguments that are deprecated.
amend_msg: An optional message to append to the deprecation warning.
"""

def __init__(self, version: str, kwargs: List = [], amend_msg: str = ""):
self._version: str = version
self._amend_msg: str = amend_msg
self._depcrecated_kwargs: List = kwargs

def __call__(self, func):
def wrapper(*args, **kwargs):
for deprecated_kwarg in self._depcrecated_kwargs:
if deprecated_kwarg in kwargs and kwargs[deprecated_kwarg] is not None:
msg = f"{deprecated_kwarg} is deprecated and will be removed in version {self._version}."
warnings.warn(msg + self._amend_msg)
return func(*args, **kwargs)

return wrapper
27 changes: 0 additions & 27 deletions tests/test_license.py

This file was deleted.

60 changes: 60 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import os
import re
from pathlib import Path
from typing import List

import numpy as np
import pytest


def list_files(directory):
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(".py"):
yield os.path.join(root, file)


license_txt = """# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>"""


@pytest.mark.parametrize("dir", ["../jaxley", "."])
def test_license(dir):
for i, file in enumerate(list_files(dir)):
with open(file, "r") as f:
header = f.read(len(license_txt))
assert (
header == license_txt
), f"File {file} does not have the correct license header"


def test_rm_all_deprecated_functions():
from jaxley.__version__ import __version__ as package_version

package_version = np.array([int(s) for s in package_version.split(".")])

decorator_pattern = r"@deprecated(?:_signature)?"
version_pattern = r"[v]?(\d+\.\d+\.\d+)"

package_dir = Path(__file__).parent.parent / "jaxley"

violations = []
for py_file in package_dir.rglob("*.py"):
with open(py_file, "r") as f:
for line_num, line in enumerate(f, 1):
if re.search(decorator_pattern, line):
version_match = re.search(version_pattern, line)
if version_match:
depr_version_str = version_match.group(1)
depr_version = np.array(
[int(s) for s in depr_version_str.split(".")]
)
if not np.all(package_version <= depr_version):
violations.append(f"{py_file}:L{line_num}")

assert not violations, "\n".join(
["Found deprecated items that should have been removed:", *violations]
)

0 comments on commit 8fe6e68

Please sign in to comment.