diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 5eb4716b98..029ac46d66 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -63,6 +63,7 @@ decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr; // Graph decltype(&cuGraphDestroy) p_cuGraphDestroy = nullptr; +decltype(&cuGraphExecDestroy) p_cuGraphExecDestroy = nullptr; // Linker decltype(&cuLinkDestroy) p_cuLinkDestroy = nullptr; @@ -952,6 +953,28 @@ GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent) return GraphHandle(box, &box->resource); } +// ============================================================================ +// Graph Exec Handles +// ============================================================================ + +namespace { +struct GraphExecBox { + CUgraphExec resource; +}; +} // namespace + +GraphExecHandle create_graph_exec_handle(CUgraphExec graph_exec) { + auto box = std::shared_ptr( + new GraphExecBox{graph_exec}, + [](const GraphExecBox* b) { + GILReleaseGuard gil; + p_cuGraphExecDestroy(b->resource); + delete b; + } + ); + return GraphExecHandle(box, &box->resource); +} + namespace { struct GraphNodeBox { mutable CUgraphNode resource; diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index 2e6ebb6271..14bd2a0bc4 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -94,6 +94,7 @@ extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel; // Graph extern decltype(&cuGraphDestroy) p_cuGraphDestroy; +extern decltype(&cuGraphExecDestroy) p_cuGraphExecDestroy; // Linker extern decltype(&cuLinkDestroy) p_cuLinkDestroy; @@ -148,6 +149,7 @@ using MemoryPoolHandle = std::shared_ptr; using LibraryHandle = std::shared_ptr; using KernelHandle = std::shared_ptr; using GraphHandle = std::shared_ptr; +using GraphExecHandle = std::shared_ptr; using GraphNodeHandle = std::shared_ptr; using GraphicsResourceHandle = std::shared_ptr; using NvrtcProgramHandle = std::shared_ptr; @@ -403,6 +405,14 @@ GraphHandle create_graph_handle(CUgraph graph); // but h_parent will be prevented from destruction while this handle exists. GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent); +// ============================================================================ +// Graph exec handle functions +// ============================================================================ + +// Wrap an externally-created CUgraphExec with RAII cleanup. +// When the last reference is released, cuGraphExecDestroy is called automatically. +GraphExecHandle create_graph_exec_handle(CUgraphExec graph_exec); + // ============================================================================ // Graph node handle functions // ============================================================================ @@ -529,6 +539,10 @@ inline CUgraph as_cu(const GraphHandle& h) noexcept { return h ? *h : nullptr; } +inline CUgraphExec as_cu(const GraphExecHandle& h) noexcept { + return h ? *h : nullptr; +} + inline CUgraphNode as_cu(const GraphNodeHandle& h) noexcept { return h ? *h : nullptr; } @@ -587,6 +601,10 @@ inline std::intptr_t as_intptr(const GraphHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } +inline std::intptr_t as_intptr(const GraphExecHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + inline std::intptr_t as_intptr(const GraphNodeHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } @@ -677,6 +695,10 @@ inline PyObject* as_py(const GraphHandle& h) noexcept { return detail::make_py("cuda.bindings.driver", "CUgraph", as_intptr(h)); } +inline PyObject* as_py(const GraphExecHandle& h) noexcept { + return detail::make_py("cuda.bindings.driver", "CUgraphExec", as_intptr(h)); +} + inline PyObject* as_py(const GraphNodeHandle& h) noexcept { if (!as_intptr(h)) { Py_RETURN_NONE; diff --git a/cuda_core/cuda/core/_device.pyx b/cuda_core/cuda/core/_device.pyx index c0d7f09ee4..c4fba83006 100644 --- a/cuda_core/cuda/core/_device.pyx +++ b/cuda_core/cuda/core/_device.pyx @@ -1361,7 +1361,7 @@ class Device: self._check_context_initialized() handle_return(runtime.cudaDeviceSynchronize()) - def create_graph_builder(self) -> "GraphBuilder": + def create_graph_builder(self) -> GraphBuilder: """Create a new :obj:`~graph.GraphBuilder` object. Returns @@ -1373,7 +1373,7 @@ class Device: from cuda.core.graph._graph_builder import GraphBuilder self._check_context_initialized() - return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True) + return GraphBuilder._init(self.create_stream()) cdef inline int Device_ensure_cuda_initialized() except? -1: diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index 0d7d20e574..a059465403 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -27,6 +27,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": ctypedef shared_ptr[const cydriver.CUlibrary] LibraryHandle ctypedef shared_ptr[const cydriver.CUkernel] KernelHandle ctypedef shared_ptr[const cydriver.CUgraph] GraphHandle + ctypedef shared_ptr[const cydriver.CUgraphExec] GraphExecHandle ctypedef shared_ptr[const cydriver.CUgraphNode] GraphNodeHandle ctypedef shared_ptr[const cydriver.CUgraphicsResource] GraphicsResourceHandle ctypedef shared_ptr[const cynvrtc.nvrtcProgram] NvrtcProgramHandle @@ -52,6 +53,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUlibrary as_cu(LibraryHandle h) noexcept nogil cydriver.CUkernel as_cu(KernelHandle h) noexcept nogil cydriver.CUgraph as_cu(GraphHandle h) noexcept nogil + cydriver.CUgraphExec as_cu(GraphExecHandle h) noexcept nogil cydriver.CUgraphNode as_cu(GraphNodeHandle h) noexcept nogil cydriver.CUgraphicsResource as_cu(GraphicsResourceHandle h) noexcept nogil cynvrtc.nvrtcProgram as_cu(NvrtcProgramHandle h) noexcept nogil @@ -68,6 +70,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": intptr_t as_intptr(LibraryHandle h) noexcept nogil intptr_t as_intptr(KernelHandle h) noexcept nogil intptr_t as_intptr(GraphHandle h) noexcept nogil + intptr_t as_intptr(GraphExecHandle h) noexcept nogil intptr_t as_intptr(GraphNodeHandle h) noexcept nogil intptr_t as_intptr(GraphicsResourceHandle h) noexcept nogil intptr_t as_intptr(NvrtcProgramHandle h) noexcept nogil @@ -85,6 +88,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": object as_py(LibraryHandle h) object as_py(KernelHandle h) object as_py(GraphHandle h) + object as_py(GraphExecHandle h) object as_py(GraphNodeHandle h) object as_py(GraphicsResourceHandle h) object as_py(NvrtcProgramHandle h) @@ -183,6 +187,9 @@ cdef LibraryHandle get_kernel_library(const KernelHandle& h) noexcept nogil cdef GraphHandle create_graph_handle(cydriver.CUgraph graph) except+ nogil cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil +# Graph exec handles +cdef GraphExecHandle create_graph_exec_handle(cydriver.CUgraphExec graph_exec) except+ nogil + # Graph node handles cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index d30993cc5e..2291b1ec20 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -27,6 +27,7 @@ from ._resource_handles cimport ( LibraryHandle, KernelHandle, GraphHandle, + GraphExecHandle, GraphicsResourceHandle, NvrtcProgramHandle, NvvmProgramHandle, @@ -154,6 +155,10 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": GraphHandle create_graph_handle_ref "cuda_core::create_graph_handle_ref" ( cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil + # Graph exec handles + GraphExecHandle create_graph_exec_handle "cuda_core::create_graph_exec_handle" ( + cydriver.CUgraphExec graph_exec) except+ nogil + # Graph node handles GraphNodeHandle create_graph_node_handle "cuda_core::create_graph_node_handle" ( cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil @@ -265,6 +270,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": # Graph void* p_cuGraphDestroy "reinterpret_cast(cuda_core::p_cuGraphDestroy)" + void* p_cuGraphExecDestroy "reinterpret_cast(cuda_core::p_cuGraphExecDestroy)" # Linker void* p_cuLinkDestroy "reinterpret_cast(cuda_core::p_cuLinkDestroy)" @@ -334,6 +340,7 @@ p_cuLibraryGetKernel = _get_driver_fn("cuLibraryGetKernel") # Graph p_cuGraphDestroy = _get_driver_fn("cuGraphDestroy") +p_cuGraphExecDestroy = _get_driver_fn("cuGraphExecDestroy") # Linker p_cuLinkDestroy = _get_driver_fn("cuLinkDestroy") diff --git a/cuda_core/cuda/core/_stream.pyx b/cuda_core/cuda/core/_stream.pyx index fdb617f032..a2bf0e025c 100644 --- a/cuda_core/cuda/core/_stream.pyx +++ b/cuda_core/cuda/core/_stream.pyx @@ -360,7 +360,7 @@ cdef class Stream: return Stream._init(obj=_stream_holder()) - def create_graph_builder(self) -> "GraphBuilder": + def create_graph_builder(self) -> GraphBuilder: """Create a new :obj:`~graph.GraphBuilder` object. The new graph builder will be associated with this stream. @@ -373,7 +373,7 @@ cdef class Stream: """ from cuda.core.graph._graph_builder import GraphBuilder - return GraphBuilder._init(stream=self, is_stream_owner=False) + return GraphBuilder._init(self) # c-only python objects, not public diff --git a/cuda_core/cuda/core/graph/_graph_builder.pxd b/cuda_core/cuda/core/graph/_graph_builder.pxd new file mode 100644 index 0000000000..660ebe8ec7 --- /dev/null +++ b/cuda_core/cuda/core/graph/_graph_builder.pxd @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.bindings cimport cydriver + +from cuda.core._resource_handles cimport GraphExecHandle, GraphHandle, StreamHandle +from cuda.core._stream cimport Stream + + +cdef class GraphBuilder: + cdef: + GraphHandle _h_graph + StreamHandle _h_stream + int _kind + int _state + Stream _stream # cached to avoid reconstruction from _h_stream handle + object __weakref__ + + +cdef class Graph: + cdef: + GraphExecHandle _h_graph_exec + object __weakref__ + + @staticmethod + cdef Graph _init(cydriver.CUgraphExec graph_exec) diff --git a/cuda_core/cuda/core/graph/_graph_builder.pyx b/cuda_core/cuda/core/graph/_graph_builder.pyx index 526c95e04a..d9e777b2ca 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/graph/_graph_builder.pyx @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import weakref from dataclasses import dataclass from libc.stdint cimport intptr_t @@ -11,7 +10,10 @@ from cuda.bindings cimport cydriver from cuda.core.graph._graph_definition cimport GraphCondition from cuda.core.graph._utils cimport _attach_host_callback_to_graph -from cuda.core._resource_handles cimport as_cu +from cuda.core._resource_handles cimport ( + as_cu, as_py, + create_graph_exec_handle, create_graph_handle, create_graph_handle_ref, +) from cuda.core._stream cimport Stream from cuda.core._utils.cuda_utils cimport HANDLE_RETURN from cuda.core._utils.version cimport cy_binding_version, cy_driver_version @@ -147,7 +149,8 @@ class GraphCompleteOptions: use_node_priority: bool = False -def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> "Graph": +def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph: + cdef cydriver.CUgraphExec c_exec params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() if options: flags = 0 @@ -162,7 +165,9 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY params.flags = flags - graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) + py_exec = handle_return(driver.cuGraphInstantiateWithParams(h_graph, params)) + c_exec = int(py_exec) + graph = Graph._init(c_exec) if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: raise RuntimeError( "Instantiation failed for an unexpected reason which is described in the return value of the function." @@ -185,7 +190,40 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> return graph -class GraphBuilder: +# Distinguishes the three kinds of GraphBuilder, which differ in how they +# begin/end stream capture and whether they own the resulting CUgraph. +# Each kind progresses through _CaptureState as follows: +# +# PRIMARY: NOT_STARTED -> CAPTURING -> ENDED +# FORKED: CAPTURING (never transitions; joined and closed) +# CONDITIONAL_BODY: NOT_STARTED -> CAPTURING -> ENDED +# +cdef enum _BuilderKind: + # PRIMARY: The top-level builder created by Device or Stream. Owns the + # captured CUgraph via an owning GraphHandle. Progresses through all three + # capture states; responsible for ending capture if destroyed early. + PRIMARY = 0 + # FORKED: Created by split(). Captures on a private stream forked from the + # primary. Starts in CAPTURING state and never transitions; the user joins + # it back to the primary via join(), which closes the builder. Must NOT + # call cuStreamEndCapture (the driver requires all forked streams to be + # joined first). + FORKED = 1 + # CONDITIONAL_BODY: Created by if_then/if_else/switch/while_loop. Captures + # into a non-owned body graph via cuStreamBeginCaptureToGraph. The body + # graph's lifetime is tied to a parent graph. Progresses through all three + # capture states like PRIMARY. + CONDITIONAL_BODY = 2 + + +# Tracks the capture lifecycle of a GraphBuilder. +cdef enum _CaptureState: + CAPTURE_NOT_STARTED = 0 + CAPTURING = 1 + CAPTURE_ENDED = 2 + + +cdef class GraphBuilder: """A graph under construction by stream capture. A graph groups a set of CUDA kernels and other CUDA operations together and executes @@ -198,63 +236,48 @@ class GraphBuilder: """ - class _MembersNeededForFinalize: - __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream") - - def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required): - self.stream = stream_obj - self.is_stream_owner = is_stream_owner - self.graph = None - self.conditional_graph = conditional_graph - self.is_join_required = is_join_required - weakref.finalize(graph_builder_obj, self.close) - - def close(self): - if self.stream: - if not self.is_join_required: - capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] - if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: - # Note how this condition only occures for the primary graph builder - # This is because calling cuStreamEndCapture streams that were split off of the primary - # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. - # Therefore, it is currently a requirement that users join all split graph builders - # before a graph builder can be clearly destroyed. - handle_return(driver.cuStreamEndCapture(self.stream.handle)) - if self.is_stream_owner: - self.stream.close() - self.stream = None - if self.graph: - handle_return(driver.cuGraphDestroy(self.graph)) - self.graph = None - self.conditional_graph = None - - __slots__ = ("__weakref__", "_building_ended", "_mnff") - def __init__(self): raise NotImplementedError( "directly creating a Graph object can be ambiguous. Please either " "call Device.create_graph_builder() or stream.create_graph_builder()" ) - @classmethod - def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False): - self = cls.__new__(cls) - self._mnff = GraphBuilder._MembersNeededForFinalize( - self, stream, is_stream_owner, conditional_graph, is_join_required - ) + def __dealloc__(self): + # Note: _stream could be set to None by cyclic-GC tp_clear before + # __dealloc__, but _h_stream is guaranteed to be valid. + if self._h_stream and self._state == CAPTURING and self._kind != FORKED: + with nogil: + cydriver.cuStreamEndCapture(as_cu(self._h_stream), NULL) - self._building_ended = False + @staticmethod + def _init(Stream stream): + cdef GraphBuilder self = GraphBuilder.__new__(GraphBuilder) + # _h_graph set by begin_building + self._h_stream = stream._h_stream + self._kind = PRIMARY + self._state = CAPTURE_NOT_STARTED + self._stream = stream return self + def close(self): + """Destroy the graph builder.""" + if self._h_stream and self._state == CAPTURING and self._kind != FORKED: + with nogil: + HANDLE_RETURN(cydriver.cuStreamEndCapture(as_cu(self._h_stream), NULL)) + self._h_graph.reset() + self._h_stream.reset() + self._state = CAPTURE_ENDED + self._stream = None + @property def stream(self) -> Stream: """Returns the stream associated with the graph builder.""" - return self._mnff.stream + return self._stream @property def is_join_required(self) -> bool: """Returns True if this graph builder must be joined before building is ended.""" - return self._mnff.is_join_required + return self._kind == FORKED def begin_building(self, mode="relaxed") -> GraphBuilder: """Begins the building process. @@ -272,64 +295,68 @@ class GraphBuilder: Default set to use relaxed. """ - if self._building_ended: - raise RuntimeError("Cannot resume building after building has ended.") - if mode not in ("global", "thread_local", "relaxed"): - raise ValueError(f"Unsupported build mode: {mode}") + if self._state != CAPTURE_NOT_STARTED: + if self._state == CAPTURING: + raise RuntimeError("Graph builder is already building.") + else: + raise RuntimeError("Cannot resume building after building has ended.") + cdef cydriver.CUstreamCaptureMode c_mode if mode == "global": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL + c_mode = cydriver.CU_STREAM_CAPTURE_MODE_GLOBAL elif mode == "thread_local": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL + c_mode = cydriver.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL elif mode == "relaxed": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED + c_mode = cydriver.CU_STREAM_CAPTURE_MODE_RELAXED else: raise ValueError(f"Unsupported build mode: {mode}") - if self._mnff.conditional_graph: - handle_return( - driver.cuStreamBeginCaptureToGraph( - self._mnff.stream.handle, - self._mnff.conditional_graph, - None, # dependencies - None, # dependencyData - 0, # numDependencies - capture_mode, - ) - ) + cdef cydriver.CUstream c_stream = as_cu(self._h_stream) + cdef cydriver.CUgraph c_graph + if self._kind == CONDITIONAL_BODY: + c_graph = as_cu(self._h_graph) + with nogil: + HANDLE_RETURN(cydriver.cuStreamBeginCaptureToGraph( + c_stream, c_graph, NULL, NULL, 0, c_mode)) else: - handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) + with nogil: + HANDLE_RETURN(cydriver.cuStreamBeginCapture(c_stream, c_mode)) + _get_capture_info(c_stream, NULL, &c_graph) + self._h_graph = create_graph_handle(c_graph) + self._state = CAPTURING return self @property def is_building(self) -> bool: """Returns True if the graph builder is currently building.""" - capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] - if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: + cdef cydriver.CUstream c_stream = as_cu(self._h_stream) + cdef cydriver.CUstreamCaptureStatus status + with nogil: + _get_capture_info(c_stream, &status, NULL) + if status == cydriver.CU_STREAM_CAPTURE_STATUS_NONE: return False - elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: + elif status == cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: return True - elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: + elif status == cydriver.CU_STREAM_CAPTURE_STATUS_INVALIDATED: raise RuntimeError( "Build process encountered an error and has been invalidated. Build process must now be ended." ) else: - raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") + raise NotImplementedError(f"Unsupported capture status type received: {status}") def end_building(self) -> GraphBuilder: """Ends the building process.""" if not self.is_building: raise RuntimeError("Graph builder is not building.") - if self._mnff.conditional_graph: - self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) - else: - self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) + cdef cydriver.CUstream c_stream = as_cu(self._h_stream) + with nogil: + HANDLE_RETURN(cydriver.cuStreamEndCapture(c_stream, NULL)) # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to # resume the build process after the first call to end_building() - self._building_ended = True + self._state = CAPTURE_ENDED return self - def complete(self, options: GraphCompleteOptions | None = None) -> "Graph": + def complete(self, options: GraphCompleteOptions | None = None) -> Graph: """Completes the graph builder and returns the built :obj:`~graph.Graph` object. Parameters @@ -343,10 +370,10 @@ class GraphBuilder: The newly built graph. """ - if not self._building_ended: + if self._state != CAPTURE_ENDED: raise RuntimeError("Graph has not finished building.") - return _instantiate_graph(self._mnff.graph, options) + return _instantiate_graph(as_py(self._h_graph), options) def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None): """Generates a DOT debug file for the graph builder. @@ -359,10 +386,14 @@ class GraphBuilder: Customizable dataclass for the debug print options. """ - if not self._building_ended: + if self._state != CAPTURE_ENDED: raise RuntimeError("Graph has not finished building.") - flags = options._to_flags() if options else 0 - handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) + cdef unsigned int c_flags = options._to_flags() if options else 0 + cdef cydriver.CUgraph c_graph = as_cu(self._h_graph) + cdef bytes b_path = path.encode() if isinstance(path, str) else path + cdef const char* c_path = b_path + with nogil: + HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(c_graph, c_path, c_flags)) def split(self, count: int) -> tuple[GraphBuilder, ...]: """Splits the original graph builder into multiple graph builders. @@ -385,14 +416,12 @@ class GraphBuilder: if count < 2: raise ValueError(f"Invalid split count: expecting >= 2, got {count}") - event = self._mnff.stream.record() + event = self._stream.record() result = [self] for i in range(count - 1): - stream = self._mnff.stream.device.create_stream() + stream = self._stream.device.create_stream() stream.wait(event) - result.append( - GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) - ) + result.append(_init_forked(stream)) event.close() return tuple(result) @@ -440,7 +469,7 @@ class GraphBuilder: return self.stream.__cuda_stream__() def _get_conditional_context(self) -> driver.CUcontext: - return self._mnff.stream.context.handle + return self._stream.context.handle def create_condition(self, default_value=None) -> GraphCondition: """Create a condition variable for use with conditional nodes. @@ -471,7 +500,7 @@ class GraphBuilder: default_value = 0 flags = 0 - status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) + status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._stream.handle)) if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: raise RuntimeError("Cannot create a condition when graph is not being built") @@ -480,42 +509,6 @@ class GraphBuilder: ) return GraphCondition._from_handle(int(raw_handle)) - def _cond_with_params(self, node_params) -> tuple: - # Get current capture info to ensure we're in a valid state - status, _, graph, *deps_info, num_dependencies = handle_return( - driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) - ) - if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: - raise RuntimeError("Cannot add conditional node when not actively capturing") - - # Add the conditional node to the graph - deps_info_update = [ - [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] - ] + [None] * (len(deps_info) - 1) - - # Update the stream's capture dependencies - handle_return( - driver.cuStreamUpdateCaptureDependencies( - self._mnff.stream.handle, - *deps_info_update, # dependencies, edgeData - 1, # numDependencies - driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, - ) - ) - - # Create new graph builders for each condition - return tuple( - [ - GraphBuilder._init( - stream=self._mnff.stream.device.create_stream(), - is_stream_owner=True, - conditional_graph=node_params.conditional.phGraph_out[i], - is_join_required=False, - ) - for i in range(node_params.conditional.size) - ] - ) - def if_then(self, condition: GraphCondition) -> GraphBuilder: """Adds an if condition branch and returns a new graph builder for it. @@ -550,7 +543,7 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF node_params.conditional.size = 1 node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params)[0] + return _cond_with_params(self, node_params)[0] def if_else(self, condition: GraphCondition) -> tuple[GraphBuilder, GraphBuilder]: """Adds an if-else condition branch and returns new graph builders for both branches. @@ -586,7 +579,7 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF node_params.conditional.size = 2 node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params) + return _cond_with_params(self, node_params) def switch(self, condition: GraphCondition, count: int) -> tuple[GraphBuilder, ...]: """Adds a switch condition branch and returns new graph builders for all cases. @@ -625,7 +618,7 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH node_params.conditional.size = count node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params) + return _cond_with_params(self, node_params) def while_loop(self, condition: GraphCondition) -> GraphBuilder: """Adds a while loop and returns a new graph builder for it. @@ -661,18 +654,9 @@ class GraphBuilder: node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE node_params.conditional.size = 1 node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params)[0] - - def close(self): - """Destroy the graph builder. - - Closes the associated stream if we own it. Borrowed stream - object will instead have their references released. - - """ - self._mnff.close() + return _cond_with_params(self, node_params)[0] - def embed(self, child: GraphBuilder): + def embed(self, GraphBuilder child): """Embed a previously-built :obj:`~graph.GraphBuilder` as a child node. Parameters @@ -680,13 +664,13 @@ class GraphBuilder: child : :obj:`~graph.GraphBuilder` The child graph builder. Must have finished building. """ - if not child._building_ended: + if child._state != CAPTURE_ENDED: raise ValueError("Child graph has not finished building.") if not self.is_building: raise ValueError("Parent graph is not being built.") - stream_handle = self._mnff.stream.handle + stream_handle = self._stream.handle _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( driver.cuStreamGetCaptureInfo(stream_handle) ) @@ -698,7 +682,7 @@ class GraphBuilder: [ handle_return( driver.cuGraphAddChildGraphNode( - graph_out, *deps_info_trimmed, num_dependencies_out, child._mnff.graph + graph_out, *deps_info_trimmed, num_dependencies_out, as_py(child._h_graph) ) ) ] @@ -740,18 +724,13 @@ class GraphBuilder: pointer (caller manages lifetime). If bytes-like, the data is copied and its lifetime is tied to the graph. """ - cdef Stream stream = self._mnff.stream + cdef Stream stream = self._stream cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) cdef cydriver.CUstreamCaptureStatus capture_status cdef cydriver.CUgraph c_graph = NULL with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( - c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL)) - ELSE: - HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( - c_stream, &capture_status, NULL, &c_graph, NULL, NULL)) + _get_capture_info(c_stream, &capture_status, &c_graph) if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: raise RuntimeError("Cannot add callback when graph is not being built") @@ -764,7 +743,69 @@ class GraphBuilder: HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) -class Graph: +cdef inline GraphBuilder _init_forked(Stream stream): + cdef GraphBuilder gb = GraphBuilder.__new__(GraphBuilder) + # _h_graph not used for FORKED builders. Captures to primary graph. + gb._h_stream = stream._h_stream + gb._kind = FORKED + gb._state = CAPTURING + gb._stream = stream + return gb + + +cdef inline GraphBuilder _init_conditional(Stream stream, cydriver.CUgraph cond_graph, GraphBuilder parent): + cdef GraphBuilder gb = GraphBuilder.__new__(GraphBuilder) + gb._h_graph = create_graph_handle_ref(cond_graph, parent._h_graph) + gb._h_stream = stream._h_stream + gb._kind = CONDITIONAL_BODY + gb._state = CAPTURE_NOT_STARTED + gb._stream = stream + return gb + + +cdef inline int _get_capture_info( + cydriver.CUstream stream, + cydriver.CUstreamCaptureStatus* status, + cydriver.CUgraph* graph) except?-1 nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + return HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( + stream, status, NULL, graph, NULL, NULL, NULL)) + ELSE: + return HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( + stream, status, NULL, graph, NULL, NULL)) + + +cdef inline tuple _cond_with_params(GraphBuilder gb, node_params): + status, _, graph, *deps_info, num_dependencies = handle_return( + driver.cuStreamGetCaptureInfo(gb._stream.handle) + ) + if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: + raise RuntimeError("Cannot add conditional node when not actively capturing") + + deps_info_update = [ + [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] + ] + [None] * (len(deps_info) - 1) + + handle_return( + driver.cuStreamUpdateCaptureDependencies( + gb._stream.handle, + *deps_info_update, # dependencies, edgeData + 1, # numDependencies + driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, + ) + ) + + return tuple( + _init_conditional( + gb._stream.device.create_stream(), + int(node_params.conditional.phGraph_out[i]), + gb, + ) + for i in range(node_params.conditional.size) + ) + + +cdef class Graph: """An executable graph. A graph groups a set of CUDA kernels and other CUDA operations together and executes @@ -775,32 +816,18 @@ class Graph: """ - class _MembersNeededForFinalize: - __slots__ = "graph" - - def __init__(self, graph_obj, graph): - self.graph = graph - weakref.finalize(graph_obj, self.close) - - def close(self): - if self.graph: - handle_return(driver.cuGraphExecDestroy(self.graph)) - self.graph = None - - __slots__ = ("__weakref__", "_mnff") - def __init__(self): raise RuntimeError("directly constructing a Graph instance is not supported") - @classmethod - def _init(cls, graph): - self = cls.__new__(cls) - self._mnff = Graph._MembersNeededForFinalize(self, graph) + @staticmethod + cdef Graph _init(cydriver.CUgraphExec graph_exec): + cdef Graph self = Graph.__new__(Graph) + self._h_graph_exec = create_graph_exec_handle(graph_exec) return self def close(self): """Destroy the graph.""" - self._mnff.close() + self._h_graph_exec.reset() @property def handle(self) -> driver.CUgraphExec: @@ -812,7 +839,7 @@ class Graph: handle, call ``int()`` on the returned object. """ - return self._mnff.graph + return as_py(self._h_graph_exec) def update(self, source: "GraphBuilder | GraphDefinition") -> None: """Update the graph using a new graph definition. @@ -829,12 +856,12 @@ class Graph: from cuda.core.graph import GraphDefinition cdef cydriver.CUgraph cu_graph - cdef cydriver.CUgraphExec cu_exec = int(self._mnff.graph) + cdef cydriver.CUgraphExec cu_exec = as_cu(self._h_graph_exec) if isinstance(source, GraphBuilder): - if not source._building_ended: + if (source)._state != CAPTURE_ENDED: raise ValueError("Graph has not finished building.") - cu_graph = int(source._mnff.graph) + cu_graph = as_cu((source)._h_graph) elif isinstance(source, GraphDefinition): cu_graph = int(source.handle) else: @@ -860,7 +887,10 @@ class Graph: The stream in which to upload the graph """ - handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) + cdef cydriver.CUgraphExec c_exec = as_cu(self._h_graph_exec) + cdef cydriver.CUstream c_stream = int(stream.handle) + with nogil: + HANDLE_RETURN(cydriver.cuGraphUpload(c_exec, c_stream)) def launch(self, stream: Stream): """Launches the graph in a stream. @@ -871,4 +901,7 @@ class Graph: The stream in which to launch the graph """ - handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) + cdef cydriver.CUgraphExec c_exec = as_cu(self._h_graph_exec) + cdef cydriver.CUstream c_stream = int(stream.handle) + with nogil: + HANDLE_RETURN(cydriver.cuGraphLaunch(c_exec, c_stream))