Skip to content

Commit

Permalink
improve stubgen tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bluenote10 committed Jan 7, 2024
1 parent fbb738a commit 1e56aa4
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 32 deletions.
13 changes: 11 additions & 2 deletions misc/test-stubgenc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,24 @@ function stubgenc_test() {
rm -rf "${STUBGEN_OUTPUT_FOLDER:?}/*"
stubgen -o "$STUBGEN_OUTPUT_FOLDER" "${@:2}"

# Check if generated stubs can actually be type checked by mypy
if ! mypy "$STUBGEN_OUTPUT_FOLDER";
then
echo "Stubgen test failed, because generated stubs failed to type check."
EXIT=1
fi

# Compare generated stubs to expected ones
if ! git diff --exit-code "$STUBGEN_OUTPUT_FOLDER";
then
echo "Stubgen test failed, because generated stubs differ from expected outputs."
EXIT=1
fi
}

# create stubs without docstrings
stubgenc_test stubgen -p pybind11_mypy_demo
stubgenc_test expected_stubs_no_docs -p pybind11_mypy_demo
# create stubs with docstrings
stubgenc_test stubgen-include-docs -p pybind11_mypy_demo --include-docstrings
stubgenc_test expected_stubs_with_docs -p pybind11_mypy_demo --include-docstrings

exit $EXIT
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
from . import demo as demo
from typing import List, Optional, Tuple

class TestStruct:
field_readwrite: int
field_readwrite_docstring: int
def __init__(self, *args, **kwargs) -> None: ...
@property
def field_readonly(self) -> int: ...

def func_incomplete_signature(*args, **kwargs): ...
def func_returning_optional() -> Optional[int]: ...
def func_returning_pair() -> Tuple[int, float]: ...
def func_returning_path() -> os.PathLike: ...
def func_returning_vector() -> List[float]: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from . import demo as demo
from typing import List, Optional, Tuple

class TestStruct:
field_readwrite: int
field_readwrite_docstring: int
def __init__(self, *args, **kwargs) -> None:
"""Initialize self. See help(type(self)) for accurate signature."""
@property
def field_readonly(self) -> int: ...

def func_incomplete_signature(*args, **kwargs):
"""func_incomplete_signature() -> dummy_sub_namespace::HasNoBinding"""
def func_returning_optional() -> Optional[int]:
"""func_returning_optional() -> Optional[int]"""
def func_returning_pair() -> Tuple[int, float]:
"""func_returning_pair() -> Tuple[int, float]"""
def func_returning_path() -> os.PathLike:
"""func_returning_path() -> os.PathLike"""
def func_returning_vector() -> List[float]:
"""func_returning_vector() -> List[float]"""
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ class Point:
degree: ClassVar[Point.AngleUnit] = ...
radian: ClassVar[Point.AngleUnit] = ...
def __init__(self, value: int) -> None:
"""__init__(self: pybind11_mypy_demo.basics.Point.AngleUnit, value: int) -> None"""
"""__init__(self: pybind11_mypy_demo.demo.Point.AngleUnit, value: int) -> None"""
def __eq__(self, other: object) -> bool:
"""__eq__(self: object, other: object) -> bool"""
def __hash__(self) -> int:
"""__hash__(self: object) -> int"""
def __index__(self) -> int:
"""__index__(self: pybind11_mypy_demo.basics.Point.AngleUnit) -> int"""
"""__index__(self: pybind11_mypy_demo.demo.Point.AngleUnit) -> int"""
def __int__(self) -> int:
"""__int__(self: pybind11_mypy_demo.basics.Point.AngleUnit) -> int"""
"""__int__(self: pybind11_mypy_demo.demo.Point.AngleUnit) -> int"""
def __ne__(self, other: object) -> bool:
"""__ne__(self: object, other: object) -> bool"""
@property
Expand All @@ -33,15 +33,15 @@ class Point:
mm: ClassVar[Point.LengthUnit] = ...
pixel: ClassVar[Point.LengthUnit] = ...
def __init__(self, value: int) -> None:
"""__init__(self: pybind11_mypy_demo.basics.Point.LengthUnit, value: int) -> None"""
"""__init__(self: pybind11_mypy_demo.demo.Point.LengthUnit, value: int) -> None"""
def __eq__(self, other: object) -> bool:
"""__eq__(self: object, other: object) -> bool"""
def __hash__(self) -> int:
"""__hash__(self: object) -> int"""
def __index__(self) -> int:
"""__index__(self: pybind11_mypy_demo.basics.Point.LengthUnit) -> int"""
"""__index__(self: pybind11_mypy_demo.demo.Point.LengthUnit) -> int"""
def __int__(self) -> int:
"""__int__(self: pybind11_mypy_demo.basics.Point.LengthUnit) -> int"""
"""__int__(self: pybind11_mypy_demo.demo.Point.LengthUnit) -> int"""
def __ne__(self, other: object) -> bool:
"""__ne__(self: object, other: object) -> bool"""
@property
Expand All @@ -60,38 +60,38 @@ class Point:
"""__init__(*args, **kwargs)
Overloaded function.
1. __init__(self: pybind11_mypy_demo.basics.Point) -> None
1. __init__(self: pybind11_mypy_demo.demo.Point) -> None
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
2. __init__(self: pybind11_mypy_demo.demo.Point, x: float, y: float) -> None
"""
@overload
def __init__(self, x: float, y: float) -> None:
"""__init__(*args, **kwargs)
Overloaded function.
1. __init__(self: pybind11_mypy_demo.basics.Point) -> None
1. __init__(self: pybind11_mypy_demo.demo.Point) -> None
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
2. __init__(self: pybind11_mypy_demo.demo.Point, x: float, y: float) -> None
"""
def as_list(self) -> List[float]:
"""as_list(self: pybind11_mypy_demo.basics.Point) -> List[float]"""
"""as_list(self: pybind11_mypy_demo.demo.Point) -> List[float]"""
@overload
def distance_to(self, x: float, y: float) -> float:
"""distance_to(*args, **kwargs)
Overloaded function.
1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float
1. distance_to(self: pybind11_mypy_demo.demo.Point, x: float, y: float) -> float
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float
2. distance_to(self: pybind11_mypy_demo.demo.Point, other: pybind11_mypy_demo.demo.Point) -> float
"""
@overload
def distance_to(self, other: Point) -> float:
"""distance_to(*args, **kwargs)
Overloaded function.
1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float
1. distance_to(self: pybind11_mypy_demo.demo.Point, x: float, y: float) -> float
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float
2. distance_to(self: pybind11_mypy_demo.demo.Point, other: pybind11_mypy_demo.demo.Point) -> float
"""
@property
def length(self) -> float: ...
Expand Down
109 changes: 96 additions & 13 deletions test-data/pybind11_mypy_demo/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,87 @@
*/

