-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a29fa3d
commit 5181356
Showing
7 changed files
with
228 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
cmake_minimum_required(VERSION 3.27) | ||
|
||
if (POLICY CMP0076) | ||
# target_sources() converts relative paths to absolute | ||
cmake_policy(SET CMP0076 NEW) | ||
endif() | ||
|
||
project(neighbors_convert CXX) | ||
|
||
# Set a default build type if none was specified | ||
if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR}) | ||
if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "") | ||
message(STATUS "Setting build type to 'relwithdebinfo' as none was specified.") | ||
set( | ||
CMAKE_BUILD_TYPE "relwithdebinfo" | ||
CACHE STRING | ||
"Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel." | ||
FORCE | ||
) | ||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none) | ||
endif() | ||
endif() | ||
|
||
# add path to the cmake configuration of the version of libtorch used | ||
# by the Python torch module. PYTHON_EXECUTABLE is provided by skbuild | ||
execute_process( | ||
COMMAND ${PYTHON_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)" | ||
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT | ||
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT | ||
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR | ||
) | ||
|
||
if (NOT ${TORCH_CMAKE_PATH_RESULT} EQUAL 0) | ||
message(FATAL_ERROR "failed to find your pytorch installation\n${TORCH_CMAKE_PATH_ERROR}") | ||
endif() | ||
|
||
string(STRIP ${TORCH_CMAKE_PATH_OUTPUT} TORCH_CMAKE_PATH_OUTPUT) | ||
set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH};${TORCH_CMAKE_PATH_OUTPUT}") | ||
|
||
find_package(Torch 2.3 REQUIRED) | ||
|
||
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/_build_torch_version.py "BUILD_TORCH_VERSION = '${Torch_VERSION}'") | ||
|
||
add_library(neighbors_convert SHARED | ||
"src/pet_neighbors_convert/neighbors_convert.cpp" | ||
) | ||
|
||
# only link to `torch_cpu_library` instead of `torch`, which could also include | ||
# `libtorch_cuda`. | ||
target_link_libraries(neighbors_convert PUBLIC torch_cpu_library) | ||
target_include_directories(neighbors_convert PUBLIC "${TORCH_INCLUDE_DIRS}") | ||
target_compile_definitions(neighbors_convert PUBLIC "${TORCH_CXX_FLAGS}") | ||
|
||
target_compile_features(neighbors_convert PUBLIC cxx_std_17) | ||
|
||
install(TARGETS neighbors_convert | ||
LIBRARY DESTINATION "lib" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ graft src | |
|
||
include LICENSE | ||
include README.rst | ||
include CMakeLists.txt | ||
|
||
prune .github | ||
prune .tox | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,113 @@ | ||
import os | ||
import subprocess | ||
import sys | ||
from setuptools import setup, Extension | ||
|
||
from setuptools import Extension, setup | ||
from setuptools.command.bdist_egg import bdist_egg | ||
from setuptools.command.build_ext import build_ext | ||
from wheel.bdist_wheel import bdist_wheel | ||
|
||
|
||
class CustomBuildExt(build_ext): | ||
def build_extensions(self): | ||
# Import torch here, when we actually need it | ||
from torch.utils.cpp_extension import include_paths, library_paths | ||
ROOT = os.path.realpath(os.path.dirname(__file__)) | ||
|
||
# Update extension settings with torch-specific information | ||
for ext in self.extensions: | ||
ext.include_dirs = include_paths() | ||
ext.library_dirs = library_paths() | ||
ext.libraries = ["c10", "torch", "torch_cpu"] | ||
|
||
super().build_extensions() | ||
class universal_wheel(bdist_wheel): | ||
# When building the wheel, the `wheel` package assumes that if we have a | ||
# binary extension then we are linking to `libpython.so`; and thus the wheel | ||
# is only usable with a single python version. This is not the case for | ||
# here, and the wheel will be compatible with any Python >=3.7. This is | ||
# tracked in https://github.com/pypa/wheel/issues/185, but until then we | ||
# manually override the wheel tag. | ||
def get_tag(self): | ||
tag = bdist_wheel.get_tag(self) | ||
# tag[2:] contains the os/arch tags, we want to keep them | ||
return ("py3", "none") + tag[2:] | ||
|
||
|
||
if __name__ == "__main__": | ||
# Basic compilation settings | ||
extra_compile_args = ["-std=c++17"] | ||
extra_link_args = [] | ||
if sys.platform == "darwin": | ||
shared_lib_ext = ".dylib" | ||
extra_compile_args.append("-stdlib=libc++") | ||
extra_link_args.extend(["-stdlib=libc++", "-mmacosx-version-min=10.9"]) | ||
elif sys.platform == "linux": | ||
extra_compile_args.append("-fPIC") | ||
shared_lib_ext = ".so" | ||
else: | ||
raise RuntimeError(f"Unsupported platform {sys.platform}") | ||
|
||
neighbors_convert_extension = Extension( | ||
name="pet_neighbors_convert.neighbors_convert", | ||
sources=["src/pet_neighbors_convert/neighbors_convert.cpp"], | ||
language="c++", | ||
extra_compile_args=extra_compile_args, | ||
extra_link_args=extra_link_args, | ||
) | ||
class cmake_ext(build_ext): | ||
"""Build the native library using cmake""" | ||
|
||
def run(self): | ||
import torch | ||
|
||
torch_major, torch_minor, *_ = torch.__version__.split(".") | ||
|
||
source_dir = ROOT | ||
build_dir = os.path.join(ROOT, "build", "cmake-build") | ||
install_dir = os.path.join( | ||
os.path.realpath(self.build_lib), | ||
f"pet_neighbors_convert/torch-{torch_major}.{torch_minor}", | ||
) | ||
|
||
os.makedirs(build_dir, exist_ok=True) | ||
|
||
cmake_options = [ | ||
"-DCMAKE_BUILD_TYPE=Release", | ||
f"-DCMAKE_INSTALL_PREFIX={install_dir}", | ||
f"-DPYTHON_EXECUTABLE={sys.executable}", | ||
] | ||
|
||
if sys.platform.startswith("darwin"): | ||
cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0") | ||
|
||
# ARCHFLAGS is used by cibuildwheel to pass the requested arch to the | ||
# compilers | ||
ARCHFLAGS = os.environ.get("ARCHFLAGS") | ||
if ARCHFLAGS is not None: | ||
cmake_options.append(f"-DCMAKE_C_FLAGS={ARCHFLAGS}") | ||
cmake_options.append(f"-DCMAKE_CXX_FLAGS={ARCHFLAGS}") | ||
|
||
subprocess.run( | ||
["cmake", source_dir, *cmake_options], | ||
cwd=build_dir, | ||
check=True, | ||
) | ||
build_command = [ | ||
"cmake", | ||
"--build", | ||
build_dir, | ||
"--target", | ||
"install", | ||
] | ||
|
||
subprocess.run(build_command, check=True) | ||
|
||
|
||
class bdist_egg_disabled(bdist_egg): | ||
"""Disabled version of bdist_egg | ||
Prevents setup.py install performing setuptools' default easy_install, | ||
which it should never ever do. | ||
""" | ||
|
||
def run(self): | ||
sys.exit( | ||
"Aborting implicit building of eggs. " | ||
"Use `pip install .` to install from source." | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
try: | ||
import torch | ||
|
||
# if we have torch, we are building a wheel, which will only be compatible with | ||
# a single torch version | ||
# if we have torch, we are building a wheel - requires specific torch version | ||
torch_v_major, torch_v_minor, *_ = torch.__version__.split(".") | ||
torch_version = f"== {torch_v_major}.{torch_v_minor}.*" | ||
except ImportError: | ||
# otherwise we are building a sdist | ||
torch_version = ">= 1.12" | ||
torch_version = ">= 2.1" | ||
|
||
install_requires = [f"torch {torch_version}"] | ||
|
||
setup( | ||
install_requires=[f"torch {torch_version}"], | ||
ext_modules=[neighbors_convert_extension], | ||
cmdclass={"build_ext": CustomBuildExt}, | ||
package_data={"pet_neighbors_convert": [f"neighbors_convert{shared_lib_ext}"]}, | ||
install_requires=install_requires, | ||
ext_modules=[ | ||
Extension(name="neighbors_convert", sources=[]), | ||
], | ||
cmdclass={ | ||
"build_ext": cmake_ext, | ||
"bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled, | ||
"bdist_wheel": universal_wheel, | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,72 @@ | ||
import os | ||
import sys | ||
from collections import namedtuple | ||
|
||
import torch | ||
import importlib.resources as pkg_resources | ||
from ._version import __version__ # noqa | ||
|
||
import re | ||
import glob | ||
|
||
|
||
Version = namedtuple("Version", ["major", "minor", "patch"]) | ||
|
||
|
||
def parse_version(version): | ||
match = re.match(r"(\d+)\.(\d+)\.(\d+).*", version) | ||
if match: | ||
return Version(*map(int, match.groups())) | ||
else: | ||
raise ValueError("Invalid version string format") | ||
|
||
|
||
_HERE = os.path.realpath(os.path.dirname(__file__)) | ||
|
||
|
||
def _lib_path(): | ||
torch_version = parse_version(torch.__version__) | ||
expected_prefix = os.path.join( | ||
_HERE, f"torch-{torch_version.major}.{torch_version.minor}" | ||
) | ||
if os.path.exists(expected_prefix): | ||
if sys.platform.startswith("darwin"): | ||
path = os.path.join(expected_prefix, "lib", "libneighbors_convert.dylib") | ||
elif sys.platform.startswith("linux"): | ||
path = os.path.join(expected_prefix, "lib", "libneighbors_convert.so") | ||
elif sys.platform.startswith("win"): | ||
path = os.path.join(expected_prefix, "bin", "neighbors_convert.dll") | ||
else: | ||
raise ImportError("Unknown platform. Please edit this file") | ||
|
||
if os.path.isfile(path): | ||
return path | ||
else: | ||
raise ImportError( | ||
"Could not find neighbors_convert shared library at " + path | ||
) | ||
|
||
# gather which torch version(s) the current install was built | ||
# with to create the error message | ||
existing_versions = [] | ||
for prefix in glob.glob(os.path.join(_HERE, "torch-*")): | ||
existing_versions.append(os.path.basename(prefix)[11:]) | ||
|
||
print(existing_versions) | ||
|
||
def load_neighbors_convert(): | ||
try: | ||
# Locate the shared object file in the package | ||
with pkg_resources.files(__name__).joinpath("neighbors_convert.so") as lib_path: | ||
# Load the shared object file | ||
torch.ops.load_library(str(lib_path)) | ||
except Exception as e: | ||
print(f"Failed to load neighbors_convert.so: {e}") | ||
if len(existing_versions) == 1: | ||
raise ImportError( | ||
f"Trying to load neighbors-convert with torch v{torch.__version__}, " | ||
f"but it was compiled against torch v{existing_versions[0]}, which " | ||
"is not ABI compatible" | ||
) | ||
else: | ||
all_versions = ", ".join(map(lambda version: f"v{version}", existing_versions)) | ||
raise ImportError( | ||
f"Trying to load neighbors-convert with torch v{torch.__version__}, " | ||
f"we found builds for torch {all_versions}; which are not ABI compatible.\n" | ||
"You can try to re-install from source with " | ||
"`pip install neighbors-convert --no-binary=neighbors-convert`" | ||
) | ||
|
||
|
||
load_neighbors_convert() | ||
# load the C++ operators and custom classes | ||
torch.classes.load_library(_lib_path()) |