diff --git a/ci/Dockerfile b/ci/Dockerfile index 911d62e..447dd8b 100644 --- a/ci/Dockerfile +++ b/ci/Dockerfile @@ -1,8 +1,8 @@ -FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 +FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y \ - python3 \ - python3-pip \ + python3-full \ + python3-dev \ git \ libfftw3-dev diff --git a/ci/Jenkinsfile b/ci/Jenkinsfile index a097b46..1c0d4a6 100644 --- a/ci/Jenkinsfile +++ b/ci/Jenkinsfile @@ -18,27 +18,37 @@ pipeline { CMAKE_ARGS = "-DJAX_FINUFFT_USE_CUDA=ON" } steps { - sh 'python3 -m pip install -U pip' - sh 'python3 -m pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' - sh 'python3 -m pip install -v .[test]' + sh ''' + python3 -m venv venv + . venv/bin/activate + pip install -U pip + pip install -U "jax[cuda12]" + pip install -v ".[test]" + ''' } } stage('CPU Tests') { environment { JAX_PLATFORMS = "cpu" - OMP_NUM_THREADS = "4" + OMP_NUM_THREADS = "${env.PARALLEL}" } steps { - sh 'python3 -m pytest -v tests/' + sh ''' + . venv/bin/activate + pytest -v tests/ + ''' } } stage('GPU Tests') { environment { JAX_PLATFORMS = "cuda" - OMP_NUM_THREADS = "4" + OMP_NUM_THREADS = "${env.PARALLEL}" } steps { - sh 'python3 -m pytest -v tests/' + sh ''' + . venv/bin/activate + pytest -v tests/ + ''' } } }