Skip to content

Commit

Permalink
Switch to cmake-based build
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Feb 14, 2025
1 parent a29fa3d commit 5181356
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 56 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build-wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ jobs:
run: python -m pip install build

- name: build sdist
env:
PIP_EXTRA_INDEX_URL: "https://download.pytorch.org/whl/cpu"
run: python -m build . --outdir=dist/

- uses: actions/upload-artifact@v4
Expand Down
58 changes: 58 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
cmake_minimum_required(VERSION 3.27)

if (POLICY CMP0076)
# target_sources() converts relative paths to absolute
cmake_policy(SET CMP0076 NEW)
endif()

project(neighbors_convert CXX)

# Set a default build type if none was specified
if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR})
if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "")
message(STATUS "Setting build type to 'relwithdebinfo' as none was specified.")
set(
CMAKE_BUILD_TYPE "relwithdebinfo"
CACHE STRING
"Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel."
FORCE
)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none)
endif()
endif()

# add path to the cmake configuration of the version of libtorch used
# by the Python torch module. PYTHON_EXECUTABLE is provided by skbuild
execute_process(
COMMAND ${PYTHON_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR
)

if (NOT ${TORCH_CMAKE_PATH_RESULT} EQUAL 0)
message(FATAL_ERROR "failed to find your pytorch installation\n${TORCH_CMAKE_PATH_ERROR}")
endif()

string(STRIP ${TORCH_CMAKE_PATH_OUTPUT} TORCH_CMAKE_PATH_OUTPUT)
set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH};${TORCH_CMAKE_PATH_OUTPUT}")

find_package(Torch 2.3 REQUIRED)

file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/_build_torch_version.py "BUILD_TORCH_VERSION = '${Torch_VERSION}'")

add_library(neighbors_convert SHARED
"src/pet_neighbors_convert/neighbors_convert.cpp"
)

# only link to `torch_cpu_library` instead of `torch`, which could also include
# `libtorch_cuda`.
target_link_libraries(neighbors_convert PUBLIC torch_cpu_library)
target_include_directories(neighbors_convert PUBLIC "${TORCH_INCLUDE_DIRS}")
target_compile_definitions(neighbors_convert PUBLIC "${TORCH_CXX_FLAGS}")

target_compile_features(neighbors_convert PUBLIC cxx_std_17)

install(TARGETS neighbors_convert
LIBRARY DESTINATION "lib"
)
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ graft src

include LICENSE
include README.rst
include CMakeLists.txt

prune .github
prune .tox
Expand Down
2 changes: 1 addition & 1 deletion build-backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
if FORCED_TORCH_VERSION is not None:
TORCH_DEP = f"torch =={FORCED_TORCH_VERSION}"
else:
TORCH_DEP = "torch >=1.12"
TORCH_DEP = "torch >=2.3"

# ==================================================================================== #
# Build backend functions definition #
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ classifiers = [
"Topic :: Scientific/Engineering",
"Topic :: Software Development :: Libraries :: Python Modules"
]
description = "Extension for PET model for reordering the neighbors during message passing."
description = "PET model extension for processing neighbor lists"
dynamic = ["version", "dependencies"]
license = {text = "BSD-3-Clause"}
name = "pet-neighbors-convert"
readme = "README.rst"
requires-python = ">=3.9"

[tool.check-manifest]
ignore = ["src/pet_neighbors_convert/_version.py"]
# [tool.check-manifest]
# ignore = ["src/pet_neighbors_convert/_version.py"]

[tool.setuptools_scm]
version_file = "src/pet_neighbors_convert/_version.py"
# [tool.setuptools_scm]
# version_file = "src/pet_neighbors_convert/_version.py"

[tool.setuptools.packages.find]
where = ["src"]
133 changes: 94 additions & 39 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,113 @@
import os
import subprocess
import sys
from setuptools import setup, Extension

from setuptools import Extension, setup
from setuptools.command.bdist_egg import bdist_egg
from setuptools.command.build_ext import build_ext
from wheel.bdist_wheel import bdist_wheel


class CustomBuildExt(build_ext):
def build_extensions(self):
# Import torch here, when we actually need it
from torch.utils.cpp_extension import include_paths, library_paths
ROOT = os.path.realpath(os.path.dirname(__file__))

# Update extension settings with torch-specific information
for ext in self.extensions:
ext.include_dirs = include_paths()
ext.library_dirs = library_paths()
ext.libraries = ["c10", "torch", "torch_cpu"]

super().build_extensions()
class universal_wheel(bdist_wheel):
# When building the wheel, the `wheel` package assumes that if we have a
# binary extension then we are linking to `libpython.so`; and thus the wheel
# is only usable with a single python version. This is not the case for
# here, and the wheel will be compatible with any Python >=3.7. This is
# tracked in https://github.com/pypa/wheel/issues/185, but until then we
# manually override the wheel tag.
def get_tag(self):
tag = bdist_wheel.get_tag(self)
# tag[2:] contains the os/arch tags, we want to keep them
return ("py3", "none") + tag[2:]


if __name__ == "__main__":
# Basic compilation settings
extra_compile_args = ["-std=c++17"]
extra_link_args = []
if sys.platform == "darwin":
shared_lib_ext = ".dylib"
extra_compile_args.append("-stdlib=libc++")
extra_link_args.extend(["-stdlib=libc++", "-mmacosx-version-min=10.9"])
elif sys.platform == "linux":
extra_compile_args.append("-fPIC")
shared_lib_ext = ".so"
else:
raise RuntimeError(f"Unsupported platform {sys.platform}")

