-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflake.nix
41 lines (38 loc) · 1.04 KB
/
flake.nix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
{
description = "PyTorch";
inputs = {
nixpkgs.url = "github:nixos/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
shell-utils.url = "github:waltermoreira/shell-utils";
hpcKit-utils.url = "github:Nixify-Technology/intel-mpi-nix";
};
outputs =
{ self
, nixpkgs
, flake-utils
, shell-utils
, hpcKit-utils
}:
with flake-utils.lib; eachSystem [
system.x86_64-linux
]
(system:
let
pkgs = import nixpkgs { inherit system; };
oneapi-version = "2021.10.0";
shell = shell-utils.myShell.${system};
hpcKit = hpcKit-utils.packages.${system}.default;
in
{
devShells.default = shell {
name = "PyTorch";
packages = [
hpcKit
pkgs.mpich
pkgs.cudaPackages.cudnn
# replace `torch` with `torchWithCuda` for a CUDA version
(pkgs.python311.withPackages (ps: with ps; [ torch torchvision ]))
];
};
});
}