diff --git a/.gitignore b/.gitignore index 750baeb..7b3ce21 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ result result-* +.ipynb_checkpoints/ diff --git a/flake.lock b/flake.lock index 3f4f54a..c327708 100644 --- a/flake.lock +++ b/flake.lock @@ -2,7 +2,9 @@ "nodes": { "flake-parts": { "inputs": { - "nixpkgs-lib": "nixpkgs-lib" + "nixpkgs-lib": [ + "nixpkgs" + ] }, "locked": { "lastModified": 1714641030, @@ -18,38 +20,175 @@ "type": "github" } }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "genjax": { + "flake": false, + "locked": { + "lastModified": 1706565447, + "narHash": "sha256-dMTB2YPnmboU+ZFpNF3+ZrcK4uH2ZjQ4hQKzXu49sjc=", + "owner": "probcomp", + "repo": "genjax", + "rev": "3357b75b7ae64121b2848254e11c5b79ee7f1820", + "type": "github" + }, + "original": { + "owner": "probcomp", + "ref": "v0.1.1", + "repo": "genjax", + "type": "github" + } + }, + "nix-github-actions": { + "inputs": { + "nixpkgs": [ + "poetry2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1703863825, + "narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=", + "owner": "nix-community", + "repo": "nix-github-actions", + "rev": "5163432afc817cf8bd1f031418d1869e4c9d5547", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nix-github-actions", + "type": "github" + } + }, "nixpkgs": { "locked": { - "lastModified": 1714253743, - "narHash": "sha256-mdTQw2XlariysyScCv2tTE45QSU9v/ezLcHJ22f0Nxc=", + "lastModified": 1720028458, + "narHash": "sha256-DuQi7Eaa7hfpl5WufMiWDq/a4p5qngRtwhXNGHmdv4Y=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "d8724afca4565614164dd81345f6137c4c6eab21", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "d8724afca4565614164dd81345f6137c4c6eab21", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-llvm-10": { + "locked": { + "lastModified": 1706589919, + "narHash": "sha256-pNHnDITxSI3a17GOF1RUF3jBO1OiNYTRH2yV/cJG4m4=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "58a1abdbae3217ca6b702f03d3b35125d88a2994", + "rev": "222c1940fafeda4dea161858ffe6ebfc853d3db5", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-unstable", "repo": "nixpkgs", + "rev": "222c1940fafeda4dea161858ffe6ebfc853d3db5", "type": "github" } }, - "nixpkgs-lib": { + "poetry2nix": { + "inputs": { + "flake-utils": "flake-utils", + "nix-github-actions": "nix-github-actions", + "nixpkgs": [ + "nixpkgs" + ], + "systems": "systems_2", + "treefmt-nix": "treefmt-nix" + }, "locked": { - "lastModified": 1714640452, - "narHash": "sha256-QBx10+k6JWz6u7VsohfSw8g8hjdBZEf8CFzXH1/1Z94=", - "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/50eb7ecf4cd0a5756d7275c8ba36790e5bd53e33.tar.gz" + "lastModified": 1718726452, + "narHash": "sha256-w4hJSYvACz0i5XHtxc6XNyHwbxpisN13M2kA2Y7937o=", + "owner": "nix-community", + "repo": "poetry2nix", + "rev": "53e534a08c0cd2a9fa7587ed1c3e7f6aeb804a2c", + "type": "github" }, "original": { - "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/50eb7ecf4cd0a5756d7275c8ba36790e5bd53e33.tar.gz" + "owner": "nix-community", + "repo": "poetry2nix", + "type": "github" } }, "root": { "inputs": { "flake-parts": "flake-parts", - "nixpkgs": "nixpkgs" + "genjax": "genjax", + "nixpkgs": "nixpkgs", + "nixpkgs-llvm-10": "nixpkgs-llvm-10", + "poetry2nix": "poetry2nix" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "id": "systems", + "type": "indirect" + } + }, + "treefmt-nix": { + "inputs": { + "nixpkgs": [ + "poetry2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1718522839, + "narHash": "sha256-ULzoKzEaBOiLRtjeY3YoGFJMwWSKRYOic6VNw2UyTls=", + "owner": "numtide", + "repo": "treefmt-nix", + "rev": "68eb1dc333ce82d0ab0c0357363ea17c31ea1f81", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "treefmt-nix", + "type": "github" } } }, diff --git a/flake.nix b/flake.nix index 27b865f..7cc2073 100644 --- a/flake.nix +++ b/flake.nix @@ -3,9 +3,22 @@ inputs = { flake-parts.url = "github:hercules-ci/flake-parts"; - nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-parts.inputs.nixpkgs-lib.follows = "nixpkgs"; + + nixpkgs.url = "github:NixOS/nixpkgs?ref=d8724afca4565614164dd81345f6137c4c6eab21"; + nixpkgs-llvm-10.url = "github:NixOS/nixpkgs?rev=222c1940fafeda4dea161858ffe6ebfc853d3db5"; + + genjax.url = "github:probcomp/genjax?ref=v0.1.1"; + genjax.flake = false; + + poetry2nix.url = "github:nix-community/poetry2nix"; + poetry2nix.inputs.nixpkgs.follows = "nixpkgs"; }; + nixConfig.extra-substituters = [ "https://numtide.cachix.org" ]; + nixConfig.extra-trusted-public-keys = [ "numtide.cachix.org-1:2ps1kLBUWjxIneOy1Ik6cQjb41X0iXVXeHigGmycPPE=" ]; + nixConfig.sandbox = "relaxed"; + outputs = inputs@{ self, nixpkgs, flake-parts, ... }: flake-parts.lib.mkFlake { inherit inputs; } { imports = [ @@ -30,23 +43,121 @@ inherit nixpkgs; basicTools = self.lib.basicTools; }; + + poetry2nix = inputs.poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }; scopes = (self.lib.mkScopes { - inherit pkgs; + inherit pkgs internalPackages inputs poetry2nix; basicTools = self.lib.basicTools; }); loom = scopes.callPy3Package ./pkgs/loom { }; - packages = loom.more_packages // { + # TODO: make this cleaner + bayes3d = self'.legacyPackages.python3Packages.bayes3d; + open3d = scopes.callPy3Package ./pkgs/python-modules/open3d { }; + + internalPackages = { + #jaxlib = scopes.callPy3Package ./pkgs/jaxlib { }; + #jax = scopes.callPy3Package ./pkgs/jax { }; + jaxtyping = scopes.callPy3Package ./pkgs/jaxtyping { }; + tinygltf = scopes.callPackage ./pkgs/tinygltf { }; + PoissonRecon = scopes.callPackage ./pkgs/PoissonRecon { }; + goftests = scopes.callPackage ./pkgs/goftests { }; + parsable = scopes.callPackage ./pkgs/parsable { }; + pymetis = scopes.callPackage ./pkgs/pymetis { }; + distributions = scopes.callPackage ./pkgs/distributions { }; + genjax = scopes.callPy3Package ./pkgs/genjax { }; + distinctipy = scopes.callPy3Package ./pkgs/distinctipy { }; + pyransac3d = scopes.callPy3Package ./pkgs/pyransac3d { }; + opencv-python = scopes.callPy3Package ./pkgs/opencv-python { }; + oryx = scopes.callPy3Package ./pkgs/oryx { }; + plum-dispatch = scopes.callPy3Package ./pkgs/plum-dispatch { }; + } // packages; + + packages = { inherit loom sppl ociImgBase - ; + + bayes3d + open3d + ; + }; + + loadPackages = callPackage: path: + let + entries = builtins.readDir path; + in + pkgs.lib.mapAttrs (name: type: + if type != "directory" then (throw "${toString path}/${name} is not a directory") + else + callPackage "${toString path}/${name}" { } + ) + entries; + + # For fixing existing packages that live in nixpkgs + # TODO: put in separate file + pythonOverrides = super: pythonSuper: { + # so we can pull from flake inputs + inherit inputs; + + # FIXME: I don't think this is working as expected. Better to change nixpkgs wthfor now. + + # Use the pre-built version of tensorflow + tensorflow = pythonSuper.tensorflow-bin; + + # Use the pre-built version of jaxlib + jaxlib = super.jaxlib-bin; + + # Use the pre-built version of libjax + libjax = super.libjax-bin; }; in { + _module.args.pkgs = import inputs.nixpkgs { + inherit system; + config = { + # FIXME: commenting these out to see if they fix the duplicate dependency issue when building bayes3d + allowUnfree = true; + # Only enable CUDA on Linux + cudaSupport = (system == "x86_64-linux" || system == "aarch64-linux"); + }; + overlays = [ + (final: prev: { + # FIXME: say why this was added. + inherit (inputs.nixpkgs-llvm-10.legacyPackages.${system}) llvmPackages_10; + }) + ]; + }; + inherit packages; + + legacyPackages.python3Packages = + (pkgs.python311Packages.overrideScope pythonOverrides).overrideScope (super: superPython: + loadPackages super.callPackage ./pkgs/python-modules + ); + + devShells.default = pkgs.mkShell { + packages = [ + self'.legacyPackages.python3Packages.python-lsp-server + (self'.legacyPackages.python3Packages.python.withPackages (p: [ + self'.legacyPackages.python3Packages.bayes3d + self'.legacyPackages.python3Packages.jax + p.jupyter + p.scipy + ])) + ]; + + shellHook = '' + export EXTRA_LDFLAGS="-L/lib -L${pkgs.linuxPackages.nvidia_x11}/lib" + export EXTRA_CCFLAGS="-I/usr/include" + export CUDA_PATH=${pkgs.cudatoolkit_11} + export B3D_ASSET_PATH="${bayes3d.src}/assets" + + jupyter notebook + ''; + }; }; # NOTE: this property is consumed by flake-parts.mkFlake to define fields diff --git a/lib/mkScopes/default.nix b/lib/mkScopes/default.nix index 3218d45..b7488e3 100644 --- a/lib/mkScopes/default.nix +++ b/lib/mkScopes/default.nix @@ -1,19 +1,37 @@ -{ pkgs, basicTools }: let +{ pkgs, basicTools, internalPackages, poetry2nix, inputs }: +let callPackage = pkgs.newScope ( pkgs // { - inherit callPackage; + inherit + callPackage + callPy3Package + inputs + poetry2nix + ; basicTools = basicTools pkgs; } + // internalPackages ); + callPy3Package = pkgs.newScope ( pkgs // pkgs.python3Packages // { - inherit callPackage; + inherit + callPackage + callPy3Package + inputs + poetry2nix + ; basicTools = basicTools pkgs; } + // internalPackages ); -in { - inherit callPackage callPy3Package ; +in +{ + inherit + callPackage + callPy3Package + ; } diff --git a/notebooks/assets/.gitkeep b/notebooks/assets/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb new file mode 100644 index 0000000..43d207a --- /dev/null +++ b/notebooks/demo.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "id": "fc70511d-4bf5-47cb-9935-c17b4da6cf11", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/nix-shell.EBRV2Z/ipykernel_252673/268569837.py:7: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n", + " from IPython.core.display import Image, display\n", + "[E rasterize_gl.cpp:121] OpenGL version reported as 4.6\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Increasing frame buffer size to (width, height, depth) = (128, 128, 1024)\n", + "Centering mesh with translation [0.0000000e+00 2.9802322e-08 0.0000000e+00]\n", + "Number of frames: 60\n", + "observed_images.shape (60, 100, 100, 4)\n", + "Time elapsed: 0.3095111846923828\n", + "FPS: 193.85406074947772\n" + ] + } + ], + "source": [ + "import os\n", + "import time\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from IPython import embed\n", + "from IPython.core.display import Image, display\n", + "from scipy.spatial.transform import Rotation as R\n", + "\n", + "import bayes3d as b\n", + "\n", + "# Can be helpful for debugging:\n", + "# jax.config.update('jax_enable_checks', True)\n", + "\n", + "assets_dir = os.getenv(\"B3D_ASSET_PATH\")\n", + "\n", + "intrinsics = b.Intrinsics(\n", + " height=100, width=100, fx=50.0, fy=50.0, cx=50.0, cy=50.0, near=0.001, far=6.0\n", + ")\n", + "\n", + "b.setup_renderer(intrinsics)\n", + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(assets_dir, \"sample_objs/bunny.obj\")\n", + ")\n", + "\n", + "num_frames = 60\n", + "\n", + "poses = [b.t3d.transform_from_pos(jnp.array([-3.0, 0.0, 3.5]))]\n", + "delta_pose = b.t3d.transform_from_rot_and_pos(\n", + " R.from_euler(\"zyx\", [-1.0, 0.1, 2.0], degrees=True).as_matrix(),\n", + " jnp.array([0.09, 0.05, 0.02]),\n", + ")\n", + "for t in range(num_frames - 1):\n", + " poses.append(poses[-1].dot(delta_pose))\n", + "poses = jnp.stack(poses)\n", + "print(\"Number of frames: \", poses.shape[0])\n", + "\n", + "observed_images = b.RENDERER.render_many(poses[:, None, ...], jnp.array([0]))\n", + "print(\"observed_images.shape\", observed_images.shape)\n", + "\n", + "translation_deltas = b.utils.make_translation_grid_enumeration(\n", + " -0.2, -0.2, -0.2, 0.2, 0.2, 0.2, 5, 5, 5\n", + ")\n", + "rotation_deltas = jax.vmap(\n", + " lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0)\n", + ")(jax.random.split(jax.random.PRNGKey(30), 100))\n", + "\n", + "likelihood = jax.vmap(\n", + " b.threedp3_likelihood_old, in_axes=(None, 0, None, None, None, None, None)\n", + ")\n", + "\n", + "\n", + "def update_pose_estimate(pose_estimate, gt_image):\n", + " proposals = jnp.einsum(\"ij,ajk->aik\", pose_estimate, translation_deltas)\n", + " rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(\n", + " proposals[:, None, ...], jnp.array([0])\n", + " )\n", + " weights_new = likelihood(gt_image, rendered_images, 0.05, 0.1, 10**3, 0.1, 3)\n", + " pose_estimate = proposals[jnp.argmax(weights_new)]\n", + "\n", + " proposals = jnp.einsum(\"ij,ajk->aik\", pose_estimate, rotation_deltas)\n", + " rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(\n", + " proposals[:, None, ...], jnp.array([0])\n", + " )\n", + " weights_new = likelihood(gt_image, rendered_images, 0.05, 0.1, 10**3, 0.1, 3)\n", + " pose_estimate = proposals[jnp.argmax(weights_new)]\n", + " return pose_estimate, pose_estimate\n", + "\n", + "\n", + "inference_program = jax.jit(lambda p, x: jax.lax.scan(update_pose_estimate, p, x)[1])\n", + "inferred_poses = inference_program(poses[0], observed_images)\n", + "\n", + "start = time.time()\n", + "pose_estimates_over_time = inference_program(poses[0], observed_images)\n", + "end = time.time()\n", + "print(\"Time elapsed:\", end - start)\n", + "print(\"FPS:\", poses.shape[0] / (end - start))\n", + "\n", + "rerendered_images = b.RENDERER.render_many(\n", + " pose_estimates_over_time[:, None, ...], jnp.array([0])\n", + ")\n", + "\n", + "viz_images = [\n", + " b.viz.multi_panel(\n", + " [\n", + " b.viz.scale_image(b.viz.get_depth_image(d[:, :, 2]), 3),\n", + " b.viz.scale_image(b.viz.get_depth_image(r[:, :, 2]), 3),\n", + " ],\n", + " labels=[\"Observed\", \"Rerendered\"],\n", + " label_fontsize=20,\n", + " )\n", + " for (r, d) in zip(rerendered_images, observed_images)\n", + "]\n", + "b.make_gif_from_pil_images(viz_images, \"assets/demo.gif\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c39a88e9-b18d-4d91-b001-f2a3abdd6804", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "expected str, bytes or os.PathLike object, not list", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m display(\u001b[43mImage\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mviz_images\u001b[49m\u001b[43m)\u001b[49m)\n", + "File \u001b[0;32m/nix/store/6514avl2wlhjgdhvvw80f3qm8ps8kv5b-python3-3.11.9-env/lib/python3.11/site-packages/IPython/core/display.py:923\u001b[0m, in \u001b[0;36mImage.__init__\u001b[0;34m(self, data, url, filename, format, embed, width, height, retina, unconfined, metadata, alt)\u001b[0m\n\u001b[1;32m 921\u001b[0m ext \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_find_ext(filename)\n\u001b[1;32m 922\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m url \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 923\u001b[0m ext \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_find_ext\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 924\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m data \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 925\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo image data found. Expecting filename, url, or data.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m/nix/store/6514avl2wlhjgdhvvw80f3qm8ps8kv5b-python3-3.11.9-env/lib/python3.11/site-packages/IPython/core/display.py:1074\u001b[0m, in \u001b[0;36mImage._find_ext\u001b[0;34m(self, s)\u001b[0m\n\u001b[1;32m 1073\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_find_ext\u001b[39m(\u001b[38;5;28mself\u001b[39m, s):\n\u001b[0;32m-> 1074\u001b[0m base, ext \u001b[38;5;241m=\u001b[39m \u001b[43msplitext\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1076\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m ext:\n\u001b[1;32m 1077\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m base\n", + "File \u001b[0;32m:118\u001b[0m, in \u001b[0;36msplitext\u001b[0;34m(p)\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: expected str, bytes or os.PathLike object, not list" + ] + } + ], + "source": [ + "display(Image(url=viz_images))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a450ad59-1f75-4e1d-b150-fa62e47ba345", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pkgs/disabled-modules/beartype_0_16_4/default.nix b/pkgs/disabled-modules/beartype_0_16_4/default.nix new file mode 100644 index 0000000..3765da2 --- /dev/null +++ b/pkgs/disabled-modules/beartype_0_16_4/default.nix @@ -0,0 +1,12 @@ +{ fetchPypi +, beartype +}: +(beartype.overrideAttrs (final: prev: rec { + version = "0.16.4"; + + src = fetchPypi { + inherit (prev) pname; + inherit version; + hash = "sha256-GtqJzy1usw624Vbu0utUkzV3gpN5ENdDgJGOU8Lq4L8="; + }; +})) diff --git a/pkgs/disabled-modules/jax_/default.nix b/pkgs/disabled-modules/jax_/default.nix new file mode 100644 index 0000000..e84b240 --- /dev/null +++ b/pkgs/disabled-modules/jax_/default.nix @@ -0,0 +1,147 @@ +{ + lib, + blas, + buildPythonPackage, + callPackage, + setuptools, + importlib-metadata, + fetchFromGitHub, + jaxlib-bin, + hypothesis, + lapack, + matplotlib, + ml-dtypes, + numpy, + opt-einsum, + pytestCheckHook, + pytest-xdist, + pythonOlder, + scipy, + stdenv, +}: + +let + usingMKL = false; # blas.implementation == "mkl" || lapack.implementation == "mkl"; +in +buildPythonPackage rec { + pname = "jax"; + version = "0.4.28"; + pyproject = true; + + disabled = pythonOlder "3.9"; + + src = fetchFromGitHub { + owner = "google"; + repo = "jax"; + # google/jax contains tags for jax and jaxlib. Only use jax tags! + rev = "refs/tags/jax-v${version}"; + hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; + }; + + nativeBuildInputs = [ setuptools ]; + + # The version is automatically set to ".dev" if this variable is not set. + # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 + JAX_RELEASE = "1"; + + # jaxlib is _not_ included in propagatedBuildInputs because there are + # different versions of jaxlib depending on the desired target hardware. The + # JAX project ships separate wheels for CPU, GPU, and TPU. + propagatedBuildInputs = [ + ml-dtypes + numpy + opt-einsum + scipy + ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; + + nativeCheckInputs = [ + hypothesis + jaxlib-bin + matplotlib + pytestCheckHook + pytest-xdist + ]; + + # high parallelism will result in the tests getting stuck + dontUsePytestXdist = true; + + # NOTE: Don't run the tests in the expiremental directory as they require flax + # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. + # Not a big deal, this is how the JAX docs suggest running the test suite + # anyhow. + pytestFlagsArray = [ + "--numprocesses=4" + "-W ignore::DeprecationWarning" + "tests/" + ]; + + # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with + # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py' + # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241 + # NOTE: this doesn't seem to be an issue on linux + preCheck = lib.optionalString stdenv.isDarwin '' + export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d) + ''; + + # FIXME: disable the checks manually + doCheck = false; + + disabledTests = + [ + # Exceeds tolerance when the machine is busy + "test_custom_linear_solve_aux" + # UserWarning: Explicitly requested dtype + # requested in astype is not available, and will be truncated to + # dtype float32. (With numpy 1.24) + "testKde3" + "testKde5" + "testKde6" + # Invokes python manually in a subprocess, which does not have the correct dependencies + # ImportError: This version of jax requires jaxlib version >= 0.4.19. + "test_no_log_spam" + ] + ++ lib.optionals usingMKL [ + # See + # * https://github.com/google/jax/issues/9705 + # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 + # * https://github.com/NixOS/nixpkgs/issues/161960 + "test_custom_linear_solve_cholesky" + "test_custom_root_with_aux" + "testEigvalsGrad_shape" + ] + ++ lib.optionals stdenv.isAarch64 [ + # See https://github.com/google/jax/issues/14793. + "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" + "testQdwhWithRandomMatrix3" + "testScanGrad_jit_scan" + + # See https://github.com/google/jax/issues/17867. + "test_array" + "test_async" + "test_copy0" + "test_device_put" + "test_make_array_from_callback" + "test_make_array_from_single_device_arrays" + + # Fails on some hardware due to some numerical error + # See https://github.com/google/jax/issues/18535 + "testQdwhWithOnRankDeficientInput5" + ]; + + disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ + # RuntimeWarning: invalid value encountered in cast + "tests/lax_test.py" + ]; + + pythonImportsCheck = [ "jax" ]; + + # updater fails to pick the correct branch + passthru.skipBulkUpdate = true; + + meta = with lib; { + description = "Differentiate, compile, and transform Numpy code"; + homepage = "https://github.com/google/jax"; + license = licenses.asl20; + maintainers = with maintainers; [ samuela ]; + }; +} diff --git a/pkgs/disabled-modules/jaxlib_/default.nix b/pkgs/disabled-modules/jaxlib_/default.nix new file mode 100644 index 0000000..e1d365c --- /dev/null +++ b/pkgs/disabled-modules/jaxlib_/default.nix @@ -0,0 +1,156 @@ +{ + lib, + stdenv, + fetchurl, + + # Build-time dependencies: + buildPythonPackage, + curl, + fetchFromGitHub, + jsoncpp, + autoPatchelfHook, + + # Python dependencies: + absl-py, + flatbuffers, + ml-dtypes, + numpy, + scipy, + six, + + # Runtime dependencies: + double-conversion, + giflib, + libjpeg_turbo, + python, + snappy, + +}: + +let + pname = "jaxlib"; + version = "0.4.28"; + + # REMOVEME + effectiveStdenv = stdenv; + + meta = with lib; { + description = "JAX is Autograd and XLA, brought together for high-performance machine learning research"; + homepage = "https://github.com/google/jax"; + license = licenses.asl20; + maintainers = with maintainers; [ ndl ]; + platforms = platforms.unix; + # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136 + # however even with that fix applied, it doesn't work for everyone: + # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129 + # NOTE: We always build with NCCL; if it is unsupported, then our build is broken. + # broken = effectiveStdenv.isDarwin || nccl.meta.unsupported; + }; + + + arch = + # KeyError: ('Linux', 'arm64') + if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then + "aarch64" + else + effectiveStdenv.hostPlatform.linuxArch; + + xla = effectiveStdenv.mkDerivation { + pname = "xla-src"; + version = "unstable"; + + src = fetchFromGitHub { + owner = "openxla"; + repo = "xla"; + # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl. + rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4"; + hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E="; + }; + + dontBuild = true; + + # This is necessary for patchShebangs to know the right path to use. + nativeBuildInputs = [ python ]; + + # Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl + postPatch = '' + patchShebangs . + ''; + + installPhase = '' + cp -r . $out + ''; + }; + + platformTag = + if effectiveStdenv.hostPlatform.isLinux then + "manylinux2014_${arch}" + else if effectiveStdenv.system == "x86_64-darwin" then + "macosx_10_9_${arch}" + else if effectiveStdenv.system == "aarch64-darwin" then + "macosx_11_0_${arch}" + else + throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}"; + + wheelUrls = { + "x86_64-linux" = { + url = "https://files.pythonhosted.org/packages/8e/d7/65b1f5cf05d9159abd5882a51695d4d1b386bc8e26140eff7159854777f2/jaxlib-0.4.28-cp311-cp311-manylinux2014_x86_64.whl"; + hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU="; + }; + + "aarch64-linux" = { + url = "https://files.pythonhosted.org/packages/f2/87/0c07ec3ba047ca58c940d1c3050cd08c4390bca992cdfeeb2d9d356cd2c6/jaxlib-0.4.28-cp311-cp311-manylinux2014_aarch64.whl"; + hash = ""; + }; + + "x86_64-darwin" = { + url = "https://files.pythonhosted.org/packages/e0/b2/896d8d1f35e16e9f88ae6a753012e6d5a6882507ea58e7f0dd5af68ee1e8/jaxlib-0.4.28-cp311-cp311-macosx_10_14_x86_64.whl"; + hash = ""; + }; + + "aarch64-darwin" = { + url = "https://files.pythonhosted.org/packages/75/f3/1ce8b092ca68dfcfa6a0ee0a8a410f6d877e1628c05799c5d03757682c66/jaxlib-0.4.28-cp311-cp311-macosx_11_0_arm64.whl"; + hash = ""; + }; + }; +in +buildPythonPackage { + inherit meta pname version; + format = "wheel"; + + src = fetchurl ( + if builtins.hasAttr stdenv.system wheelUrls + then wheelUrls.${stdenv.system} + else throw "Unsupported system '${stdenv.system}'" + ); + + nativeBuildInputs = [ + autoPatchelfHook + ]; + + buildInputs = [ + stdenv.cc.cc.lib + ]; + + dependencies = [ + absl-py + curl + double-conversion + flatbuffers + giflib + jsoncpp + libjpeg_turbo + ml-dtypes + numpy + scipy + six + snappy + ]; + + pythonImportsCheck = [ + "jaxlib" + # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. + "jaxlib.cpu_feature_guard" + "jaxlib.xla_client" + ]; +} diff --git a/pkgs/disabled-modules/jaxlib_/default.nix.bak b/pkgs/disabled-modules/jaxlib_/default.nix.bak new file mode 100644 index 0000000..325aadc --- /dev/null +++ b/pkgs/disabled-modules/jaxlib_/default.nix.bak @@ -0,0 +1,86 @@ +{ buildPythonPackage +, fetchFromGitHub +, inputs +, fetchPypi +, stdenv +, scipy +, numpy +, ml-dtypes +, pip +, opt-einsum +, absl-py +, breakpointHook +, matplotlib +} +: +let + wheelUrls = { + "x86_64-linux" = { + url = "https://files.pythonhosted.org/packages/8e/d7/65b1f5cf05d9159abd5882a51695d4d1b386bc8e26140eff7159854777f2/jaxlib-0.4.28-cp311-cp311-manylinux2014_x86_64.whl"; + hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU="; + }; + + "aarch64-linux" = { + url = "https://files.pythonhosted.org/packages/f2/87/0c07ec3ba047ca58c940d1c3050cd08c4390bca992cdfeeb2d9d356cd2c6/jaxlib-0.4.28-cp311-cp311-manylinux2014_aarch64.whl"; + hash = ""; + }; + + "x86_64-darwin" = { + url = "https://files.pythonhosted.org/packages/e0/b2/896d8d1f35e16e9f88ae6a753012e6d5a6882507ea58e7f0dd5af68ee1e8/jaxlib-0.4.28-cp311-cp311-macosx_10_14_x86_64.whl"; + hash = ""; + }; + + "aarch64-darwin" = { + url = "https://files.pythonhosted.org/packages/75/f3/1ce8b092ca68dfcfa6a0ee0a8a410f6d877e1628c05799c5d03757682c66/jaxlib-0.4.28-cp311-cp311-macosx_11_0_arm64.whl"; + hash = ""; + }; + }; +in +buildPythonPackage rec { + pname = "jaxlib"; + version = "0.4.28"; + + src = fetchFromGitHub { + owner = "google"; + repo = "jax"; + rev = "jaxlib-v${version}"; + hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; + }; + + #unpackPhase = '' + #mkdir source + #cp -rv $src/jaxlib source + #chmod -R +w source + #ls -alsph source + + #ln -rsv $src/jax/version.py source/jaxlib + #cd source/jaxlib + #''; + + preConfigure = '' + cd jaxlib + mkdir jaxlib + ln -rsv $src/jax/version.py jaxlib/version.py + ln -rsv $src/third_party/xla jaxlib/xla_extension + ls -alsph . + ''; + + nativeBuildInputs = [ + breakpointHook + ]; + + nativeCheckInputs = [ + pip + ]; + + propagatedBuildInputs = [ + scipy + numpy + ml-dtypes + opt-einsum + absl-py + matplotlib + ]; + + preferLocalBuild = true; +} diff --git a/pkgs/disabled-modules/jaxtyping_/default.nix b/pkgs/disabled-modules/jaxtyping_/default.nix new file mode 100644 index 0000000..298ae13 --- /dev/null +++ b/pkgs/disabled-modules/jaxtyping_/default.nix @@ -0,0 +1,81 @@ +{ + lib, + buildPythonPackage, + pythonOlder, + fetchFromGitHub, + hatchling, + pythonRelaxDepsHook, + numpy, + typeguard, + typing-extensions, + cloudpickle, + equinox, + ipython, + jax, + jaxlib, + pytestCheckHook, + tensorflow, + torch, +}: + +let + self = buildPythonPackage rec { + pname = "jaxtyping"; + version = "0.2.28"; + pyproject = true; + + disabled = pythonOlder "3.9"; + + src = fetchFromGitHub { + owner = "google"; + repo = "jaxtyping"; + rev = "refs/tags/v${version}"; + hash = "sha256-xDFrgPecUIfCACg/xkMQ8G1+6hNiUUDg9eCZKNpNfzs="; + }; + + nativeBuildInputs = [ + hatchling + pythonRelaxDepsHook + ]; + + propagatedBuildInputs = [ + numpy + typeguard + typing-extensions + ]; + + pythonRelaxDeps = [ "typeguard" ]; + + nativeCheckInputs = [ + cloudpickle + equinox + ipython + jax + jaxlib + pytestCheckHook + tensorflow + torch + ]; + + doCheck = false; + + # Enable tests via passthru to avoid cyclic dependency with equinox. + passthru.tests = { + check = self.overridePythonAttrs { + # We disable tests because they complain about the version of typeguard being too new. + doCheck = false; + catchConflicts = false; + }; + }; + + pythonImportsCheck = [ "jaxtyping" ]; + + meta = with lib; { + description = "Type annotations and runtime checking for JAX arrays and PyTrees"; + homepage = "https://github.com/google/jaxtyping"; + license = licenses.mit; + maintainers = with maintainers; [ GaetanLepage ]; + }; + }; +in +self diff --git a/pkgs/disabled-modules/optax_0_1_7/default.nix b/pkgs/disabled-modules/optax_0_1_7/default.nix new file mode 100644 index 0000000..e84a8ac --- /dev/null +++ b/pkgs/disabled-modules/optax_0_1_7/default.nix @@ -0,0 +1,12 @@ +{ fetchFromGitHub +, optax +}: +(optax.overrideAttrs rec { + version = "0.1.7"; + src = fetchFromGitHub { + owner = "deepmind"; + repo = "optax"; + rev = "refs/tags/v${version}"; + hash = "sha256-zSMJxagPe2rkhrawJ+TWXUzk6V58IY6MhWmEqLVtOoA="; + }; +}) diff --git a/pkgs/python-modules/bayes3d/default.nix b/pkgs/python-modules/bayes3d/default.nix new file mode 100644 index 0000000..71ade64 --- /dev/null +++ b/pkgs/python-modules/bayes3d/default.nix @@ -0,0 +1,115 @@ +{ lib +, fetchFromGitHub +, breakpointHook +, buildPythonPackage +, cudaPackages_11 +, which +, libglvnd +, libGLU +, open3d +, symlinkJoin +, genjax +, distinctipy +, pyransac3d +, opencv-python +, setuptools +, setuptools-scm +, torch +, pytorchWithCuda +, graphviz +, imageio +, matplotlib +, meshcat +, natsort +, opencv4 +, plyfile +, liblzfse +, tensorflow-probability +, timm +, trimesh +}: +let + rev = "8113f643a7ba084e0ca2288cf06f95a23e39d1c7"; + + cuda-common-redist = with cudaPackages_11; [ + cuda_cccl # + libcublas # cublas_v2.h + libcurand + libcusolver # cusolverDn.h + libcusparse # cusparse.h + ]; + + cuda-native-redist = symlinkJoin { + name = "cuda-native-redist-${cudaPackages_11.cudaVersion}"; + paths = + with cudaPackages_11; + [ + cuda_cudart # cuda_runtime.h cuda_runtime_api.h + cuda_nvcc + ] + ++ cuda-common-redist; + }; +in +buildPythonPackage rec { + pname = "bayes3d"; + version = "0.1.0+${builtins.substring 0 8 rev}"; + + src = fetchFromGitHub { + repo = pname; + owner = "srounce"; + inherit rev; + hash = "sha256-6AtxR8ZsByliDTQE/hEJs5+LKwdfS/sRGYXf+mgFHxw="; + }; + + pyproject = true; + + nativeBuildInputs = [ + setuptools + setuptools-scm + which + #breakpointHook + ]; + + buildInputs = [ + # cudaPackages.cuda_nvcc + # cudaPackages.cuda_cudart + # cudaPackages.libcusparse + # cudaPackages.cuda_cccl + # cudaPackages.libcublas + # cudaPackages.libcusolver + cudaPackages_11.cudatoolkit.lib + pytorchWithCuda + libglvnd + libGLU + ]; + + propagatedBuildInputs = [ + distinctipy + genjax + graphviz + imageio + liblzfse + matplotlib + meshcat + natsort + open3d + opencv-python + opencv4 + plyfile + pyransac3d + tensorflow-probability + timm + #torch + trimesh + ]; + + preBuild = '' + export CUDA_HOME=${cuda-native-redist} + ''; + + #preferLocalBuild = true; + + pythonImportsCheck = [ + "bayes3d" + ]; +} diff --git a/pkgs/python-modules/distinctipy/default.nix b/pkgs/python-modules/distinctipy/default.nix new file mode 100644 index 0000000..1716aa3 --- /dev/null +++ b/pkgs/python-modules/distinctipy/default.nix @@ -0,0 +1,26 @@ +{ fetchPypi +, buildPythonPackage +, setuptools +, numpy +}: +buildPythonPackage rec { + pname = "distinctipy"; + version = "1.3.4"; + format = "pyproject"; + + src = fetchPypi { + inherit pname version; + hash = "sha256-/tl6//Gvtz7KqHyFRhAh8LqJ+uYwZ8ASW5ZzUmUQqsQ="; + }; + + doCheck = false; + + nativeBuildInputs = [ + setuptools + ]; + + propagatedBuildInputs = [ + numpy + ]; +} + diff --git a/pkgs/distributions/default.nix b/pkgs/python-modules/distributions/default.nix similarity index 94% rename from pkgs/distributions/default.nix rename to pkgs/python-modules/distributions/default.nix index 7d4da3f..23db5b1 100644 --- a/pkgs/distributions/default.nix +++ b/pkgs/python-modules/distributions/default.nix @@ -20,7 +20,8 @@ let distributions-shared = callPackage ./distributions-shared.nix { inherit version src; }; - imageio = python3Packages.buildPythonPackage rec { + # TODO: move into own package + imageio_2_6_1 = python3Packages.buildPythonPackage rec { pname = "imageio"; version = "2.6.1"; @@ -75,7 +76,7 @@ python3Packages.buildPythonPackage { # TODO: be more precise. Some tests seem to be still in Python 2. doCheck = false; nativeCheckInputs = with python3Packages; [ - imageio + imageio_2_6_1 nose goftests pytest diff --git a/pkgs/distributions/distributions-shared.nix b/pkgs/python-modules/distributions/distributions-shared.nix similarity index 100% rename from pkgs/distributions/distributions-shared.nix rename to pkgs/python-modules/distributions/distributions-shared.nix diff --git a/pkgs/distributions/use-imread-instead-of-scipy.patch b/pkgs/python-modules/distributions/use-imread-instead-of-scipy.patch similarity index 100% rename from pkgs/distributions/use-imread-instead-of-scipy.patch rename to pkgs/python-modules/distributions/use-imread-instead-of-scipy.patch diff --git a/pkgs/python-modules/genjax/default.nix b/pkgs/python-modules/genjax/default.nix new file mode 100644 index 0000000..37cc7af --- /dev/null +++ b/pkgs/python-modules/genjax/default.nix @@ -0,0 +1,73 @@ +{ buildPythonPackage +, fetchFromGitHub +, inputs +, poetry-core +, poetry-dynamic-versioning +, fetchPypi +, stdenv +, beartype +, deprecated +, dill +, jax +, jaxtyping +, equinox +, numpy +, optax +, oryx +, plum-dispatch +, pygments +, rich +, tensorflow-probability +, typing-extensions +}: +let + src = stdenv.mkDerivation { + name = "genjax-source"; + version = inputs.genjax.shortRev; + src = inputs.genjax; + + patches = [ + ./set-pyproject-version.patch + ./use-beartype-0.18.0-version.patch + ]; + + installPhase = '' + mkdir $out + ls -alsph $src + cp -rfv ./. $out + ''; + }; +in +buildPythonPackage { + __noChroot = true; + + pname = "genjax"; + version = "0.1.1"; + inherit src; + format = "pyproject"; + + nativeBuildInputs = [ + #poetry + poetry-core + poetry-dynamic-versioning + ]; + + propagatedBuildInputs = [ + beartype + deprecated + dill + jax + jaxtyping + equinox + numpy + optax + oryx + plum-dispatch + pygments + rich + tensorflow-probability + typing-extensions + ]; + + pythonImportsCheck = [ "genjax" ]; +} diff --git a/pkgs/python-modules/genjax/set-pyproject-version.patch b/pkgs/python-modules/genjax/set-pyproject-version.patch new file mode 100644 index 0000000..2b516d1 --- /dev/null +++ b/pkgs/python-modules/genjax/set-pyproject-version.patch @@ -0,0 +1,11 @@ +diff a/pyproject.toml b/pyproject.toml +--- a/pyproject.toml ++++ b/pyproject.toml +@@ -1,6 +1,6 @@ + [tool.poetry] + name = "genjax" +-version = "0.0.0" ++version = "0.1.1" + description = "Probabilistic programming with Gen, built on top of JAX." + authors = [ + "McCoy R. Becker ", diff --git a/pkgs/python-modules/genjax/use-beartype-0.18.0-version.patch b/pkgs/python-modules/genjax/use-beartype-0.18.0-version.patch new file mode 100644 index 0000000..ec4c01e --- /dev/null +++ b/pkgs/python-modules/genjax/use-beartype-0.18.0-version.patch @@ -0,0 +1,14 @@ +diff +--- a/pyproject.toml ++++ b/pyproject.toml +@@ -41,7 +41,7 @@ tensorflow-probability = "^0.23.0" + rich = "^13.7.0" + jaxtyping = "^0.2.24" +-optax = "^0.1.7" ++optax = "^0.2.2" +-beartype = "^0.16.4" ++beartype = "^0.18.0" + dill = "^0.3.7" + pygments = "^2.17.2" + plum-dispatch = "^2.2.2" + diff --git a/pkgs/goftests/default.nix b/pkgs/python-modules/goftests/default.nix similarity index 76% rename from pkgs/goftests/default.nix rename to pkgs/python-modules/goftests/default.nix index a8d70cd..b640623 100644 --- a/pkgs/goftests/default.nix +++ b/pkgs/python-modules/goftests/default.nix @@ -1,8 +1,11 @@ { fetchPypi, - python3Packages, + buildPythonPackage, + numpy, + scipy, }: -python3Packages.buildPythonPackage rec { +# TODO: upstream +buildPythonPackage rec { pname = "goftests"; version = "0.2.7"; format = "setuptools"; @@ -12,7 +15,7 @@ python3Packages.buildPythonPackage rec { hash = "sha256-5s0NugSus2TuZIInesCNJNAtxEHnZLQIjn0pxGgwL/o="; }; - buildInputs = with python3Packages; [ numpy scipy ]; + buildInputs = [ numpy scipy ]; doCheck = false; diff --git a/pkgs/loom/default.nix b/pkgs/python-modules/loom/default.nix similarity index 89% rename from pkgs/loom/default.nix rename to pkgs/python-modules/loom/default.nix index 7a8292e..09cead6 100644 --- a/pkgs/loom/default.nix +++ b/pkgs/python-modules/loom/default.nix @@ -27,14 +27,14 @@ , zlib , eigen , gperftools -, dockerTools -, basicTools +# , dockerTools +# , basicTools +, distributions +, goftests +, parsable +, pymetis }: let - goftests = callPackage ./../goftests { }; - parsable = callPackage ./../parsable { }; - pymetis = callPackage ./../pymetis { }; - distributions = callPackage ./../distributions {inherit goftests parsable;}; protobuf = protobuf3_20; @@ -191,13 +191,14 @@ let ; }; - passthru.ociImg = dockerTools.buildLayeredImage { - name = "probcomp/loom"; - contents = - with pkgs; [ loom bashInteractive ] ++ - basicTools - ; - }; + # TODO: move it to a different package + # passthru.ociImg = dockerTools.buildLayeredImage { + # name = "probcomp/loom"; + # contents = + # with pkgs; [ loom bashInteractive ] ++ + # basicTools + # ; + # }; passthru.tests.run = callPackage ./test.nix { inherit src; }; diff --git a/pkgs/loom/test.nix b/pkgs/python-modules/loom/test.nix similarity index 100% rename from pkgs/loom/test.nix rename to pkgs/python-modules/loom/test.nix diff --git a/pkgs/python-modules/open3d/default.nix b/pkgs/python-modules/open3d/default.nix new file mode 100644 index 0000000..3dac4c2 --- /dev/null +++ b/pkgs/python-modules/open3d/default.nix @@ -0,0 +1,219 @@ +{ stdenv +, lib +, pkgs +, fetchPypi +, fetchurl + +, tree +, unzip +, zip + +, autoPatchelfHook +, python +, tensorflow-bin +, libusb +, cudaPackages_11 +, buildPythonPackage +, ipywidgets +, matplotlib +, numpy +, pandas +, plyfile +, pytorchWithCuda +, pyyaml +, scikitlearn +, scipy +, tqdm + +, libGL +, libglvnd +, libdrm +, expat +, xorg +, llvmPackages_10 +, buildEnv +, runCommand +}: +let + libllvm-wrapped = + let + libllvm = llvmPackages_10.libllvm.lib; + name = libllvm.name; + in + buildEnv { + inherit name; + paths = [ + llvmPackages_10.libllvm.lib + (runCommand "${name}.1" {} "mkdir -p $out/lib && ln -sf ${libllvm}/lib/libLLVM-10.so $out/lib/libLLVM-10.so.1") + ]; + }; + + version = "0.18.0"; + pname = "open3d"; + pythonAbi = "cp311"; + + prebuiltSrcs = { + "3.8-x86_64-linux" = { + platform = "manylinux_2_27_x86_64"; + dist = "cp38"; + hash = ""; + }; + "3.8-aarch64-linux" = { + platform = "manylinux_2_27_aarch64"; + dist = "cp38"; + hash = ""; + }; + "3.8-x86_64-darwin" = { + platform = "macosx_11_0_x86_64"; + dist = "cp38"; + hash = ""; + }; + "3.8-aarch64-darwin" = { + platform = "macosx_13_0_aarch64"; + dist = "cp38"; + hash = ""; + }; + + "3.9-x86_64-linux" = { + platform = "manylinux_2_27_x86_64"; + dist = "cp39"; + hash = ""; + }; + "3.9-aarch64-linux" = { + platform = "manylinux_2_27_aarch64"; + dist = "cp39"; + hash = ""; + }; + "3.9-x86_64-darwin" = { + platform = "macosx_11_0_x86_64"; + dist = "cp39"; + hash = ""; + }; + "3.9-aarch64-darwin" = { + platform = "macosx_13_0_aarch64"; + dist = "cp39"; + hash = ""; + }; + + "3.10-x86_64-linux" = { + platform = "manylinux_2_27_x86_64"; + dist = "cp310"; + hash = ""; + }; + "3.10-aarch64-linux" = { + platform = "manylinux_2_27_aarch64"; + dist = "cp310"; + hash = ""; + }; + "3.10-x86_64-darwin" = { + platform = "macosx_11_0_x86_64"; + dist = "cp310"; + hash = ""; + }; + "3.10-aarch64-darwin" = { + platform = "macosx_13_0_aarch64"; + dist = "cp310"; + hash = ""; + }; + + "3.11-x86_64-linux" = { + platform = "manylinux_2_27_x86_64"; + dist = "cp311"; + hash = "sha256-jj0dGQCo9NlW9oGcJGx4CBclubCIj4VJ0qeknI2qEwM="; + }; + "3.11-aarch64-linux" = { + platform = "manylinux_2_27_aarch64"; + dist = "cp311"; + hash = ""; + }; + "3.11-x86_64-darwin" = { + platform = "macosx_11_0_x86_64"; + dist = "cp311"; + hash = ""; + }; + "3.11-aarch64-darwin" = { + platform = "macosx_13_0_aarch64"; + dist = "cp311"; + hash = ""; + }; + }; + + pyVersion = lib.versions.majorMinor python.version; + srcInputs = prebuiltSrcs."${pyVersion}-${stdenv.system}" or (throw "open3d-bin for Python version '${pyVersion}' is not supported on '${stdenv.system}'"); + + src = fetchPypi rec { + inherit pname version; + inherit (srcInputs) platform dist hash; + + python = dist; + abi = dist; + + format = "wheel"; + }; +in +buildPythonPackage { + inherit pname version; + format = "wheel"; + + # TODO: make this multiplatform + inherit src; + + patchPhase = '' + ${unzip}/bin/unzip ./dist/open3d-${version}-${srcInputs.dist}-${srcInputs.dist}-${srcInputs.platform}.whl -d tmp + rm ./dist/open3d-${version}-${srcInputs.dist}-${srcInputs.dist}-${srcInputs.platform}.whl + #sed -i 's/sklearn/scikit-learn/g' tmp/open3d-${version}.dist-info/METADATA + cd tmp + ${zip}/bin/zip -0 -r ../dist/open3d-${version}-${srcInputs.dist}-${srcInputs.dist}-${srcInputs.platform}.whl ./* + cd ../ + ''; + + nativeBuildInputs = [ + autoPatchelfHook + ]; + + buildInputs = [ + # so deps + stdenv.cc.cc.lib + libusb.out + pytorchWithCuda + tensorflow-bin + cudaPackages_11.cudatoolkit.lib + #cudaPackages_11.cuda_cudart.lib + libGL + libglvnd + libdrm + expat + xorg.libXxf86vm + xorg.libXfixes + libllvm-wrapped + pkgs.mesa + pkgs.zstd + ]; + + propagatedBuildInputs = [ + # py deps + ipywidgets + tqdm + pyyaml + pandas + plyfile + scipy + scikitlearn + numpy + #addict + matplotlib + ]; + + #preBuild = '' + #mkdir $out + #''; + + preFixup = '' + echo "OUTPUT TO: $out" + cd $out/lib/python3.*/site-packages/open3d + rm libGL.so.1 libEGL.so.1 + ln -s ${libGL}/lib/libGL.so.1 libGL.so.1 + ln -s ${libGL}/lib/libEGL.so.1 libEGL.so.1 + #exit 1 + ''; +} diff --git a/pkgs/python-modules/opencv-python/default.nix b/pkgs/python-modules/opencv-python/default.nix new file mode 100644 index 0000000..0f8af01 --- /dev/null +++ b/pkgs/python-modules/opencv-python/default.nix @@ -0,0 +1,109 @@ +{ lib +, stdenv +, buildPythonPackage +, fetchurl +, fetchPypi +, autoPatchelfHook + +, unzip +, zip + +, git +, cmake +, numpy +, pip +, scikit-build +, setuptools +, wheel +, gcc +, libGL +, xorg +, libz +, qt5 + +, breakpointHook +}: +let + version = "4.10.0.84"; + pname = "opencv-python"; + pythonAbi = "cp311"; + pythonPlatform = "manylinux_2_27_x86_64"; + + wheelUrls = { + "x86_64-linux" = { + url = "https://files.pythonhosted.org/packages/3f/a4/d2537f47fd7fcfba966bd806e3ec18e7ee1681056d4b0a9c8d983983e4d5/opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"; + hash = "sha256-ms4UD8bWR/vhxpK8sqvOdolzSRIiwGfBMdgJV8WVtx8="; + }; + + "aarch64-linux" = { + url = "https://files.pythonhosted.org/packages/81/e4/7a987ebecfe5ceaf32db413b67ff18eb3092c598408862fff4d7cc3fd19b/opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl"; + hash = ""; + }; + + "x86_64-darwin" = { + url = "https://files.pythonhosted.org/packages/64/4a/016cda9ad7cf18c58ba074628a4eaae8aa55f3fd06a266398cef8831a5b9/opencv_python-4.10.0.84-cp37-abi3-macosx_12_0_x86_64.whl"; + hash = ""; + }; + + "aarch64-darwin" = { + url = "https://files.pythonhosted.org/packages/66/82/564168a349148298aca281e342551404ef5521f33fba17b388ead0a84dc5/opencv_python-4.10.0.84-cp37-abi3-macosx_11_0_arm64.whl"; + hash = ""; + }; + }; + + src = fetchurl ( + if builtins.hasAttr stdenv.system wheelUrls + then wheelUrls.${stdenv.system} + else throw "Unsupported system" + ); +in +buildPythonPackage rec { + inherit pname version; + format = "wheel"; + + inherit src; + + #patchPhase = '' + #pwd + + #${unzip}/bin/unzip $src/${pname}-${version}-${pythonAbi}-${pythonAbi}-${pythonPlatform}.whl -d tmp + #cd tmp + #${zip}/bin/zip -0 -r ../dist/${pname}-${version}-${pythonAbi}-${pythonAbi}-${pythonPlatform}.whl ./* + #cd ../ + #''; + + nativeBuildInputs = [ + breakpointHook + ] + ++ lib.optionals stdenv.isLinux [ + autoPatchelfHook + ]; + + dontWrapQtApps = true; + + buildInputs = [ + stdenv.cc.cc.lib + libGL + libz + qt5.qtbase + ] ++ lib.optionals stdenv.isLinux [ + xorg.libxcb + xorg.libXext + xorg.libX11 + xorg.libSM + xorg.libICE + ]; + + propagatedBuildInputs = [ + numpy + ]; + + pythonImportsCheck = [ "cv2" ]; + + meta = with lib; { + description = "Wrapper package for OpenCV python bindings"; + homepage = "https://pypi.org/project/opencv-python"; + license = with licenses; [ asl20 mit ]; + maintainers = with maintainers; [ ]; + }; +} diff --git a/pkgs/python-modules/opencv-python/relax-dependency-ranges.patch b/pkgs/python-modules/opencv-python/relax-dependency-ranges.patch new file mode 100644 index 0000000..8349092 --- /dev/null +++ b/pkgs/python-modules/opencv-python/relax-dependency-ranges.patch @@ -0,0 +1,26 @@ +diff a/pyproject.toml b/pyproject.toml +--- a/pyproject.toml ++++ b/pyproject.toml +@@ -6,8 +6,8 @@ requires = [ + "numpy==1.17.5; python_version=='3.8' and platform_machine != 'aarch64' and platform_machine != 'arm64'", + "numpy==1.19.3; python_version<'3.9' and sys_platform == 'linux' and platform_machine == 'aarch64'", + "numpy==1.21.0; python_version<'3.9' and sys_platform == 'darwin' and platform_machine == 'arm64'", ++ "numpy>=1.21.0; python_version>'3.9'", +- "numpy>=2.0.0; python_version>='3.9'", + "pip", + "scikit-build>=0.14.0", +- "setuptools==59.2.0", ++ "setuptools>=59.2.0", + ] + diff a/setup.py b/setup.py +--- a/setup.py ++++ b/setup.py +@@ -32,7 +32,7 @@ def main(): + 'numpy>=1.21.0; python_version<="3.9" and platform_system=="Darwin" and platform_machine=="arm64"', + 'numpy>=1.21.4; python_version>="3.10" and platform_system=="Darwin"', + "numpy>=1.23.5; python_version>='3.11'", +- "numpy>=1.26.0; python_version>='3.12'" ++ "numpy>=1.26.0; python_version>='3.11'" + ] + + python_version = cmaker.CMaker.get_python_version() diff --git a/pkgs/python-modules/oryx/default.nix b/pkgs/python-modules/oryx/default.nix new file mode 100644 index 0000000..82589ed --- /dev/null +++ b/pkgs/python-modules/oryx/default.nix @@ -0,0 +1,38 @@ +{ lib +, buildPythonPackage +, fetchPypi +, poetry-core +, jax +, jaxlib-bin +, tensorflow-probability +}: + +buildPythonPackage rec { + pname = "oryx"; + version = "0.2.6"; + pyproject = true; + + src = fetchPypi { + inherit pname version; + hash = "sha256-8spdSJN9e9jDdc6KxSmh7Z+NoxJjNNLgz91rfxuepI8="; + }; + + nativeBuildInputs = [ + poetry-core + ]; + + propagatedBuildInputs = [ + jax + jaxlib-bin + tensorflow-probability + ]; + + pythonImportsCheck = [ "oryx" ]; + + meta = with lib; { + description = "Probabilistic programming and deep learning in JAX"; + homepage = "https://pypi.org/project/oryx"; + license = licenses.asl20; + maintainers = with maintainers; [ ]; + }; +} diff --git a/pkgs/parsable/default.nix b/pkgs/python-modules/parsable/default.nix similarity index 80% rename from pkgs/parsable/default.nix rename to pkgs/python-modules/parsable/default.nix index 2a5ea1b..f9aa16e 100644 --- a/pkgs/parsable/default.nix +++ b/pkgs/python-modules/parsable/default.nix @@ -1,8 +1,8 @@ { fetchPypi, - python3Packages + buildPythonPackage }: -python3Packages.buildPythonPackage rec { +buildPythonPackage rec { pname = "parsable"; version = "0.3.1"; format = "setuptools"; diff --git a/pkgs/python-modules/plum-dispatch/default.nix b/pkgs/python-modules/plum-dispatch/default.nix new file mode 100644 index 0000000..01acda7 --- /dev/null +++ b/pkgs/python-modules/plum-dispatch/default.nix @@ -0,0 +1,76 @@ +{ lib +, buildPythonPackage +, fetchPypi +, hatch-vcs +, hatchling +, beartype +, rich +, typing-extensions +, black +, build +, coveralls +, ghp-import +, ipython +, jupyter-book +, mypy +, numpy +, pre-commit +, pyright +, pytest +, pytest-cov +, ruff +, tox +, wheel +}: + +buildPythonPackage rec { + pname = "plum-dispatch"; + version = "2.3.5"; + pyproject = true; + + src = fetchPypi { + pname = "plum_dispatch"; + inherit version; + hash = "sha256-eticwgKdh7Djusx8x3Pxlq4ynOEV8wi2Ly0GxosYo40="; + }; + + nativeBuildInputs = [ + hatch-vcs + hatchling + ]; + + propagatedBuildInputs = [ + beartype + rich + typing-extensions + ]; + + passthru.optional-dependencies = { + dev = [ + black + build + coveralls + ghp-import + ipython + jupyter-book + mypy + numpy + pre-commit + pyright + pytest + pytest-cov + ruff + tox + wheel + ]; + }; + + pythonImportsCheck = [ "plum" ]; + + meta = with lib; { + description = "Multiple dispatch in Python"; + homepage = "https://pypi.org/project/plum-dispatch"; + license = licenses.mit; + maintainers = with maintainers; [ ]; + }; +} diff --git a/pkgs/pymetis/default.nix b/pkgs/python-modules/pymetis/default.nix similarity index 74% rename from pkgs/pymetis/default.nix rename to pkgs/python-modules/pymetis/default.nix index f6ed7f5..2f663ef 100644 --- a/pkgs/pymetis/default.nix +++ b/pkgs/python-modules/pymetis/default.nix @@ -1,7 +1,8 @@ { fetchPypi -, python3Packages +, buildPythonPackage +, pybind11 }: -python3Packages.buildPythonPackage rec { +buildPythonPackage rec { pname = "PyMetis"; version = "2023.1.1"; format = "setuptools"; @@ -14,6 +15,6 @@ python3Packages.buildPythonPackage rec { doCheck = false; nativeBuildInputs = [ - python3Packages.pybind11 + pybind11 ]; } diff --git a/pkgs/python-modules/pyransac3d/default.nix b/pkgs/python-modules/pyransac3d/default.nix new file mode 100644 index 0000000..89b6b39 --- /dev/null +++ b/pkgs/python-modules/pyransac3d/default.nix @@ -0,0 +1,31 @@ +{ fetchFromGitHub +, python3 +, buildPythonPackage +, setuptools +, wheel +, numpy +}: +# TODO: upstream me +buildPythonPackage rec { + pname = "pyransac3d"; + version = "0.6.0"; + pyproject = true; + + src = fetchFromGitHub { + owner = "leomariga"; + repo = "pyRANSAC-3D"; + rev = "v${version}"; + hash = "sha256-QplIgH+zjkZgPWMvvpV2yM/HEEBRea4D+dG7G7h2jdQ="; + }; + + nativeBuildInputs = [ + setuptools + wheel + ]; + + propagatedBuildInputs = [ + numpy + ]; + + pythonImportsCheck = [ "pyransac3d" ]; +} diff --git a/pkgs/sppl/default.nix b/pkgs/python-modules/sppl/default.nix similarity index 94% rename from pkgs/sppl/default.nix rename to pkgs/python-modules/sppl/default.nix index eb9c4ec..edb97bd 100644 --- a/pkgs/sppl/default.nix +++ b/pkgs/python-modules/sppl/default.nix @@ -2,6 +2,7 @@ system, ... }: let + # FIXME: check if we still need python 3.9. Can it be switched to 3.11? # relies on specific versions of deps that are no longer present in # nixpkgs stable; we must checkout a specific SHA diff --git a/pkgs/python-modules/tensorflow-probability/default.nix b/pkgs/python-modules/tensorflow-probability/default.nix new file mode 100644 index 0000000..01ba07f --- /dev/null +++ b/pkgs/python-modules/tensorflow-probability/default.nix @@ -0,0 +1,125 @@ +{ + lib, + stdenv, + fetchFromGitHub, + bazel_6, + buildBazelPackage, + buildPythonPackage, + python, + setuptools, + wheel, + absl-py, + tensorflow, + six, + numpy, + dm-tree, + keras, + decorator, + cloudpickle, + gast, + hypothesis, + scipy, + pandas, + mpmath, + matplotlib, + mock, + pytest, + darwin, +}: + +let + version = "0.23.0"; + pname = "tensorflow-probability"; + + inherit (darwin) cctools; + + # first build all binaries and generate setup.py using bazel + bazel-wheel = buildBazelPackage { + name = "tensorflow_probability-${version}-py2.py3-none-any.whl"; + src = fetchFromGitHub { + owner = "tensorflow"; + repo = "probability"; + rev = "refs/tags/v${version}"; + hash = "sha256-cZTlWfg3pIFnJz5xrQhcS1uvRMfIOOxcUxN747txd28="; + }; + nativeBuildInputs = [ + # needed to create the output wheel in installPhase + python + setuptools + wheel + absl-py + tensorflow + ]; + + bazel = bazel_6; + + bazelTargets = [ ":pip_pkg" ]; + LIBTOOL = lib.optionalString stdenv.isDarwin "${cctools}/bin/libtool"; + + fetchAttrs = { + sha256 = "sha256-TbWcWYidyXuAMgBnO2/k0NKCzc4wThf2uUeC3QxdBJY="; + }; + + buildAttrs = { + preBuild = '' + patchShebangs . + ''; + + installPhase = '' + # work around timestamp issues + # https://github.com/NixOS/nixpkgs/issues/270#issuecomment-467583872 + export SOURCE_DATE_EPOCH=315532800 + + # First build, then move. Otherwise pip_pkg would create the dir $out + # and then put the wheel in that directory. However we want $out to + # point directly to the wheel file. + ./bazel-bin/pip_pkg . --release + mv *.whl "$out" + ''; + }; + }; +in +buildPythonPackage { + inherit version pname; + format = "wheel"; + + src = bazel-wheel; + + propagatedBuildInputs = [ + tensorflow + six + numpy + decorator + cloudpickle + gast + dm-tree + keras + ]; + + # Listed here: + # https://github.com/tensorflow/probability/blob/f3777158691787d3658b5e80883fe1a933d48989/testing/dependency_install_lib.sh#L83 + nativeCheckInputs = [ + hypothesis + pytest + scipy + pandas + mpmath + matplotlib + mock + ]; + + # Ideally, we run unit tests with pytest, but in checkPhase, only the Bazel-build wheel is available. + # But it seems not guaranteed that running the tests with pytest will even work, see + # https://github.com/tensorflow/probability/blob/c2a10877feb2c4c06a4dc58281e69c37a11315b9/CONTRIBUTING.md?plain=1#L69 + # Ideally, tests would be run using Bazel. For now, lets's do a... + + # sanity check + pythonImportsCheck = [ "tensorflow_probability" ]; + + meta = with lib; { + description = "Library for probabilistic reasoning and statistical analysis"; + homepage = "https://www.tensorflow.org/probability/"; + license = licenses.asl20; + maintainers = with maintainers; [ GaetanLepage ]; + }; +}