diff --git a/pyproject.toml b/pyproject.toml index 5c02020..7c57f69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ pytest-testinfra = [ { version = ">=8.0", python = ">= 3.8" } ] dataclasses = { version = ">=0.8", python = "< 3.7" } -typing-extensions = { version = ">=3.0", markers="python_version < '3.8'" } +typing-extensions = { version = ">=3.0", markers="python_version < '3.10'" } cached-property = { version = "^1.5", markers="python_version < '3.8'" } filelock = "^3.4" deprecation = "^2.1" @@ -65,3 +65,12 @@ strict = true [[tool.mypy.overrides]] module = "testinfra,deprecation" ignore_missing_imports = true + +[tool.pytest.ini_options] +xfail_strict = true +addopts = "--strict-markers" +markers = [ + 'secretleapmark', + 'othersecretmark', + 'secretpodmark', +] diff --git a/pytest_container/container.py b/pytest_container/container.py index 4419a11..ed96fec 100644 --- a/pytest_container/container.py +++ b/pytest_container/container.py @@ -15,6 +15,7 @@ import time import warnings from abc import ABC +from abc import ABCMeta from abc import abstractmethod from dataclasses import dataclass from dataclasses import field @@ -34,6 +35,11 @@ from typing import List from typing import Optional from typing import overload + +try: + from typing import Self +except ImportError: + from typing_extensions import Self from typing import Tuple from typing import Type from typing import Union @@ -44,6 +50,8 @@ import testinfra from filelock import BaseFileLock from filelock import FileLock +from pytest import Mark +from pytest import MarkDecorator from pytest import param from pytest_container.helpers import get_always_pull_option from pytest_container.inspect import ContainerHealth @@ -493,6 +501,11 @@ class ContainerBase: default_factory=list ) + #: optional list of marks applied to this container image under test + _marks: Collection[Union[MarkDecorator, Mark]] = field( + default_factory=list + ) + _is_local: bool = False def __post_init__(self) -> None: @@ -503,6 +516,9 @@ def __post_init__(self) -> None: def __str__(self) -> str: return self.url or self.container_id + def __bool__(self) -> bool: + return True + @property def _build_tag(self) -> str: """Internal build tag assigned to each immage, either the image url or @@ -519,6 +535,18 @@ def local_image(self) -> bool: """ return self._is_local + @property + def marks(self) -> Collection[Union[MarkDecorator, Mark]]: + return self._marks + + @property + def values(self) -> Tuple[Self, ...]: + return (self,) + + @property + def id(self) -> str: + return str(self) + def get_launch_cmd( self, container_runtime: OciRuntimeBase, @@ -656,8 +684,25 @@ def baseurl(self) -> Optional[str]: """ +class _HackMROMeta(ABCMeta): + def mro(cls): + return ( + cls, + ContainerBase, + ContainerBaseABC, + tuple, + _pytest.mark.ParameterSet, + object, + ) + + @dataclass(unsafe_hash=True) -class Container(ContainerBase, ContainerBaseABC): +class Container( + ContainerBase, + ContainerBaseABC, + _pytest.mark.ParameterSet, + metaclass=_HackMROMeta, +): """This class stores information about the Container Image under test.""" def pull_container(self, container_runtime: OciRuntimeBase) -> None: @@ -696,7 +741,12 @@ def baseurl(self) -> Optional[str]: @dataclass(unsafe_hash=True) -class DerivedContainer(ContainerBase, ContainerBaseABC): +class DerivedContainer( + ContainerBase, + ContainerBaseABC, + _pytest.mark.ParameterSet, + metaclass=_HackMROMeta, +): """Class for storing information about the Container Image under test, that is build from a :file:`Containerfile`/:file:`Dockerfile` from a different image (can be any image from a registry or an instance of @@ -723,6 +773,23 @@ class DerivedContainer(ContainerBase, ContainerBaseABC): #: has been built add_build_tags: List[str] = field(default_factory=list) + @staticmethod + def _get_recursive_marks( + ctr: Union[Container, "DerivedContainer", str] + ) -> Collection[Union[MarkDecorator, Mark]]: + if isinstance(ctr, str): + return [] + if isinstance(ctr, Container): + return ctr._marks + + return tuple(ctr._marks) + tuple( + DerivedContainer._get_recursive_marks(ctr.base) + ) + + @property + def marks(self) -> Collection[Union[MarkDecorator, Mark]]: + return DerivedContainer._get_recursive_marks(self) + def __post_init__(self) -> None: super().__post_init__() if not self.base: diff --git a/pytest_container/plugin.py b/pytest_container/plugin.py index 47e947f..8eb821e 100644 --- a/pytest_container/plugin.py +++ b/pytest_container/plugin.py @@ -7,10 +7,13 @@ from subprocess import run from typing import Callable from typing import Generator +from typing import Union +from pytest_container.container import Container from pytest_container.container import container_and_marks_from_pytest_param from pytest_container.container import ContainerData from pytest_container.container import ContainerLauncher +from pytest_container.container import DerivedContainer from pytest_container.helpers import get_extra_build_args from pytest_container.helpers import get_extra_pod_create_args from pytest_container.helpers import get_extra_run_args @@ -77,13 +80,12 @@ def fixture_funct( pytest_generate_tests. """ - try: - container, _ = container_and_marks_from_pytest_param(request.param) - except AttributeError as attr_err: - raise RuntimeError( - "This fixture was not parametrized correctly, " - "did you forget to call `auto_container_parametrize` in `pytest_generate_tests`?" - ) from attr_err + container: Union[DerivedContainer, Container] = ( + request.param + if isinstance(request.param, (DerivedContainer, Container)) + else request.param[0] + ) + assert isinstance(container, (DerivedContainer, Container)) _logger.debug("Requesting the container %s", str(container)) if scope == "session" and container.singleton: diff --git a/pytest_container/pod.py b/pytest_container/pod.py index 3ec627d..e11fa6e 100644 --- a/pytest_container/pod.py +++ b/pytest_container/pod.py @@ -1,17 +1,27 @@ """Module for managing podman pods.""" import contextlib import json +from abc import ABCMeta from dataclasses import dataclass from dataclasses import field from pathlib import Path from subprocess import check_output from types import TracebackType +from typing import Collection from typing import List from typing import Optional + +try: + from typing import Self +except ImportError: + from typing_extensions import Self +from typing import Tuple from typing import Type from typing import Union from _pytest.mark import ParameterSet +from pytest import Mark +from pytest import MarkDecorator from pytest_container.container import Container from pytest_container.container import ContainerData from pytest_container.container import ContainerLauncher @@ -24,8 +34,18 @@ from pytest_container.runtime import PodmanRuntime +class _HackMROMeta(ABCMeta): + def mro(cls): + return ( + cls, + tuple, + ParameterSet, + object, + ) + + @dataclass -class Pod: +class Pod(ParameterSet, metaclass=_HackMROMeta): """A pod is a collection of containers that share the same network and port forwards. Currently only :command:`podman` supports creating pods. @@ -40,6 +60,30 @@ class Pod: #: ports exposed by the pod forwarded_ports: List[PortForwarding] = field(default_factory=list) + _marks: Collection[Union[MarkDecorator, Mark]] = field( + default_factory=list + ) + + @property + def values(self) -> Tuple[Self]: + return (self,) + + @property + def marks(self) -> Collection[Union[MarkDecorator, Mark]]: + marks = tuple(self._marks) + for ctr in self.containers: + marks += tuple(ctr.marks) + return marks + + @property + def id(self) -> str: + return "Pod with containers: " + ",".join( + str(c) for c in self.containers + ) + + def __bool__(self) -> bool: + return True + @dataclass(frozen=True) class PodData: diff --git a/tests/test_marks.py b/tests/test_marks.py new file mode 100644 index 0000000..5d8b304 --- /dev/null +++ b/tests/test_marks.py @@ -0,0 +1,99 @@ +import pytest +from _pytest.mark import ParameterSet +from pytest_container.container import Container +from pytest_container.container import ContainerBase +from pytest_container.container import DerivedContainer +from pytest_container.pod import Pod + +from tests.images import LEAP_URL + +LEAP_WITH_MARK = Container(url=LEAP_URL, _marks=[pytest.mark.secretleapmark]) + +DERIVED_ON_LEAP_WITH_MARK = DerivedContainer(base=LEAP_WITH_MARK) + +SECOND_DERIVED_ON_LEAP = DerivedContainer( + base=DERIVED_ON_LEAP_WITH_MARK, _marks=[pytest.mark.othersecretmark] +) + +INDEPENDENT_OTHER_LEAP = Container( + url=LEAP_URL, _marks=[pytest.mark.othersecretmark] +) + +UNMARKED_POD = Pod(containers=[LEAP_WITH_MARK, INDEPENDENT_OTHER_LEAP]) + +MARKED_POD = Pod( + containers=[LEAP_WITH_MARK, INDEPENDENT_OTHER_LEAP], + _marks=[pytest.mark.secretpodmark], +) + + +def test_marks() -> None: + assert list(LEAP_WITH_MARK.marks) == [pytest.mark.secretleapmark] + assert list(DERIVED_ON_LEAP_WITH_MARK.marks) == [ + pytest.mark.secretleapmark + ] + assert list(SECOND_DERIVED_ON_LEAP.marks) == [ + pytest.mark.othersecretmark, + pytest.mark.secretleapmark, + ] + assert not DerivedContainer( + base=LEAP_URL, containerfile="ENV HOME=/root" + ).marks + + pod_marks = UNMARKED_POD.marks + assert ( + len(pod_marks) == 2 + and pytest.mark.othersecretmark in pod_marks + and pytest.mark.secretleapmark in pod_marks + ) + + pod_marks = MARKED_POD.marks + assert ( + len(pod_marks) == 3 + and pytest.mark.othersecretmark in pod_marks + and pytest.mark.secretleapmark in pod_marks + and pytest.mark.secretpodmark in pod_marks + ) + + +@pytest.mark.parametrize( + "ctr", + [ + LEAP_WITH_MARK, + DERIVED_ON_LEAP_WITH_MARK, + SECOND_DERIVED_ON_LEAP, + INDEPENDENT_OTHER_LEAP, + ], +) +def test_container_is_pytest_param(ctr) -> None: + + assert isinstance(ctr, ParameterSet) + assert isinstance(ctr, (Container, DerivedContainer)) + + +@pytest.mark.parametrize( + "ctr", + [ + LEAP_WITH_MARK, + DERIVED_ON_LEAP_WITH_MARK, + SECOND_DERIVED_ON_LEAP, + INDEPENDENT_OTHER_LEAP, + ], +) +def test_container_is_truthy(ctr: ContainerBase) -> None: + """Regression test that we don't accidentally inherit __bool__ from tuple + and the container is False by default. + + """ + assert ctr + + +@pytest.mark.parametrize("pd", [MARKED_POD, UNMARKED_POD]) +def test_pod_is_pytest_param(pd: Pod) -> None: + assert isinstance(pd, ParameterSet) + assert isinstance(pd, Pod) + + +@pytest.mark.parametrize("pd", [MARKED_POD, UNMARKED_POD]) +def test_pod_is_truthy(pd: Pod) -> None: + assert pd