-
Notifications
You must be signed in to change notification settings - Fork 278
Add CUDA process checkpointing helpers #1983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kkraus14
wants to merge
9
commits into
NVIDIA:main
Choose a base branch
from
kkraus14:kk/issue-1343-cuda-checkpointing
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
d8a2031
Add CUDA process checkpointing helpers
kkraus14 4992921
Address checkpoint review feedback
kkraus14 5afd43c
Rewrite checkpoint tests: replace mocks with real GPU tests
leofang f67a5e6
Accept Device.uuid strings in gpu_mapping; use cuda.core APIs in tests
leofang 245e7a4
Apply pre-commit formatting fixes
leofang 7c7f0e5
Restore original device in self_process fixture teardown
leofang 8192df6
Address checkpoint review follow-ups
kkraus14 fbb8037
Skip checkpoint lifecycle/migration tests in CI
leofang 8f798f4
Isolate checkpoint lifecycle tests
kkraus14 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,248 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import ctypes as _ctypes | ||
| from collections.abc import Mapping as _Mapping | ||
| from typing import Any as _Any | ||
|
|
||
| from cuda.core._utils.cuda_utils import handle_return as _handle_cuda_return | ||
| from cuda.core._utils.version import binding_version as _binding_version | ||
| from cuda.core._utils.version import driver_version as _driver_version | ||
| from cuda.core.typing import ProcessStateT as _ProcessStateT | ||
|
|
||
| try: | ||
| from cuda.bindings import driver as _driver | ||
| except ImportError: | ||
| from cuda import cuda as _driver | ||
|
leofang marked this conversation as resolved.
|
||
|
|
||
|
|
||
| _PROCESS_STATE_NAME_ATTRS: tuple[tuple[str, _ProcessStateT], ...] = ( | ||
| ("CU_PROCESS_STATE_RUNNING", "running"), | ||
| ("CU_PROCESS_STATE_LOCKED", "locked"), | ||
| ("CU_PROCESS_STATE_CHECKPOINTED", "checkpointed"), | ||
| ("CU_PROCESS_STATE_FAILED", "failed"), | ||
| ) | ||
|
|
||
| _REQUIRED_BINDING_ATTRS = ( | ||
| "cuCheckpointProcessCheckpoint", | ||
| "cuCheckpointProcessGetRestoreThreadId", | ||
| "cuCheckpointProcessGetState", | ||
| "cuCheckpointProcessLock", | ||
| "cuCheckpointProcessRestore", | ||
| "cuCheckpointProcessUnlock", | ||
| "CUcheckpointGpuPair", | ||
| "CUcheckpointLockArgs", | ||
| "CUprocessState", | ||
| "CUcheckpointRestoreArgs", | ||
| ) | ||
| _REQUIRED_DRIVER_VERSION = (12, 8, 0) | ||
| _driver_capability_checked = False | ||
|
|
||
|
|
||
| class Process: | ||
| """ | ||
| CUDA process that can be locked, checkpointed, restored, and unlocked. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| pid : int | ||
| Process ID of the CUDA process. | ||
| """ | ||
|
|
||
| __slots__ = ("pid",) | ||
|
|
||
| def __init__(self, pid: int): | ||
| self.pid = _check_pid(pid) | ||
|
Comment on lines
+53
to
+56
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: move @property
def pid(self):
return self._pidto make it readonly. |
||
|
|
||
| @property | ||
| def state(self) -> _ProcessStateT: | ||
| """ | ||
| CUDA checkpoint state for this process. | ||
| """ | ||
| driver = _get_driver() | ||
| state = _call_driver(driver, driver.cuCheckpointProcessGetState, self.pid) | ||
| state_names = _get_process_state_names(driver) | ||
| try: | ||
| return state_names[state] | ||
| except KeyError as e: | ||
| state_value = int(state) | ||
| raise RuntimeError(f"Unknown CUDA checkpoint process state: {state_value}") from e | ||
|
|
||
| @property | ||
| def restore_thread_id(self) -> int: | ||
| """ | ||
| CUDA restore thread ID for this process. | ||
| """ | ||
| driver = _get_driver() | ||
| return _call_driver(driver, driver.cuCheckpointProcessGetRestoreThreadId, self.pid) | ||
|
|
||
| def lock(self, timeout_ms: int = 0) -> None: | ||
|
leofang marked this conversation as resolved.
|
||
| """ | ||
| Lock this process, blocking further CUDA API calls. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| timeout_ms : int, optional | ||
| Timeout in milliseconds. A value of 0 indicates no timeout. | ||
| """ | ||
| driver = _get_driver() | ||
| args = driver.CUcheckpointLockArgs() | ||
| args.timeoutMs = _check_timeout_ms(timeout_ms) | ||
| _call_driver(driver, driver.cuCheckpointProcessLock, self.pid, args) | ||
|
|
||
| def checkpoint(self) -> None: | ||
| """ | ||
| Checkpoint the GPU memory contents of this locked process. | ||
| """ | ||
| driver = _get_driver() | ||
| _call_driver(driver, driver.cuCheckpointProcessCheckpoint, self.pid, None) | ||
|
|
||
| def restore(self, gpu_mapping: _Mapping[_Any, _Any] | None = None) -> None: | ||
| """ | ||
| Restore this checkpointed process. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| gpu_mapping : mapping, optional | ||
| GPU UUID remapping from each checkpointed GPU UUID to the GPU UUID | ||
| to restore onto. For migration workflows, provide mappings for | ||
| every CUDA-visible GPU. | ||
| """ | ||
| driver = _get_driver() | ||
| args = _make_restore_args(driver, gpu_mapping) | ||
| _call_driver(driver, driver.cuCheckpointProcessRestore, self.pid, args) | ||
|
|
||
| def unlock(self) -> None: | ||
| """ | ||
| Unlock this locked process so it can resume CUDA API calls. | ||
| """ | ||
| driver = _get_driver() | ||
| _call_driver(driver, driver.cuCheckpointProcessUnlock, self.pid, None) | ||
|
|
||
|
|
||
| def _get_driver(): | ||
| global _driver_capability_checked | ||
| if _driver_capability_checked: | ||
| return _driver | ||
|
|
||
| binding_ver = _binding_version() | ||
| if not _binding_version_supports_checkpoint(binding_ver): | ||
| raise RuntimeError( | ||
| "CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. " | ||
| f"Found cuda.bindings {'.'.join(str(part) for part in binding_ver[:3])}." | ||
| ) | ||
|
|
||
| missing = [name for name in _REQUIRED_BINDING_ATTRS if not hasattr(_driver, name)] | ||
| if missing: | ||
| raise RuntimeError( | ||
| f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}" | ||
| ) | ||
|
|
||
| driver_ver = _driver_version() | ||
| if driver_ver < _REQUIRED_DRIVER_VERSION: | ||
| raise RuntimeError( | ||
| "CUDA checkpointing is not supported by the installed NVIDIA driver. " | ||
| "Upgrade to a driver version with CUDA checkpoint API support." | ||
| ) | ||
|
|
||
| _driver_capability_checked = True | ||
| return _driver | ||
|
leofang marked this conversation as resolved.
|
||
|
|
||
|
|
||
| def _binding_version_supports_checkpoint(version) -> bool: | ||
| major, minor, patch = version[:3] | ||
| return (major == 12 and (minor, patch) >= (8, 0)) or (major == 13 and (minor, patch) >= (0, 2)) or major > 13 | ||
|
|
||
|
|
||
| def _get_process_state_names(driver) -> dict[_Any, _ProcessStateT]: | ||
| return {getattr(driver.CUprocessState, attr): state_name for attr, state_name in _PROCESS_STATE_NAME_ATTRS} | ||
|
|
||
|
|
||
| def _call_driver(driver, func, *args): | ||
| try: | ||
| result = func(*args) | ||
| except RuntimeError as e: | ||
| if "cuCheckpointProcess" in str(e) and "not found" in str(e): | ||
| raise RuntimeError( | ||
| "CUDA checkpointing is not supported by the installed NVIDIA driver. " | ||
| "Upgrade to a driver version with CUDA checkpoint API support." | ||
| ) from e | ||
| raise | ||
| return _handle_return(driver, result) | ||
|
|
||
|
|
||
| def _handle_return(driver, result): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: Can this check be consolidated with the above one ( |
||
| err = result[0] | ||
| not_supported_errors = ( | ||
| getattr(driver.CUresult, "CUDA_ERROR_NOT_FOUND", None), | ||
| getattr(driver.CUresult, "CUDA_ERROR_NOT_SUPPORTED", None), | ||
| ) | ||
| if err in not_supported_errors: | ||
| raise RuntimeError( | ||
| "CUDA checkpointing is not supported by the installed NVIDIA driver. " | ||
| "Upgrade to a driver version with CUDA checkpoint API support." | ||
| ) | ||
|
|
||
| return _handle_cuda_return(result) | ||
|
|
||
|
|
||
| def _check_pid(pid: int) -> int: | ||
| if isinstance(pid, bool) or not isinstance(pid, int): | ||
| raise TypeError("pid must be an int") | ||
| if pid <= 0: | ||
| raise ValueError("pid must be a positive int") | ||
| return pid | ||
|
|
||
|
|
||
| def _check_timeout_ms(timeout_ms: int) -> int: | ||
| if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int): | ||
| raise TypeError("timeout_ms must be an int") | ||
| if timeout_ms < 0: | ||
| raise ValueError("timeout_ms must be >= 0") | ||
| return timeout_ms | ||
|
|
||
|
|
||
| def _make_restore_args(driver, gpu_mapping: _Mapping[_Any, _Any] | None): | ||
| if gpu_mapping is None: | ||
| return None | ||
| if not isinstance(gpu_mapping, _Mapping): | ||
| raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID") | ||
|
|
||
| pairs = [] | ||
| for old_uuid, new_uuid in gpu_mapping.items(): | ||
| pair = driver.CUcheckpointGpuPair() | ||
| buffers = [] | ||
| pair.oldUuid = _as_cuuuid(driver, old_uuid, buffers) | ||
| pair.newUuid = _as_cuuuid(driver, new_uuid, buffers) | ||
| pairs.append(pair) | ||
|
|
||
| if not pairs: | ||
| return None | ||
|
|
||
| args = driver.CUcheckpointRestoreArgs() | ||
| args.gpuPairs = pairs | ||
| args.gpuPairsCount = len(pairs) | ||
|
leofang marked this conversation as resolved.
|
||
| return args | ||
|
|
||
|
|
||
| def _as_cuuuid(driver, value, buffers): | ||
| """Convert *value* to a ``CUuuid``. | ||
|
|
||
| Accepts a ``CUuuid`` instance (returned as-is) or a UUID string in | ||
| the ``"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"`` format returned by | ||
| :attr:`Device.uuid`. | ||
| """ | ||
| if isinstance(value, str): | ||
| raw = bytes.fromhex(value.replace("-", "")) | ||
| if len(raw) != 16: | ||
| raise ValueError(f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}") | ||
| buf = _ctypes.create_string_buffer(raw, 16) | ||
| buffers.append(buf) | ||
| return driver.CUuuid(_ctypes.addressof(buf)) | ||
| return value | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "Process", | ||
| ] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.