Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_context.pxd
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

Expand Down
1 change: 0 additions & 1 deletion cuda_core/cuda/core/_memory/_device_memory_resource.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ from cuda.core._memory._ipc cimport IPCDataForMR
cdef class DeviceMemoryResource(_MemPool):
cdef:
int _dev_id
object _peer_accessible_by


cpdef DMR_mempool_get_access(DeviceMemoryResource, int)
107 changes: 10 additions & 97 deletions cuda_core/cuda/core/_memory/_device_memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ from cuda.core._utils.cuda_utils cimport (
check_or_create_options,
HANDLE_RETURN,
)
from cpython.mem cimport PyMem_Malloc, PyMem_Free

from dataclasses import dataclass
import multiprocessing
import platform # no-cython-lint
import uuid

from cuda.core._memory._peer_access_utils import plan_peer_access_update
from cuda.core._memory._peer_access_utils import PeerAccessibleBySetProxy, replace_peer_accessible_by
from cuda.core._utils.cuda_utils import check_multiprocessing_start_method

__all__ = ['DeviceMemoryResource', 'DeviceMemoryResourceOptions']
Expand Down Expand Up @@ -131,7 +129,6 @@ cdef class DeviceMemoryResource(_MemPool):

def __cinit__(self, *args, **kwargs):
self._dev_id = cydriver.CU_DEVICE_INVALID
self._peer_accessible_by = None

def __init__(self, device_id: Device | int, options=None):
_DMR_init(self, device_id, options)
Expand Down Expand Up @@ -191,7 +188,6 @@ cdef class DeviceMemoryResource(_MemPool):
_ipc.MP_from_allocation_handle(cls, alloc_handle))
from .._device import Device
mr._dev_id = Device(device_id).device_id
mr._peer_accessible_by = ()
return mr

@property
Expand All @@ -217,30 +213,23 @@ cdef class DeviceMemoryResource(_MemPool):
pool. Access can be modified at any time and affects all allocations
from this memory pool.

Returns a tuple of sorted device IDs that currently have peer access to
allocations from this memory pool.

When setting, accepts a sequence of :obj:`~_device.Device` objects or device IDs.
Setting to an empty sequence revokes all peer access.

For non-owned pools (the default or current device pool), the state
is always queried from the driver to reflect changes made by other
wrappers or direct driver calls.
Returns a set-like proxy of :obj:`~_device.Device` objects that manages
peer access. Inputs are accepted as either :obj:`~_device.Device`
objects or device-ordinal :class:`int` values.

Examples
--------
>>> dmr = DeviceMemoryResource(0)
>>> dmr.peer_accessible_by = [1] # Grant access to device 1
>>> assert dmr.peer_accessible_by == (1,)
>>> dmr.peer_accessible_by = [] # Revoke access
>>> dmr.peer_accessible_by = {1} # grant access to device 1
>>> assert 1 in dmr.peer_accessible_by
>>> dmr.peer_accessible_by.add(2) # update access to include device 2
>>> dmr.peer_accessible_by = [] # revoke peer access
"""
if not self._mempool_owned:
_DMR_query_peer_access(self)
return self._peer_accessible_by
return PeerAccessibleBySetProxy(self)

@peer_accessible_by.setter
def peer_accessible_by(self, devices):
_DMR_set_peer_accessible_by(self, devices)
replace_peer_accessible_by(self, devices)

@property
def is_device_accessible(self) -> bool:
Expand All @@ -253,81 +242,6 @@ cdef class DeviceMemoryResource(_MemPool):
return False


cdef inline _DMR_query_peer_access(DeviceMemoryResource self):
"""Query the driver for the actual peer access state of this pool."""
cdef int total
cdef cydriver.CUmemAccess_flags flags
cdef cydriver.CUmemLocation location
cdef list peers = []

with nogil:
HANDLE_RETURN(cydriver.cuDeviceGetCount(&total))

location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
for dev_id in range(total):
if dev_id == self._dev_id:
continue
location.id = dev_id
with nogil:
HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, as_cu(self._h_pool), &location))
if flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE:
peers.append(dev_id)

self._peer_accessible_by = tuple(sorted(peers))


cdef inline _DMR_set_peer_accessible_by(DeviceMemoryResource self, devices):
from .._device import Device

this_dev = Device(self._dev_id)
cdef object resolve_device_id = lambda dev: Device(dev).device_id
cdef object plan
cdef tuple target_ids
cdef tuple to_add
cdef tuple to_rm
if not self._mempool_owned:
_DMR_query_peer_access(self)
plan = plan_peer_access_update(
owner_device_id=self._dev_id,
current_peer_ids=self._peer_accessible_by,
requested_devices=devices,
resolve_device_id=resolve_device_id,
can_access_peer=this_dev.can_access_peer,
)
target_ids = plan.target_ids
to_add = plan.to_add
to_rm = plan.to_remove
cdef size_t count = len(to_add) + len(to_rm)
cdef cydriver.CUmemAccessDesc* access_desc = NULL
cdef size_t i = 0

if count > 0:
access_desc = <cydriver.CUmemAccessDesc*>PyMem_Malloc(count * sizeof(cydriver.CUmemAccessDesc))
if access_desc == NULL:
raise MemoryError("Failed to allocate memory for access descriptors")

try:
for dev_id in to_add:
access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
access_desc[i].location.id = dev_id
i += 1

for dev_id in to_rm:
access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_NONE
access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
access_desc[i].location.id = dev_id
i += 1

with nogil:
HANDLE_RETURN(cydriver.cuMemPoolSetAccess(as_cu(self._h_pool), access_desc, count))
finally:
if access_desc != NULL:
PyMem_Free(access_desc)

self._peer_accessible_by = tuple(target_ids)


cdef inline _DMR_init(DeviceMemoryResource self, device_id, options):
from .._device import Device
cdef int dev_id = Device(device_id).device_id
Expand All @@ -351,7 +265,6 @@ cdef inline _DMR_init(DeviceMemoryResource self, device_id, options):
self._mempool_owned = False
MP_raise_release_threshold(self)
else:
self._peer_accessible_by = ()
MP_init_create_pool(
self,
cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE,
Expand Down
59 changes: 0 additions & 59 deletions cuda_core/cuda/core/_memory/_peer_access_utils.py

This file was deleted.

Loading
Loading