-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mulit-GPU and CUDA Stream Support #60
base: main
Are you sure you want to change the base?
Changes from all commits
72ce285
d463d21
d81ec9c
c8d8776
0b06e79
de5a3da
b2c170a
f78161c
e89350b
ea54e9a
1bb35aa
497d811
b9be358
d9fc283
5b04a18
af2ec60
9fee2a1
fc80766
2ed5771
62b73d8
b6d4795
68223e2
3b4527f
e3c70c2
ade1cd3
263984d
e654289
899ea94
99cedf5
21b6235
dd29ac0
7f92034
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
FROM nvcr.io/nvidia/pytorch:23.10-py3 | ||
|
||
RUN apt-get update | ||
|
||
# install boost test framework | ||
RUN apt-get install -y libboost-test-dev |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
include: | ||
- remote: 'https://gitlab.com/cscs-ci/recipes/-/raw/master/templates/v2/.ci-ext.yml' | ||
|
||
stages: | ||
- build | ||
- test | ||
|
||
build_base_image_job: | ||
stage: build | ||
extends: .container-builder-dynamic-name | ||
timeout: 2h | ||
variables: | ||
DOCKERFILE: ci/docker/Dockerfile.base | ||
WATCH_FILECHANGES: $DOCKERFILE | ||
PERSIST_IMAGE_NAME: $CSCS_REGISTRY_PATH/base/public/mops | ||
|
||
test_job: | ||
stage: test | ||
extends: .container-runner-daint-gpu | ||
image: $BASE_IMAGE | ||
timeout: 2h | ||
script: | ||
- export CUDA_HOME="/usr/local/cuda" | ||
- python3 -m pip install --upgrade pip | ||
- echo "Install Tox" | ||
- python3 -m pip install tox | ||
- echo "Run the Tox Script" | ||
- tox | ||
- echo "Tox script completed" | ||
|
||
variables: | ||
SLURM_JOB_NUM_NODES: 1 | ||
SLURM_PARTITION: normal | ||
SLURM_NTASKS: 1 | ||
SLURM_TIMELIMIT: '00:40:00' | ||
GIT_STRATEGY: fetch |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
#ifdef MOPS_CUDA_ENABLED | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
#endif | ||
|
||
Comment on lines
+1
to
+5
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've fixed this for SAP. OPSAW and SASAW aren't implemented yet (SASAW is in a different branch) so when I get back to that I'll make it consistent. |
||
#include "mops/torch/sap.hpp" | ||
#include "mops/torch/utils.hpp" | ||
|
||
|
@@ -59,6 +64,14 @@ torch::Tensor SparseAccumulationOfProducts::forward( | |
); | ||
}); | ||
} else if (A.device().is_cuda()) { | ||
|
||
#ifndef MOPS_CUDA_ENABLED | ||
C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); | ||
#else | ||
c10::cuda::CUDAGuard deviceGuard{A.device()}; | ||
cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); | ||
void* stream = reinterpret_cast<void*>(currstream); | ||
|
||
output = torch::empty( | ||
{A.size(0), output_size}, | ||
torch::TensorOptions().dtype(A.scalar_type()).device(A.device()) | ||
|
@@ -72,9 +85,11 @@ torch::Tensor SparseAccumulationOfProducts::forward( | |
details::torch_to_mops_1d<scalar_t>(C), | ||
details::torch_to_mops_1d<int32_t>(indices_A), | ||
details::torch_to_mops_1d<int32_t>(indices_B), | ||
details::torch_to_mops_1d<int32_t>(indices_output) | ||
details::torch_to_mops_1d<int32_t>(indices_output), | ||
stream | ||
); | ||
}); | ||
#endif | ||
} else { | ||
C10_THROW_ERROR( | ||
ValueError, | ||
|
@@ -170,6 +185,14 @@ std::vector<torch::Tensor> SparseAccumulationOfProductsBackward::forward( | |
); | ||
}); | ||
} else if (A.device().is_cuda()) { | ||
|
||
#ifndef MOPS_CUDA_ENABLED | ||
C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); | ||
#else | ||
c10::cuda::CUDAGuard deviceGuard{A.device()}; | ||
cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); | ||
void* stream = reinterpret_cast<void*>(currstream); | ||
|
||
AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "sparse_accumulation_of_products_vjp", [&]() { | ||
auto mops_grad_A = mops::Tensor<scalar_t, 2>{nullptr, {0, 0}}; | ||
if (A.requires_grad()) { | ||
|
@@ -192,9 +215,11 @@ std::vector<torch::Tensor> SparseAccumulationOfProductsBackward::forward( | |
details::torch_to_mops_1d<scalar_t>(C), | ||
details::torch_to_mops_1d<int32_t>(indices_A), | ||
details::torch_to_mops_1d<int32_t>(indices_B), | ||
details::torch_to_mops_1d<int32_t>(indices_output) | ||
details::torch_to_mops_1d<int32_t>(indices_output), | ||
stream | ||
); | ||
}); | ||
#endif | ||
} else { | ||
C10_THROW_ERROR( | ||
ValueError, | ||
|
@@ -276,6 +301,10 @@ std::vector<torch::Tensor> SparseAccumulationOfProductsBackward::backward( | |
#ifndef MOPS_CUDA_ENABLED | ||
C10_THROW_ERROR(ValueError, "MOPS was not compiled with CUDA support " + A.device().str()); | ||
#else | ||
c10::cuda::CUDAGuard deviceGuard{A.device()}; | ||
cudaStream_t currstream = c10::cuda::getCurrentCUDAStream(); | ||
void* stream = reinterpret_cast<void*>(currstream); | ||
|
||
AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "sparse_accumulation_of_products_vjp_vjp", [&]() { | ||
auto mops_grad_grad_output = mops::Tensor<scalar_t, 2>{nullptr, {0, 0}}; | ||
if (grad_output.requires_grad()) { | ||
|
@@ -317,7 +346,8 @@ std::vector<torch::Tensor> SparseAccumulationOfProductsBackward::backward( | |
details::torch_to_mops_1d<scalar_t>(C), | ||
details::torch_to_mops_1d<int32_t>(indices_A), | ||
details::torch_to_mops_1d<int32_t>(indices_B), | ||
details::torch_to_mops_1d<int32_t>(indices_output) | ||
details::torch_to_mops_1d<int32_t>(indices_output), | ||
stream | ||
); | ||
}); | ||
#endif | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this
deviceGuard
do? I see that it's not being used explicitlyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it sets the current CUDA device to be the same one as A.device()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no easy way that I can see for us to test whether a kernel has launched on a specific stream from PyTorch. We can probably do this with the CUDA API but that seems a bit overkill.