From 4db79ea5f16f8c7e4268efd0f158e9f80c199e29 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Fri, 14 Feb 2025 15:08:59 +0100 Subject: [PATCH 1/4] Add CI for building wheels --- .github/workflows/build-wheels.yml | 236 +++++++++++++++++++++++++ CMakeLists.txt | 58 ++++++ MANIFEST.in | 5 +- build-backend/backend.py | 26 +++ pyproject.toml | 10 +- scripts/create-torch-versions-range.py | 44 +++++ setup.py | 150 ++++++++++++---- src/pet_neighbors_convert/__init__.py | 76 ++++++-- tests/test_init.py | 7 +- tox.ini | 2 - 10 files changed, 554 insertions(+), 60 deletions(-) create mode 100644 .github/workflows/build-wheels.yml create mode 100644 CMakeLists.txt create mode 100644 build-backend/backend.py create mode 100755 scripts/create-torch-versions-range.py diff --git a/.github/workflows/build-wheels.yml b/.github/workflows/build-wheels.yml new file mode 100644 index 0000000..c0ecc55 --- /dev/null +++ b/.github/workflows/build-wheels.yml @@ -0,0 +1,236 @@ +name: Build wheels + +on: + push: + branches: [main] + tags: ["*"] + pull_request: + # Check all PR + + +concurrency: + group: python-wheels-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + + +jobs: + build-wheels: + runs-on: ${{ matrix.os }} + name: ${{ matrix.name }} (torch v${{ matrix.torch-version }}) + strategy: + matrix: + torch-version: ["2.3", "2.4", "2.5", "2.6"] + arch: ["arm64", "x86_64"] + os: ["ubuntu-22.04", "ubuntu-22.04-arm", "macos-14"] + exclude: + # remove mismatched arch/os pairs + - {os: macos-14, arch: x86_64} + - {os: ubuntu-22.04, arch: arm64} + - {os: ubuntu-22.04-arm, arch: x86_64} + include: + # add `cibw-arch` to the different configurations + - name: x86_64 Linux + os: ubuntu-22.04 + arch: x86_64 + cibw-arch: x86_64 + - name: arm64 Linux + os: ubuntu-22.04-arm + arch: arm64 + cibw-arch: aarch64 + - name: arm64 macOS + os: macos-14 + arch: arm64 + cibw-arch: arm64 + # set the right manylinux image to use + - {torch-version: '2.3', manylinux-version: "2014"} + - {torch-version: '2.4', manylinux-version: "2014"} + - {torch-version: '2.5', manylinux-version: "2014"} + # only torch >= 2.6 on arm64-linux needs the newer manylinux + - {torch-version: '2.6', arch: x86_64, manylinux-version: "2014"} + - {torch-version: '2.6', arch: arm64, manylinux-version: "_2_28"} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: install dependencies + run: python -m pip install cibuildwheel + + - name: build wheel + run: python -m cibuildwheel --output-dir ./wheelhouse + env: + CIBW_BUILD: cp312-* + CIBW_SKIP: "*musllinux*" + CIBW_ARCHS: "${{ matrix.cibw-arch }}" + CIBW_BUILD_VERBOSITY: 1 + CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux${{ matrix.manylinux-version }}_x86_64 + CIBW_MANYLINUX_AARCH64_IMAGE: quay.io/pypa/manylinux${{ matrix.manylinux-version }}_aarch64 + CIBW_ENVIRONMENT: > + PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu + MACOSX_DEPLOYMENT_TARGET=11 + PETNC_BUILD_WITH_TORCH_VERSION=${{ matrix.torch-version }}.* + # do not complain for missing libtorch.so + CIBW_REPAIR_WHEEL_COMMAND_MACOS: | + delocate-wheel --ignore-missing-dependencies --require-archs {delocate_archs} -w {dest_dir} -v {wheel} + CIBW_REPAIR_WHEEL_COMMAND_LINUX: | + auditwheel repair --exclude libtorch.so --exclude libtorch_cpu.so --exclude libc10.so -w {dest_dir} {wheel} + + - uses: actions/upload-artifact@v4 + with: + name: single-version-wheel-${{ matrix.torch-version }}-${{ matrix.os }}-${{ matrix.arch }} + path: ./wheelhouse/*.whl + + merge-wheels: + needs: build-wheels + runs-on: ubuntu-22.04 + name: merge wheels for ${{ matrix.name }} + strategy: + matrix: + include: + - name: x86_64 Linux + os: ubuntu-22.04 + arch: x86_64 + - name: arm64 Linux + os: ubuntu-22.04-arm + arch: arm64 + - name: arm64 macOS + os: macos-14 + arch: arm64 + steps: + - uses: actions/checkout@v4 + + - name: Download wheels + uses: actions/download-artifact@v4 + with: + pattern: single-version-wheel-*-${{ matrix.os }}-${{ matrix.arch }} + merge-multiple: false + path: dist + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: install dependencies + run: python -m pip install twine wheel + + - name: merge wheels + run: | + # collect all torch versions used for the build + REQUIRES_TORCH=$(find dist -name "*.whl" -exec unzip -p {} "pet_neighbors_convert-*.dist-info/METADATA" \; | grep "Requires-Dist: torch") + MERGED_TORCH_REQUIRE=$(python scripts/create-torch-versions-range.py "$REQUIRES_TORCH") + + echo MERGED_TORCH_REQUIRE=$MERGED_TORCH_REQUIRE + + # unpack all single torch versions wheels in the same directory + mkdir dist/unpacked + find dist -name "*.whl" -print -exec python -m wheel unpack --dest dist/unpacked/ {} ';' + + sed -i "s/Requires-Dist: torch.*/$MERGED_TORCH_REQUIRE/" dist/unpacked/pet_neighbors_convert-*/pet_neighbors_convert-*.dist-info/METADATA + + echo "\n\n METADATA = \n\n" + cat dist/unpacked/pet_neighbors_convert-*/pet_neighbors_convert-*.dist-info/METADATA + + # check the right metadata was added to the file. grep will exit with + # code `1` if the line is not found, which will stop CI + grep "$MERGED_TORCH_REQUIRE" dist/unpacked/pet_neighbors_convert-*/pet_neighbors_convert-*.dist-info/METADATA + + # repack the directory as a new wheel + mkdir wheelhouse + python -m wheel pack --dest wheelhouse/ dist/unpacked/* + + - name: check wheels with twine + run: twine check wheelhouse/* + + - uses: actions/upload-artifact@v4 + with: + name: wheel-${{ matrix.os }}-${{ matrix.arch }} + path: ./wheelhouse/*.whl + + build-sdist: + name: sdist + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: install dependencies + run: python -m pip install build + + - name: build sdist + env: + PIP_EXTRA_INDEX_URL: "https://download.pytorch.org/whl/cpu" + run: python -m build . --outdir=dist/ + + - uses: actions/upload-artifact@v4 + with: + name: sdist + path: dist/*.tar.gz + + merge-and-release: + name: Merge and release wheels/sdists + needs: [merge-wheels, build-sdist] + runs-on: ubuntu-22.04 + permissions: + contents: write + id-token: write + pull-requests: write + environment: + name: pypi + url: https://pypi.org/project/pet-neighbors-convert + steps: + - name: Download wheels + uses: actions/download-artifact@v4 + with: + path: wheels + pattern: wheel-* + merge-multiple: true + + - name: Download sdist + uses: actions/download-artifact@v4 + with: + path: wheels + name: sdist + + - name: Re-upload a single wheels artifact + uses: actions/upload-artifact@v4 + with: + name: wheels + path: wheels/* + + - name: Comment with download link + uses: PicoCentauri/comment-artifact@v1 + if: github.event.pull_request.head.repo.fork == false + with: + name: wheels + description: ⚙️ Download Python wheels for this pull-request (you can install these with pip) + + - name: Publish distribution to PyPI + if: startsWith(github.ref, 'refs/tags/v') + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: wheels + + - name: upload to GitHub release + if: startsWith(github.ref, 'refs/tags/v') + uses: softprops/action-gh-release@v2 + with: + files: | + wheels/*.tar.gz + wheels/*.whl + prerelease: ${{ contains(github.ref, '-rc') }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..3c52e80 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,58 @@ +cmake_minimum_required(VERSION 3.27) + +if (POLICY CMP0076) + # target_sources() converts relative paths to absolute + cmake_policy(SET CMP0076 NEW) +endif() + +project(neighbors_convert CXX) + +# Set a default build type if none was specified +if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR}) + if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "") + message(STATUS "Setting build type to 'relwithdebinfo' as none was specified.") + set( + CMAKE_BUILD_TYPE "relwithdebinfo" + CACHE STRING + "Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel." + FORCE + ) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none) + endif() +endif() + +# add path to the cmake configuration of the version of libtorch used +# by the Python torch module. PYTHON_EXECUTABLE is provided by skbuild +execute_process( + COMMAND ${PYTHON_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)" + RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT + OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT + ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR +) + +if (NOT ${TORCH_CMAKE_PATH_RESULT} EQUAL 0) +message(FATAL_ERROR "failed to find your pytorch installation\n${TORCH_CMAKE_PATH_ERROR}") +endif() + +string(STRIP ${TORCH_CMAKE_PATH_OUTPUT} TORCH_CMAKE_PATH_OUTPUT) +set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH};${TORCH_CMAKE_PATH_OUTPUT}") + +find_package(Torch 2.3 REQUIRED) + +file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/_build_torch_version.py "BUILD_TORCH_VERSION = '${Torch_VERSION}'") + +add_library(neighbors_convert SHARED + "src/pet_neighbors_convert/neighbors_convert.cpp" +) + +# only link to `torch_cpu_library` instead of `torch`, which could also include +# `libtorch_cuda`. +target_link_libraries(neighbors_convert PUBLIC torch_cpu_library) +target_include_directories(neighbors_convert PUBLIC "${TORCH_INCLUDE_DIRS}") +target_compile_definitions(neighbors_convert PUBLIC "${TORCH_CXX_FLAGS}") + +target_compile_features(neighbors_convert PUBLIC cxx_std_17) + +install(TARGETS neighbors_convert + LIBRARY DESTINATION "lib" +) diff --git a/MANIFEST.in b/MANIFEST.in index d906467..ffc799c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,12 +2,15 @@ graft src include LICENSE include README.rst +include CMakeLists.txt -prune tests prune .github prune .tox +prune tests +prune scripts exclude .gitignore exclude tox.ini +recursive-include build-backend *.py global-exclude *.py[cod] __pycache__/* *.so *.dylib diff --git a/build-backend/backend.py b/build-backend/backend.py new file mode 100644 index 0000000..794e7a1 --- /dev/null +++ b/build-backend/backend.py @@ -0,0 +1,26 @@ +# This is a custom Python build backend wrapping setuptool's to add a build-time +# dependencies on torch when building the wheel and not the sdist +import os +from setuptools import build_meta + +FORCED_TORCH_VERSION = os.environ.get("PETNC_BUILD_WITH_TORCH_VERSION") +if FORCED_TORCH_VERSION is not None: + TORCH_DEP = f"torch =={FORCED_TORCH_VERSION}" +else: + TORCH_DEP = "torch >=2.3" + +# ==================================================================================== # +# Build backend functions definition # +# ==================================================================================== # + +# Use the default version of these +prepare_metadata_for_build_wheel = build_meta.prepare_metadata_for_build_wheel +get_requires_for_build_sdist = build_meta.get_requires_for_build_sdist +build_wheel = build_meta.build_wheel +build_sdist = build_meta.build_sdist + + +# Special dependencies to build the wheels +def get_requires_for_build_wheel(config_settings=None): + defaults = build_meta.get_requires_for_build_wheel(config_settings) + return defaults + [TORCH_DEP] diff --git a/pyproject.toml b/pyproject.toml index 45c38a8..92ebd32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,7 @@ [build-system] -requires = ["setuptools >= 68", "setuptools_scm>=8", "wheel", "torch >= 2.3"] -build-backend = "setuptools.build_meta" +requires = ["setuptools >= 68", "setuptools_scm>=8", "wheel >= 0.36"] +build-backend = "backend" +backend-path = ["build-backend"] [project] authors = [{name = "lab-cosmo developers"}] @@ -15,9 +16,8 @@ classifiers = [ "Topic :: Scientific/Engineering", "Topic :: Software Development :: Libraries :: Python Modules" ] -dependencies = ["torch >= 2.3"] -description = "Extension for PET model for reordering the neighbors during message passing." -dynamic = ["version"] +description = "PET model extension for processing neighbor lists" +dynamic = ["version", "dependencies"] license = {text = "BSD-3-Clause"} name = "pet-neighbors-convert" readme = "README.rst" diff --git a/scripts/create-torch-versions-range.py b/scripts/create-torch-versions-range.py new file mode 100755 index 0000000..08c2336 --- /dev/null +++ b/scripts/create-torch-versions-range.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +""" +This script updates the `Requires-Dist` information in the wheel METADATA +to contain the range of compatible torch versions. It expects newline separated +`Requires-Dist: torch ==...` information (corresponding to wheels built against a single +torch version) and will print `Requires-Dist: torch >=$MIN_VERSION,<${MAX_VERSION+1}` on +the standard output. + +This output can the be used in the merged wheel containing the build against all torch +versions. +""" +import re +import sys + + +if __name__ == "__main__": + torch_versions_raw = sys.argv[1] + + torch_versions = [] + for version in torch_versions_raw.split("\n"): + if version.strip() == "": + continue + + match = re.match(r"Requires-Dist: torch[ ]?==(\d+)\.(\d+)\.\*", version) + if match is None: + raise ValueError(f"unexpected Requires-Dist format: {version}") + + major, minor = match.groups() + major = int(major) + minor = int(minor) + + version = (major, minor) + + if version in torch_versions: + raise ValueError(f"duplicate torch version: {version}") + + torch_versions.append(version) + + torch_versions = list(sorted(torch_versions)) + + min_version = f"{torch_versions[0][0]}.{torch_versions[0][1]}" + max_version = f"{torch_versions[-1][0]}.{torch_versions[-1][1] + 1}" + + print(f"Requires-Dist: torch >={min_version},<{max_version}") diff --git a/setup.py b/setup.py index 011c7fc..50eceff 100644 --- a/setup.py +++ b/setup.py @@ -1,39 +1,113 @@ +import os +import subprocess import sys -from setuptools import setup, Extension -from torch.utils.cpp_extension import BuildExtension, include_paths, library_paths - -# Collecting include and library paths -include_dirs = include_paths() -library_dirs = library_paths() -libraries = ["c10", "torch", "torch_cpu"] - -extra_compile_args = ["-std=c++17"] -extra_link_args = [] -if sys.platform == "darwin": - shared_lib_ext = ".dylib" - extra_compile_args.append("-stdlib=libc++") - extra_link_args.extend(["-stdlib=libc++", "-mmacosx-version-min=10.9"]) -elif sys.platform == "linux": - extra_compile_args.append("-fPIC") - shared_lib_ext = ".so" -else: - raise RuntimeError(f"Unsupported platform {sys.platform}") - - -# Define the extension -neighbors_convert_extension = Extension( - name="pet_neighbors_convert.neighbors_convert", - sources=["src/pet_neighbors_convert/neighbors_convert.cpp"], - include_dirs=include_dirs, - library_dirs=library_dirs, - libraries=libraries, - language="c++", - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, -) - -setup( - ext_modules=[neighbors_convert_extension], - cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)}, - package_data={"pet_neighbors_convert": [f"neighbors_convert{shared_lib_ext}"]}, -) + +from setuptools import Extension, setup +from setuptools.command.bdist_egg import bdist_egg +from setuptools.command.build_ext import build_ext +from wheel.bdist_wheel import bdist_wheel + + +ROOT = os.path.realpath(os.path.dirname(__file__)) + + +class universal_wheel(bdist_wheel): + # When building the wheel, the `wheel` package assumes that if we have a + # binary extension then we are linking to `libpython.so`; and thus the wheel + # is only usable with a single python version. This is not the case for + # here, and the wheel will be compatible with any Python >=3.7. This is + # tracked in https://github.com/pypa/wheel/issues/185, but until then we + # manually override the wheel tag. + def get_tag(self): + tag = bdist_wheel.get_tag(self) + # tag[2:] contains the os/arch tags, we want to keep them + return ("py3", "none") + tag[2:] + + +class cmake_ext(build_ext): + """Build the native library using cmake""" + + def run(self): + import torch + + torch_major, torch_minor, *_ = torch.__version__.split(".") + + source_dir = ROOT + build_dir = os.path.join(ROOT, "build", "cmake-build") + install_dir = os.path.join( + os.path.realpath(self.build_lib), + f"pet_neighbors_convert/torch-{torch_major}.{torch_minor}", + ) + + os.makedirs(build_dir, exist_ok=True) + + cmake_options = [ + "-DCMAKE_BUILD_TYPE=Release", + f"-DCMAKE_INSTALL_PREFIX={install_dir}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + ] + + if sys.platform.startswith("darwin"): + cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0") + + # ARCHFLAGS is used by cibuildwheel to pass the requested arch to the + # compilers + ARCHFLAGS = os.environ.get("ARCHFLAGS") + if ARCHFLAGS is not None: + cmake_options.append(f"-DCMAKE_C_FLAGS={ARCHFLAGS}") + cmake_options.append(f"-DCMAKE_CXX_FLAGS={ARCHFLAGS}") + + subprocess.run( + ["cmake", source_dir, *cmake_options], + cwd=build_dir, + check=True, + ) + build_command = [ + "cmake", + "--build", + build_dir, + "--target", + "install", + ] + + subprocess.run(build_command, check=True) + + +class bdist_egg_disabled(bdist_egg): + """Disabled version of bdist_egg + + Prevents setup.py install performing setuptools' default easy_install, + which it should never ever do. + """ + + def run(self): + sys.exit( + "Aborting implicit building of eggs. " + "Use `pip install .` to install from source." + ) + + +if __name__ == "__main__": + try: + import torch + + # if we have torch, we are building a wheel - requires specific torch version + torch_v_major, torch_v_minor, *_ = torch.__version__.split(".") + torch_version = f"== {torch_v_major}.{torch_v_minor}.*" + except ImportError: + # otherwise we are building a sdist + torch_version = ">= 2.3" + + install_requires = [f"torch {torch_version}"] + + setup( + install_requires=install_requires, + ext_modules=[ + Extension(name="neighbors_convert", sources=[]), + ], + cmdclass={ + "build_ext": cmake_ext, + "bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled, + "bdist_wheel": universal_wheel, + }, + ) diff --git a/src/pet_neighbors_convert/__init__.py b/src/pet_neighbors_convert/__init__.py index e35af9b..9a33ae8 100644 --- a/src/pet_neighbors_convert/__init__.py +++ b/src/pet_neighbors_convert/__init__.py @@ -1,16 +1,72 @@ +import os +import sys +from collections import namedtuple + import torch -import importlib.resources as pkg_resources + +import re +import glob + from ._version import __version__ # noqa -def load_neighbors_convert(): - try: - # Locate the shared object file in the package - with pkg_resources.files(__name__).joinpath("neighbors_convert.so") as lib_path: - # Load the shared object file - torch.ops.load_library(str(lib_path)) - except Exception as e: - print(f"Failed to load neighbors_convert.so: {e}") +_Version = namedtuple("Version", ["major", "minor", "patch"]) + + +def _parse_version(version): + match = re.match(r"(\d+)\.(\d+)\.(\d+).*", version) + if match: + return _Version(*map(int, match.groups())) + else: + raise ValueError("Invalid version string format") + + +_HERE = os.path.realpath(os.path.dirname(__file__)) + + +def _lib_path(): + torch_version = _parse_version(torch.__version__) + expected_prefix = os.path.join( + _HERE, f"torch-{torch_version.major}.{torch_version.minor}" + ) + if os.path.exists(expected_prefix): + if sys.platform.startswith("darwin"): + path = os.path.join(expected_prefix, "lib", "libneighbors_convert.dylib") + elif sys.platform.startswith("linux"): + path = os.path.join(expected_prefix, "lib", "libneighbors_convert.so") + elif sys.platform.startswith("win"): + path = os.path.join(expected_prefix, "bin", "neighbors_convert.dll") + else: + raise ImportError("Unknown platform. Please edit this file") + + if os.path.isfile(path): + return path + else: + raise ImportError( + "Could not find neighbors_convert shared library at " + path + ) + + # gather which torch version(s) the current install was built + # with to create the error message + existing_versions = [] + for prefix in glob.glob(os.path.join(_HERE, "torch-*")): + existing_versions.append(os.path.basename(prefix)[11:]) + + if len(existing_versions) == 1: + raise ImportError( + f"Trying to load neighbors-convert with torch v{torch.__version__}, " + f"but it was compiled against torch v{existing_versions[0]}, which " + "is not ABI compatible" + ) + else: + all_versions = ", ".join(map(lambda version: f"v{version}", existing_versions)) + raise ImportError( + f"Trying to load neighbors-convert with torch v{torch.__version__}, " + f"we found builds for torch {all_versions}; which are not ABI compatible.\n" + "You can try to re-install from source with " + "`pip install neighbors-convert --no-binary=neighbors-convert`" + ) -load_neighbors_convert() +# load the C++ operators and custom classes +torch.classes.load_library(_lib_path()) diff --git a/tests/test_init.py b/tests/test_init.py index 72a490a..0ec96bd 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,5 +1,4 @@ -from pet_neighbors_convert import load_neighbors_convert - - def test_init(): - load_neighbors_convert() + import torch + import pet_neighbors_convert + torch.ops.neighbors_convert.process diff --git a/tox.ini b/tox.ini index a665ddb..f2e23a8 100644 --- a/tox.ini +++ b/tox.ini @@ -8,7 +8,6 @@ passenv = * [testenv:build] description = Asserts package build integrity. -usedevelop = true deps = build check-manifest @@ -21,6 +20,5 @@ commands = [testenv:tests] description = Run test suite with pytest and {basepython}. -usedevelop = true deps = pytest commands = pytest {posargs} From 1e0113e26b3b759ec771da7b704dcc8bcef5f2ca Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Sat, 15 Feb 2025 14:08:03 +0100 Subject: [PATCH 2/4] cleanup neighbors_convert --- .github/workflows/build-wheels.yml | 2 +- .../neighbors_convert.cpp | 270 +++++++++--------- 2 files changed, 135 insertions(+), 137 deletions(-) diff --git a/.github/workflows/build-wheels.yml b/.github/workflows/build-wheels.yml index c0ecc55..b983e62 100644 --- a/.github/workflows/build-wheels.yml +++ b/.github/workflows/build-wheels.yml @@ -219,7 +219,7 @@ jobs: - name: Publish distribution to PyPI if: startsWith(github.ref, 'refs/tags/v') - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/s@release/v1 with: packages-dir: wheels diff --git a/src/pet_neighbors_convert/neighbors_convert.cpp b/src/pet_neighbors_convert/neighbors_convert.cpp index 857a8ec..4cf90d9 100644 --- a/src/pet_neighbors_convert/neighbors_convert.cpp +++ b/src/pet_neighbors_convert/neighbors_convert.cpp @@ -1,15 +1,15 @@ -// #include #include -#include // For std::fill -#include // For c10::optional +#include // For std::fill +#include // For c10::optional #include #include // Template function to process the neighbors template -std::vector process_neighbors_cpu(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, - int64_t max_size, int64_t n_atoms, at::Tensor species, - at::Tensor all_species) { +std::vector process_neighbors_cpu(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, + int64_t max_size, int64_t n_atoms, at::Tensor species, + at::Tensor all_species) +{ // Ensure the tensors are on the CPU and are contiguous TORCH_CHECK(i_list.device().is_cpu(), "i_list must be on CPU"); TORCH_CHECK(j_list.device().is_cpu(), "j_list must be on CPU"); @@ -38,65 +38,63 @@ std::vector process_neighbors_cpu(at::Tensor i_list, at::Tensor j_li at::Tensor neighbors_index = torch::zeros({n_atoms, max_size}, options_int); at::Tensor neighbors_shift = torch::zeros({n_atoms, max_size, 3}, options_int); at::Tensor relative_positions = torch::zeros({n_atoms, max_size, 3}, options_float); - at::Tensor nums = torch::zeros({n_atoms}, options_int); // Tensor to store the count of elements - at::Tensor mask = torch::ones({n_atoms, max_size}, options_bool); // Tensor to store the mask + at::Tensor nums = torch::zeros({n_atoms}, options_int); // Tensor to store the count of elements + at::Tensor mask = torch::ones({n_atoms, max_size}, options_bool); // Tensor to store the mask at::Tensor neighbor_species = all_species.size(0) * torch::ones({n_atoms, max_size}, options_int); // Temporary array to track the current population index - int_t* current_index = new int_t[n_atoms]; - std::fill(current_index, current_index + n_atoms, 0); // Fill the array with zeros + int_t *current_index = new int_t[n_atoms]; + std::fill(current_index, current_index + n_atoms, 0); // Fill the array with zeros // Get raw data pointers - int_t* i_list_ptr = i_list.data_ptr(); - int_t* j_list_ptr = j_list.data_ptr(); - int_t* S_list_ptr = S_list.data_ptr(); - float_t* D_list_ptr = D_list.data_ptr(); - int_t* species_ptr = species.data_ptr(); - int_t* all_species_ptr = all_species.data_ptr(); - - int_t* neighbors_index_ptr = neighbors_index.data_ptr(); - int_t* neighbors_shift_ptr = neighbors_shift.data_ptr(); - float_t* relative_positions_ptr = relative_positions.data_ptr(); - int_t* nums_ptr = nums.data_ptr(); - bool* mask_ptr = mask.data_ptr(); - int_t* neighbor_species_ptr = neighbor_species.data_ptr(); - + int_t *i_list_ptr = i_list.data_ptr(); + int_t *j_list_ptr = j_list.data_ptr(); + int_t *S_list_ptr = S_list.data_ptr(); + float_t *D_list_ptr = D_list.data_ptr(); + int_t *species_ptr = species.data_ptr(); + int_t *all_species_ptr = all_species.data_ptr(); + + int_t *neighbors_index_ptr = neighbors_index.data_ptr(); + int_t *neighbors_shift_ptr = neighbors_shift.data_ptr(); + float_t *relative_positions_ptr = relative_positions.data_ptr(); + int_t *nums_ptr = nums.data_ptr(); + bool *mask_ptr = mask.data_ptr(); + int_t *neighbor_species_ptr = neighbor_species.data_ptr(); + int64_t all_species_size = all_species.size(0); - + int_t all_species_maximum = -1; - for (int64_t k = 0; k < all_species_size; ++k) { - if (all_species_ptr[k] > all_species_maximum) { + for (int64_t k = 0; k < all_species_size; ++k) + { + if (all_species_ptr[k] > all_species_maximum) + { all_species_maximum = all_species_ptr[k]; } } - - int_t* mapping = new int_t[all_species_maximum + 1]; - for (int64_t k = 0; k < all_species_size; ++k) { + + int_t *mapping = new int_t[all_species_maximum + 1]; + for (int64_t k = 0; k < all_species_size; ++k) + { mapping[all_species_ptr[k]] = k; } - - - - // Populate the neighbors_index, neighbors_shift, relative_positions, neighbor_species, and neighbor_scalar_attributes tensors - + + // Populate the neighbors_index, neighbors_shift, relative_positions, + // neighbor_species, and neighbor_scalar_attributes tensors + int64_t shift_i; int_t i, j, idx; - for (int64_t k = 0; k < i_list.size(0); ++k) { + for (int64_t k = 0; k < i_list.size(0); ++k) + { i = i_list_ptr[k]; j = j_list_ptr[k]; idx = current_index[i]; - + shift_i = i * max_size; - if (idx < max_size) { + if (idx < max_size) + { neighbors_index_ptr[shift_i + idx] = j; neighbor_species_ptr[shift_i + idx] = mapping[species_ptr[j]]; - /*for (int64_t q = 0; q < all_species_size; ++q) { - if (all_species_ptr[q] == species_ptr[j]) { - neighbor_species_ptr[i * max_size + idx] = q; - break; - } - }*/ - + // Unroll the loop for better computational efficiency neighbors_shift_ptr[(shift_i + idx) * 3 + 0] = S_list_ptr[k * 3 + 0]; neighbors_shift_ptr[(shift_i + idx) * 3 + 1] = S_list_ptr[k * 3 + 1]; @@ -113,24 +111,28 @@ std::vector process_neighbors_cpu(at::Tensor i_list, at::Tensor j_li } // Copy current_index to nums - for (int64_t i = 0; i < n_atoms; ++i) { + for (int64_t i = 0; i < n_atoms; ++i) + { nums_ptr[i] = current_index[i]; } - + at::Tensor neighbors_pos = torch::zeros({n_atoms, max_size}, options_int); - int_t* neighbors_pos_ptr = neighbors_pos.data_ptr(); + int_t *neighbors_pos_ptr = neighbors_pos.data_ptr(); // Temporary array to track the current population index - int_t* current_index_two = new int_t[n_atoms]; - std::fill(current_index_two, current_index_two + n_atoms, 0); // Fill the array with zeros - + int_t *current_index_two = new int_t[n_atoms]; + std::fill(current_index_two, current_index_two + n_atoms, 0); // Fill the array with zeros + int64_t shift_j; - for (int64_t k = 0; k < i_list.size(0); ++k) { + for (int64_t k = 0; k < i_list.size(0); ++k) + { i = i_list_ptr[k]; j = j_list_ptr[k]; shift_j = j * max_size; - for (int64_t q = 0; q < current_index[j]; ++q) { - if (neighbors_index_ptr[shift_j + q] == i && neighbors_shift_ptr[(shift_j + q) * 3 + 0] == -S_list_ptr[k * 3 + 0] && neighbors_shift_ptr[(shift_j + q) * 3 + 1] == -S_list_ptr[k * 3 + 1] && neighbors_shift_ptr[(shift_j + q) * 3 + 2] == -S_list_ptr[k * 3 + 2]) { + for (int64_t q = 0; q < current_index[j]; ++q) + { + if (neighbors_index_ptr[shift_j + q] == i && neighbors_shift_ptr[(shift_j + q) * 3 + 0] == -S_list_ptr[k * 3 + 0] && neighbors_shift_ptr[(shift_j + q) * 3 + 1] == -S_list_ptr[k * 3 + 1] && neighbors_shift_ptr[(shift_j + q) * 3 + 2] == -S_list_ptr[k * 3 + 2]) + { neighbors_pos_ptr[i * max_size + current_index_two[i]] = q; current_index_two[i]++; break; @@ -141,53 +143,44 @@ std::vector process_neighbors_cpu(at::Tensor i_list, at::Tensor j_li // Clean up temporary memory delete[] current_index; delete[] current_index_two; - + at::Tensor species_mapped = torch::zeros({n_atoms}, options_int); - int_t* species_mapped_ptr = species_mapped.data_ptr(); - for (int64_t k = 0; k < n_atoms; ++k) { + int_t *species_mapped_ptr = species_mapped.data_ptr(); + for (int64_t k = 0; k < n_atoms; ++k) + { species_mapped_ptr[k] = mapping[species_ptr[k]]; } - - /*for (int64_t k = 0; k < n_atoms; ++k) { - for (int64_t q = 0; q < all_species_size; ++q) { - if (all_species_ptr[q] == species_ptr[k]) { - species_mapped_ptr[k] = q; - break; - } - } - }*/ - - delete[] mapping; - - return {neighbors_index, relative_positions, nums, mask, neighbor_species, neighbors_pos, species_mapped}; + + delete[] mapping; + + return {neighbors_index, relative_positions, nums, mask, neighbor_species, neighbors_pos, species_mapped}; } // Template function for backward pass template -at::Tensor process_neighbors_cpu_backward(at::Tensor grad_output, at::Tensor i_list, int64_t max_size, int64_t n_atoms) { +at::Tensor process_neighbors_cpu_backward(at::Tensor grad_output, at::Tensor i_list, int64_t max_size, int64_t n_atoms) +{ // Ensure the tensors are on the CPU and are contiguous TORCH_CHECK(grad_output.device().is_cpu(), "grad_output must be on CPU"); TORCH_CHECK(i_list.device().is_cpu(), "i_list must be on CPU"); grad_output = grad_output.contiguous(); i_list = i_list.contiguous(); - - // TORCH_CHECK(grad_output.is_contiguous(), "grad_output must be contiguous"); - // TORCH_CHECK(i_list.is_contiguous(), "i_list must be contiguous"); // Initialize gradient tensor for D_list with zeros auto options_float = torch::TensorOptions().dtype(grad_output.dtype()).device(torch::kCPU); at::Tensor grad_D_list = torch::zeros({i_list.size(0), 3}, options_float); - int_t* current_index = new int_t[n_atoms]; - std::fill(current_index, current_index + n_atoms, 0); // Fill the array with zeros + int_t *current_index = new int_t[n_atoms]; + std::fill(current_index, current_index + n_atoms, 0); // Fill the array with zeros - float_t* grad_D_list_ptr = grad_D_list.data_ptr(); - float_t* grad_output_ptr = grad_output.data_ptr(); - int_t* i_list_ptr = i_list.data_ptr(); + float_t *grad_D_list_ptr = grad_D_list.data_ptr(); + float_t *grad_output_ptr = grad_output.data_ptr(); + int_t *i_list_ptr = i_list.data_ptr(); int_t i, idx; - for (int64_t k = 0; k < i_list.size(0); ++k) { + for (int64_t k = 0; k < i_list.size(0); ++k) + { i = i_list_ptr[k]; idx = current_index[i]; grad_D_list_ptr[k * 3 + 0] = grad_output_ptr[(i * max_size + idx) * 3 + 0]; @@ -201,7 +194,8 @@ at::Tensor process_neighbors_cpu_backward(at::Tensor grad_output, at::Tensor i_l } template -at::Tensor process_neighbors_backward(at::Tensor grad_output, at::Tensor i_list, int64_t max_size, int64_t n_atoms) { +at::Tensor process_neighbors_backward(at::Tensor grad_output, at::Tensor i_list, int64_t max_size, int64_t n_atoms) +{ // Ensure all tensors are on the same device auto device = grad_output.device(); TORCH_CHECK(i_list.device() == device, "i_list must be on the same device as grad_output"); @@ -218,24 +212,35 @@ at::Tensor process_neighbors_backward(at::Tensor grad_output, at::Tensor i_list, } // Dispatch function based on tensor types for backward -at::Tensor process_dispatch_backward(at::Tensor grad_output, at::Tensor i_list, int64_t max_size, int64_t n_atoms) { - if (i_list.scalar_type() == at::ScalarType::Int && grad_output.scalar_type() == at::ScalarType::Float) { +at::Tensor process_dispatch_backward(at::Tensor grad_output, at::Tensor i_list, int64_t max_size, int64_t n_atoms) +{ + if (i_list.scalar_type() == at::ScalarType::Int && grad_output.scalar_type() == at::ScalarType::Float) + { return process_neighbors_backward(grad_output, i_list, max_size, n_atoms); - } else if (i_list.scalar_type() == at::ScalarType::Int && grad_output.scalar_type() == at::ScalarType::Double) { + } + else if (i_list.scalar_type() == at::ScalarType::Int && grad_output.scalar_type() == at::ScalarType::Double) + { return process_neighbors_backward(grad_output, i_list, max_size, n_atoms); - } else if (i_list.scalar_type() == at::ScalarType::Long && grad_output.scalar_type() == at::ScalarType::Float) { + } + else if (i_list.scalar_type() == at::ScalarType::Long && grad_output.scalar_type() == at::ScalarType::Float) + { return process_neighbors_backward(grad_output, i_list, max_size, n_atoms); - } else if (i_list.scalar_type() == at::ScalarType::Long && grad_output.scalar_type() == at::ScalarType::Double) { + } + else if (i_list.scalar_type() == at::ScalarType::Long && grad_output.scalar_type() == at::ScalarType::Double) + { return process_neighbors_backward(grad_output, i_list, max_size, n_atoms); - } else { + } + else + { throw std::runtime_error("Unsupported tensor types"); } } template -std::vector process_neighbors(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, - int64_t max_size, int64_t n_atoms, at::Tensor species, - at::Tensor all_species) { +std::vector process_neighbors(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, + int64_t max_size, int64_t n_atoms, at::Tensor species, + at::Tensor all_species) +{ // Ensure all tensors are on the same device auto device = i_list.device(); TORCH_CHECK(j_list.device() == device, "j_list must be on the same device as i_list"); @@ -243,7 +248,7 @@ std::vector process_neighbors(at::Tensor i_list, at::Tensor j_list, TORCH_CHECK(D_list.device() == device, "D_list must be on the same device as i_list"); TORCH_CHECK(species.device() == device, "species must be on the same device as i_list"); TORCH_CHECK(all_species.device() == device, "all_species must be on the same device as i_list"); - + // Move all tensors to CPU auto i_list_cpu = i_list.cpu(); auto j_list_cpu = j_list.cpu(); @@ -256,7 +261,8 @@ std::vector process_neighbors(at::Tensor i_list, at::Tensor j_list, auto result = process_neighbors_cpu(i_list_cpu, j_list_cpu, S_list_cpu, D_list_cpu, max_size, n_atoms, species_cpu, all_species_cpu); // Move the output tensors back to the initial device - for (auto& tensor_opt : result) { + for (auto &tensor_opt : result) + { tensor_opt = tensor_opt.to(device); } @@ -264,31 +270,43 @@ std::vector process_neighbors(at::Tensor i_list, at::Tensor j_list, } // Dispatch function based on tensor types -std::vector process_dispatch(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, - int64_t max_size, int64_t n_atoms, at::Tensor species, - at::Tensor all_species) { +std::vector process_dispatch(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, + int64_t max_size, int64_t n_atoms, at::Tensor species, + at::Tensor all_species) +{ if (i_list.scalar_type() == at::ScalarType::Int && j_list.scalar_type() == at::ScalarType::Int && - S_list.scalar_type() == at::ScalarType::Int && D_list.scalar_type() == at::ScalarType::Float) { + S_list.scalar_type() == at::ScalarType::Int && D_list.scalar_type() == at::ScalarType::Float) + { return process_neighbors(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); - } else if (i_list.scalar_type() == at::ScalarType::Int && j_list.scalar_type() == at::ScalarType::Int && - S_list.scalar_type() == at::ScalarType::Int && D_list.scalar_type() == at::ScalarType::Double) { + } + else if (i_list.scalar_type() == at::ScalarType::Int && j_list.scalar_type() == at::ScalarType::Int && + S_list.scalar_type() == at::ScalarType::Int && D_list.scalar_type() == at::ScalarType::Double) + { return process_neighbors(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); - } else if (i_list.scalar_type() == at::ScalarType::Long && j_list.scalar_type() == at::ScalarType::Long && - S_list.scalar_type() == at::ScalarType::Long && D_list.scalar_type() == at::ScalarType::Float) { + } + else if (i_list.scalar_type() == at::ScalarType::Long && j_list.scalar_type() == at::ScalarType::Long && + S_list.scalar_type() == at::ScalarType::Long && D_list.scalar_type() == at::ScalarType::Float) + { return process_neighbors(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); - } else if (i_list.scalar_type() == at::ScalarType::Long && j_list.scalar_type() == at::ScalarType::Long && - S_list.scalar_type() == at::ScalarType::Long && D_list.scalar_type() == at::ScalarType::Double) { + } + else if (i_list.scalar_type() == at::ScalarType::Long && j_list.scalar_type() == at::ScalarType::Long && + S_list.scalar_type() == at::ScalarType::Long && D_list.scalar_type() == at::ScalarType::Double) + { return process_neighbors(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); - } else { + } + else + { throw std::runtime_error("Unsupported tensor types"); } } -class ProcessNeighborsFunction : public torch::autograd::Function { +class ProcessNeighborsFunction : public torch::autograd::Function +{ public: - static std::vector forward(torch::autograd::AutogradContext *ctx, at::Tensor i_list, at::Tensor j_list, - at::Tensor S_list, at::Tensor D_list, int64_t max_size, int64_t n_atoms, - at::Tensor species, at::Tensor all_species) { + static std::vector forward(torch::autograd::AutogradContext *ctx, at::Tensor i_list, at::Tensor j_list, + at::Tensor S_list, at::Tensor D_list, int64_t max_size, int64_t n_atoms, + at::Tensor species, at::Tensor all_species) + { auto outputs = process_dispatch(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); ctx->save_for_backward({i_list}); ctx->saved_data["max_size"] = max_size; @@ -296,12 +314,13 @@ class ProcessNeighborsFunction : public torch::autograd::Function backward(torch::autograd::AutogradContext *ctx, std::vector grad_outputs) { + static std::vector backward(torch::autograd::AutogradContext *ctx, std::vector grad_outputs) + { auto i_list = ctx->get_saved_variables()[0]; auto max_size = ctx->saved_data["max_size"].toInt(); auto n_atoms = ctx->saved_data["n_atoms"].toInt(); - auto grad_relative_positions = grad_outputs[1]; // Assuming this is the gradient w.r.t relative_positions tensor + auto grad_relative_positions = grad_outputs[1]; // Assuming this is the gradient w.r.t relative_positions tensor auto grad_D_list = process_dispatch_backward(grad_relative_positions, i_list, max_size, n_atoms); return {at::Tensor(), at::Tensor(), at::Tensor(), grad_D_list, at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor()}; @@ -309,36 +328,15 @@ class ProcessNeighborsFunction : public torch::autograd::Function process_neighbors_apply(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, - int64_t max_size, int64_t n_atoms, at::Tensor species, at::Tensor all_species) { +std::vector process_neighbors_apply(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, + int64_t max_size, int64_t n_atoms, at::Tensor species, at::Tensor all_species) +{ return ProcessNeighborsFunction::apply(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); } -/*TORCH_LIBRARY(neighbors_convert, m) { - m.def( - "convert_neighbors(Tensor i_list, Tensor j_list, Tensor S_list, Tensor D_list, int max_size, int n_atoms, Tensor species, Tensor all_species) -> Tensor[]", - &process_neighbors_apply - ); -}*/ - -TORCH_LIBRARY(neighbors_convert, m) { +TORCH_LIBRARY(neighbors_convert, m) +{ m.def( "process", - &process_neighbors_apply - ); + &process_neighbors_apply); } - -// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { -// m.def("process_neighbors(Tensor i_list, Tensor j_list, Tensor S_list, Tensor D_list, int max_size, int n_atoms, Tensor species, Tensor all_species) -> Tensor[]", &process_neighbors_apply, "Process neighbors and return tensors, including count tensor, mask, and neighbor_species"); -// } - -/*static auto registry = torch::RegisterOperators() - .op("neighbors_convert::process(Tensor i_list, Tensor j_list, Tensor S_list, Tensor D_list, int max_size, int n_atoms, Tensor species, Tensor all_species) -> Tensor[]", &process_neighbors_apply);*/ - -/*PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("process_neighbors", &process_neighbors_apply, "Process neighbors and return tensors, including count tensor, mask, and neighbor_species");*/ - -/*static auto registry = torch::RegisterOperators() - .op("neighbors_convert::process(Tensor i_list, Tensor j_list, Tensor S_list, Tensor D_list, int max_size, int n_atoms, Tensor species, Tensor all_species) -> Tensor[]", &process_neighbors_apply); - -*/ \ No newline at end of file From 11afef1e1e77a516921547ed720a86363c516bb9 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Sat, 15 Feb 2025 14:11:58 +0100 Subject: [PATCH 3/4] fix pypa action --- .github/workflows/build-wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-wheels.yml b/.github/workflows/build-wheels.yml index b983e62..c0ecc55 100644 --- a/.github/workflows/build-wheels.yml +++ b/.github/workflows/build-wheels.yml @@ -219,7 +219,7 @@ jobs: - name: Publish distribution to PyPI if: startsWith(github.ref, 'refs/tags/v') - uses: pypa/s@release/v1 + uses: pypa/gh-action-pypi-publish@release/v1 with: packages-dir: wheels From 19890457080faf2251815c798c87e41f11da047c Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Sat, 15 Feb 2025 14:14:57 +0100 Subject: [PATCH 4/4] update classifiers --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 92ebd32..5863645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ backend-path = ["build-backend"] [project] authors = [{name = "lab-cosmo developers"}] classifiers = [ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Science/Research", "License :: OSI Approved :: BSD License", "Operating System :: POSIX",