From b1c1a801c42ea645fdf64567dc5367946f576c1b Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 2 Oct 2024 21:16:42 +0200 Subject: [PATCH] Add support for cuda streams. --- candle-core/src/cuda_backend/device.rs | 14 ++++++++++++++ candle-core/src/device.rs | 4 ++++ candle-core/src/dummy_cuda_backend.rs | 6 ++++++ 3 files changed, 24 insertions(+) diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 0aa58cacde..89fe44a6e6 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -144,6 +144,20 @@ impl CudaDevice { } } +impl CudaDevice { + pub fn new_with_stream(ordinal: usize) -> Result { + let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?; + let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; + let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; + Ok(Self { + id: DeviceId::new(), + device, + blas: Arc::new(blas), + curand: Arc::new(Mutex::new(CudaRng(curand))), + }) + } +} + impl BackendDevice for CudaDevice { type Storage = CudaStorage; diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 91e569372d..c4a8e9361e 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -130,6 +130,10 @@ impl Device { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } + pub fn new_cuda_with_stream(ordinal: usize) -> Result { + Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?)) + } + pub fn new_metal(ordinal: usize) -> Result { Ok(Self::Metal(crate::MetalDevice::new(ordinal)?)) } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 68eef1efed..b4f2e8aa00 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -14,6 +14,12 @@ macro_rules! fail { }; } +impl CudaDevice { + pub fn new_with_stream(_: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } +} + impl crate::backend::BackendStorage for CudaStorage { type Device = CudaDevice;