#include <cmath>
#include <filesystem>
#include <optional>
#include <utility>
#include <vector>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl/filesystem.h>

namespace py = pybind11;

namespace basics {
// ----------------------------------------------------------------------------
// Dedicated test cases
// ----------------------------------------------------------------------------

std::vector<float> funcReturningVector()
{
return std::vector<float>{1.0, 2.0, 3.0};
}

std::pair<int, float> funcReturningPair()
{
return std::pair{42, 1.0};
}

std::optional<int> funcReturningOptional()
{
return std::nullopt;
}

std::filesystem::path funcReturningPath()
{
// This example does not include <pybind11/stl/filesystem.h> on purpose
// to demonstrate the signature of an incomplete binding.
return std::filesystem::path{"foobar"};
}

namespace dummy_sub_namespace {
struct HasNoBinding{};
}

// We can enforce the case of an incomplete signature by referring to a type in
// some namespace that doesn't have a pybind11 binding.
dummy_sub_namespace::HasNoBinding funcIncompleteSignature()
{
return dummy_sub_namespace::HasNoBinding{};
}

struct TestStruct
{
int field_readwrite;
int field_readwrite_docstring;
int field_readonly;
};

// Bindings

void bind_test_cases(py::module& m) {
m.def("func_returning_vector", &funcReturningVector);
m.def("func_returning_pair", &funcReturningPair);
m.def("func_returning_optional", &funcReturningOptional);
m.def("func_returning_path", &funcReturningPath);

m.def("func_incomplete_signature", &funcIncompleteSignature);

py::class_<TestStruct>(m, "TestStruct")
.def_readwrite("field_readwrite", &TestStruct::field_readwrite)
.def_readwrite("field_readwrite_docstring", &TestStruct::field_readwrite_docstring, "some docstring")
.def_property_readonly(
"field_readonly",
[](const TestStruct& x) {
return x.field_readonly;
},
"some docstring");
}

// ----------------------------------------------------------------------------
// Original demo
// ----------------------------------------------------------------------------

namespace demo {

int answer() {
return 42;
Expand Down Expand Up @@ -118,20 +193,22 @@ const Point Point::y_axis = Point(0, 1);
Point::LengthUnit Point::length_unit = Point::LengthUnit::mm;
Point::AngleUnit Point::angle_unit = Point::AngleUnit::radian;

} // namespace: basics
} // namespace: demo

void bind_basics(py::module& basics) {
// Bindings

using namespace basics;
void bind_demo(py::module& m) {

using namespace demo;

// Functions
basics.def("answer", &answer, "answer docstring, with end quote\""); // tests explicit docstrings
basics.def("sum", &sum, "multiline docstring test, edge case quotes \"\"\"'''");
basics.def("midpoint", &midpoint, py::arg("left"), py::arg("right"));
basics.def("weighted_midpoint", weighted_midpoint, py::arg("left"), py::arg("right"), py::arg("alpha")=0.5);
m.def("answer", &answer, "answer docstring, with end quote\""); // tests explicit docstrings
m.def("sum", &sum, "multiline docstring test, edge case quotes \"\"\"'''");
m.def("midpoint", &midpoint, py::arg("left"), py::arg("right"));
m.def("weighted_midpoint", weighted_midpoint, py::arg("left"), py::arg("right"), py::arg("alpha")=0.5);

// Classes
py::class_<Point> pyPoint(basics, "Point");
py::class_<Point> pyPoint(m, "Point");
py::enum_<Point::LengthUnit> pyLengthUnit(pyPoint, "LengthUnit");
py::enum_<Point::AngleUnit> pyAngleUnit(pyPoint, "AngleUnit");

Expand Down Expand Up @@ -167,11 +244,17 @@ void bind_basics(py::module& basics) {
.value("degree", Point::AngleUnit::degree);

// Module-level attributes
basics.attr("PI") = std::acos(-1);
basics.attr("__version__") = "0.0.1";
m.attr("PI") = std::acos(-1);
m.attr("__version__") = "0.0.1";
}

// ----------------------------------------------------------------------------
// Module entry point
// ----------------------------------------------------------------------------

PYBIND11_MODULE(pybind11_mypy_demo, m) {
auto basics = m.def_submodule("basics");
bind_basics(basics);
bind_test_cases(m);

auto demo = m.def_submodule("demo");
bind_demo(demo);
}

This file was deleted.

This file was deleted.

0 comments on commit 1e56aa4

Please sign in to comment.