neighbors_convert_extension = Extension(
name="pet_neighbors_convert.neighbors_convert",
sources=["src/pet_neighbors_convert/neighbors_convert.cpp"],
language="c++",
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
class cmake_ext(build_ext):
"""Build the native library using cmake"""

def run(self):
import torch

torch_major, torch_minor, *_ = torch.__version__.split(".")

source_dir = ROOT
build_dir = os.path.join(ROOT, "build", "cmake-build")
install_dir = os.path.join(
os.path.realpath(self.build_lib),
f"pet_neighbors_convert/torch-{torch_major}.{torch_minor}",
)

os.makedirs(build_dir, exist_ok=True)

cmake_options = [
"-DCMAKE_BUILD_TYPE=Release",
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
f"-DPYTHON_EXECUTABLE={sys.executable}",
]

if sys.platform.startswith("darwin"):
cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0")

# ARCHFLAGS is used by cibuildwheel to pass the requested arch to the
# compilers
ARCHFLAGS = os.environ.get("ARCHFLAGS")
if ARCHFLAGS is not None:
cmake_options.append(f"-DCMAKE_C_FLAGS={ARCHFLAGS}")
cmake_options.append(f"-DCMAKE_CXX_FLAGS={ARCHFLAGS}")

subprocess.run(
["cmake", source_dir, *cmake_options],
cwd=build_dir,
check=True,
)
build_command = [
"cmake",
"--build",
build_dir,
"--target",
"install",
]

subprocess.run(build_command, check=True)


class bdist_egg_disabled(bdist_egg):
"""Disabled version of bdist_egg
Prevents setup.py install performing setuptools' default easy_install,
which it should never ever do.
"""

def run(self):
sys.exit(
"Aborting implicit building of eggs. "
"Use `pip install .` to install from source."
)


if __name__ == "__main__":
try:
import torch

# if we have torch, we are building a wheel, which will only be compatible with
# a single torch version
# if we have torch, we are building a wheel - requires specific torch version
torch_v_major, torch_v_minor, *_ = torch.__version__.split(".")
torch_version = f"== {torch_v_major}.{torch_v_minor}.*"
except ImportError:
# otherwise we are building a sdist
torch_version = ">= 1.12"
torch_version = ">= 2.1"

install_requires = [f"torch {torch_version}"]

setup(
install_requires=[f"torch {torch_version}"],
ext_modules=[neighbors_convert_extension],
cmdclass={"build_ext": CustomBuildExt},
package_data={"pet_neighbors_convert": [f"neighbors_convert{shared_lib_ext}"]},
install_requires=install_requires,
ext_modules=[
Extension(name="neighbors_convert", sources=[]),
],
cmdclass={
"build_ext": cmake_ext,
"bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled,
"bdist_wheel": universal_wheel,
},
)
78 changes: 67 additions & 11 deletions src/pet_neighbors_convert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,72 @@
import os
import sys
from collections import namedtuple

import torch
import importlib.resources as pkg_resources
from ._version import __version__ # noqa

import re
import glob


Version = namedtuple("Version", ["major", "minor", "patch"])


def parse_version(version):
match = re.match(r"(\d+)\.(\d+)\.(\d+).*", version)
if match:
return Version(*map(int, match.groups()))
else:
raise ValueError("Invalid version string format")


_HERE = os.path.realpath(os.path.dirname(__file__))


def _lib_path():
torch_version = parse_version(torch.__version__)
expected_prefix = os.path.join(
_HERE, f"torch-{torch_version.major}.{torch_version.minor}"
)
if os.path.exists(expected_prefix):
if sys.platform.startswith("darwin"):
path = os.path.join(expected_prefix, "lib", "libneighbors_convert.dylib")
elif sys.platform.startswith("linux"):
path = os.path.join(expected_prefix, "lib", "libneighbors_convert.so")
elif sys.platform.startswith("win"):
path = os.path.join(expected_prefix, "bin", "neighbors_convert.dll")
else:
raise ImportError("Unknown platform. Please edit this file")

if os.path.isfile(path):
return path
else:
raise ImportError(
"Could not find neighbors_convert shared library at " + path
)

# gather which torch version(s) the current install was built
# with to create the error message
existing_versions = []
for prefix in glob.glob(os.path.join(_HERE, "torch-*")):
existing_versions.append(os.path.basename(prefix)[11:])

print(existing_versions)

def load_neighbors_convert():
try:
# Locate the shared object file in the package
with pkg_resources.files(__name__).joinpath("neighbors_convert.so") as lib_path:
# Load the shared object file
torch.ops.load_library(str(lib_path))
except Exception as e:
print(f"Failed to load neighbors_convert.so: {e}")
if len(existing_versions) == 1:
raise ImportError(
f"Trying to load neighbors-convert with torch v{torch.__version__}, "
f"but it was compiled against torch v{existing_versions[0]}, which "
"is not ABI compatible"
)
else:
all_versions = ", ".join(map(lambda version: f"v{version}", existing_versions))
raise ImportError(
f"Trying to load neighbors-convert with torch v{torch.__version__}, "
f"we found builds for torch {all_versions}; which are not ABI compatible.\n"
"You can try to re-install from source with "
"`pip install neighbors-convert --no-binary=neighbors-convert`"
)


load_neighbors_convert()
# load the C++ operators and custom classes
torch.classes.load_library(_lib_path())

0 comments on commit 5181356

Please sign in to comment.