Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly use the current device resource in DeviceBuffer #1514

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions python/rmm/_lib/device_buffer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,31 @@ from libcpp.memory cimport unique_ptr

from rmm._cuda.stream cimport Stream
from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm._lib.memory_resource cimport DeviceMemoryResource
from rmm._lib.memory_resource cimport (
DeviceMemoryResource,
device_memory_resource,
)


cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
cdef cppclass device_buffer:
device_buffer()
device_buffer(size_t size, cuda_stream_view stream) except +
device_buffer(const void* source_data,
size_t size, cuda_stream_view stream) except +
device_buffer(const device_buffer buf,
cuda_stream_view stream) except +
device_buffer(
size_t size,
cuda_stream_view stream,
device_memory_resource *
) except +
device_buffer(
const void* source_data,
size_t size,
cuda_stream_view stream,
device_memory_resource *
) except +
device_buffer(
const device_buffer buf,
cuda_stream_view stream,
device_memory_resource *
) except +
void reserve(size_t new_capacity, cuda_stream_view stream) except +
void resize(size_t new_size, cuda_stream_view stream) except +
void shrink_to_fit(cuda_stream_view stream) except +
Expand Down
22 changes: 12 additions & 10 deletions python/rmm/_lib/device_buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ from cuda.ccudart cimport (
cudaStream_t,
)

from rmm._lib.memory_resource cimport get_current_device_resource
from rmm._lib.memory_resource cimport (
device_memory_resource,
get_current_device_resource,
)


# The DeviceMemoryResource attribute could be released prematurely
Expand Down Expand Up @@ -75,24 +78,23 @@ cdef class DeviceBuffer:
>>> db = rmm.DeviceBuffer(size=5)
"""
cdef const void* c_ptr
cdef device_memory_resource * mr_ptr
# Save a reference to the MR and stream used for allocation
self.mr = get_current_device_resource()
self.stream = stream

mr_ptr = self.mr.get_mr()
with nogil:
c_ptr = <const void*>ptr

if size == 0:
self.c_obj.reset(new device_buffer())
elif c_ptr == NULL:
self.c_obj.reset(new device_buffer(size, stream.view()))
if c_ptr == NULL or size == 0:
self.c_obj.reset(new device_buffer(size, stream.view(), mr_ptr))
else:
self.c_obj.reset(new device_buffer(c_ptr, size, stream.view()))
self.c_obj.reset(new device_buffer(c_ptr, size, stream.view(), mr_ptr))

if stream.c_is_default():
stream.c_synchronize()

# Save a reference to the MR and stream used for allocation
self.mr = get_current_device_resource()
self.stream = stream

def __len__(self):
return self.size

Expand Down
2 changes: 1 addition & 1 deletion python/rmm/_lib/memory_resource.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ cdef extern from "rmm/mr/device/device_memory_resource.hpp" \

cdef class DeviceMemoryResource:
cdef shared_ptr[device_memory_resource] c_obj
cdef device_memory_resource* get_mr(self)
cdef device_memory_resource* get_mr(self) noexcept nogil

cdef class UpstreamResourceAdaptor(DeviceMemoryResource):
cdef readonly DeviceMemoryResource upstream_mr
Expand Down
2 changes: 1 addition & 1 deletion python/rmm/_lib/memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \

cdef class DeviceMemoryResource:

cdef device_memory_resource* get_mr(self):
cdef device_memory_resource* get_mr(self) noexcept nogil:
"""Get the underlying C++ memory resource object."""
return self.c_obj.get()

Expand Down
Loading