From 1d634c715904e838e926cc0b0d451fb13c78a789 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 27 Apr 2026 19:15:23 -0700 Subject: [PATCH 1/7] Implement SpecializationConstant class --- dpctl/_backend.pxd | 4 + dpctl/_sycl_platform.pyx | 2 +- dpctl/program/__init__.py | 2 + dpctl/program/_program.pyx | 170 ++++++++++++++++++ .../dpctl_sycl_kernel_bundle_interface.h | 7 + 5 files changed, 184 insertions(+), 1 deletion(-) diff --git a/dpctl/_backend.pxd b/dpctl/_backend.pxd index 93d9b5ef97..d2f1a0decb 100644 --- a/dpctl/_backend.pxd +++ b/dpctl/_backend.pxd @@ -431,6 +431,10 @@ cdef extern from "syclinterface/dpctl_sycl_context_interface.h": cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h": + ctypedef struct _spec_const "DPCTLSpecConst": + uint32_t id + size_t size + const void *value cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_CreateFromSpirv( const DPCTLSyclContextRef Ctx, const DPCTLSyclDeviceRef Dev, diff --git a/dpctl/_sycl_platform.pyx b/dpctl/_sycl_platform.pyx index 41eff7b5d3..ba78226e50 100644 --- a/dpctl/_sycl_platform.pyx +++ b/dpctl/_sycl_platform.pyx @@ -236,7 +236,7 @@ cdef class SyclPlatform(_SyclPlatform): and filter string for each device is printed. Args: - verbosity (Literal[0, 1, 2], optional):. + verbosity (Literal[0, 1, 2], optional): The verbosity controls how much information is printed by the function. Value ``0`` is the lowest level set by default and ``2`` is the highest level to print the most verbose output. diff --git a/dpctl/program/__init__.py b/dpctl/program/__init__.py index 71302e4186..bb16f9611b 100644 --- a/dpctl/program/__init__.py +++ b/dpctl/program/__init__.py @@ -22,6 +22,7 @@ """ from ._program import ( + SpecializationConstant, SyclKernel, SyclKernelBundle, SyclKernelBundleCompilationError, @@ -41,6 +42,7 @@ "SyclKernelBundleCompilationError", "SyclProgram", "SyclProgramCompilationError", + "SpecializationConstant", ] diff --git a/dpctl/program/_program.pyx b/dpctl/program/_program.pyx index 8737be4762..7cecc3fa4d 100644 --- a/dpctl/program/_program.pyx +++ b/dpctl/program/_program.pyx @@ -26,7 +26,17 @@ an OpenCL source string or a SPIR-V binary file. """ +from cpython.buffer cimport ( + Py_buffer, + PyBUF_ANY_CONTIGUOUS, + PyBUF_SIMPLE, + PyBuffer_Release, + PyObject_CheckBuffer, + PyObject_GetBuffer, +) +from cpython.bytes cimport PyBytes_FromStringAndSize from libc.stdint cimport uint32_t +from libc.string cimport memcmp import warnings @@ -51,14 +61,20 @@ from dpctl._backend cimport ( # noqa: E211, E402; DPCTLSyclDeviceRef, DPCTLSyclKernelBundleRef, DPCTLSyclKernelRef, + _spec_const, ) +import numbers + +import numpy as np + __all__ = [ "create_kernel_bundle_from_source", "create_kernel_bundle_from_spirv", "SyclKernel", "SyclKernelBundle", "SyclKernelBundleCompilationError", + "SpecializationConstant", ] cdef class SyclKernelBundleCompilationError(Exception): @@ -252,6 +268,160 @@ cdef api SyclKernelBundle SyclKernelBundle_Make(DPCTLSyclKernelBundleRef KBRef): return SyclKernelBundle._create(copied_KBRef) +cdef class SpecializationConstant: + """ + SpecializationConstant(spec_id, *args) + + Python class representing SYCL specialization constants that can be used + when creating a :class:`dpctl.program.SyclKernelBundle` from SPIR-V. + + There are multiple ways to create a :class:`.SpecializationConstant`: + + - ``SpecializationConstant(spec_id, obj)`` + If the constructor is invoked with a single variadic argument, the + argument is expected to either expose the Python buffer protocol or be + coercible to a NumPy array. If the argument is coercible to a NumPy array + or is one, it must have a supported data type (bool, integral, floating + point, or void). The specialization constant will be constructed from the + data in the buffer + + - ``SpecializationConstant(spec_id, dtype, obj)`` + If the constructor is invoked with two variadic arguments, and the first + argument is a string, it is interpreted as a NumPy ``dtype`` string and the + second argument will be coerced to a NumPy array with that data type. + The data type specified by the first argument must be a supported data + type (bool, integral, floating point, or void). + + - ``SpecializationConstant(spec_id, nbytes, raw_ptr)`` + If the constructor is invoked with two variadic arguments where both are + integers, the first argument is interpreted as the number of bytes and + the second argument is interpreted as a pointer to the data. + + Note that when constructing from a buffer, the + :class:`.SpecializationConstant`, shares memory with the original object. + Modifications to the original object's data after construction will be + reflected when the :class:`.SpecializationConstant` is used to create a + :class:`.SyclKernelBundle`. This is not the case when constructing from a + raw pointer, as the data is copied. + + Args: + spec_id (int): + The SPIR-V specialization ID. + args: + Variadic argument, see class documentation. + + Raises: + TypeError: In case of incorrect arguments given to constructor, + failure to coerce to a buffer, or unsupported data type when + coercing to a buffer. + ValueError: If the provided object fails to construct a buffer. + """ + + cdef _spec_const _spec_const + cdef Py_buffer _buffer + + def __cinit__(self, spec_id, *args): + cdef int ret_code = 0 + cdef object target_obj = None + + if not isinstance(spec_id, numbers.Integral): + raise TypeError( + "Specialization constant ID must be of type `int`, got " + f"{type(spec_id)}" + ) + + if len(args) == 0 or len(args) > 2: + raise TypeError( + f"Constructor takes 2 or 3 arguments, got {len(args)}." + ) + + self._spec_const.id = spec_id + + if len(args) == 2: + if ( + isinstance(args[0], numbers.Integral) and + isinstance(args[1], numbers.Integral) + ): + target_obj = PyBytes_FromStringAndSize( + args[1], args[0] + ) + elif isinstance(args[0], str): + target_obj = np.ascontiguousarray(args[1], dtype=args[0]) + + elif len(args) == 1: + target_obj = args[0] + if not PyObject_CheckBuffer(target_obj): + # attempt to coerce to a numpy array + target_obj = np.ascontiguousarray(target_obj) + else: + raise TypeError( + "Invalid arguments." + ) + + if isinstance(target_obj, np.ndarray): + if target_obj.dtype.kind not in ("b", "i", "u", "f", "c", "V"): + raise TypeError( + "Coercion of input to buffer resulted in an unsupported " + f"data type '{target_obj.dtype}'. When coercing objects, " + "`SpecializationConstant` expects the data to coerce to a " + "supported type: bool, integral, real or complex floating " + "point, or void. To pass arbitrary data, use a " + "`memoryview` or `bytes` object, or pass the pointer and " + "size directly." + ) + + ret_code = PyObject_GetBuffer( + target_obj, &(self._buffer), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS + ) + if ret_code != 0: + raise ValueError( + "Failed to get buffer view for the provided object." + ) + self._spec_const.value = self._buffer.buf + self._spec_const.size = self._buffer.len + + def __dealloc__(self): + PyBuffer_Release(&(self._buffer)) + + def __repr__(self): + return f"SpecializationConstant({self._spec_const.id})" + + def __eq__(self, other): + if not isinstance(other, SpecializationConstant): + return False + cdef SpecializationConstant _other = other + if ( + self._spec_const.id != _other._spec_const.id or + self._spec_const.size != _other._spec_const.size or + self._spec_const.value != _other._spec_const.value + ): + return False + return memcmp( + self._spec_const.value, + _other._spec_const.value, + self._spec_const.size + ) == 0 + + @property + def id(self): + """Returns the specialization ID for this specialization constant.""" + return self._spec_const.id + + @property + def size(self): + """ + Returns the size in bytes of the data for this specialization constant. + """ + return self._spec_const.size + + cdef size_t addressof(self): + """ + Returns the address of the _spec_const for this + :class:`.SpecializationConstant` cast to ``size_t``. + """ + return &(self._spec_const) + + cpdef create_kernel_bundle_from_source(SyclQueue q, str src, str copts=""): """ Creates a Sycl interoperability kernel bundle from an OpenCL source diff --git a/libsyclinterface/include/syclinterface/dpctl_sycl_kernel_bundle_interface.h b/libsyclinterface/include/syclinterface/dpctl_sycl_kernel_bundle_interface.h index 07a76c3fd8..de7543016e 100644 --- a/libsyclinterface/include/syclinterface/dpctl_sycl_kernel_bundle_interface.h +++ b/libsyclinterface/include/syclinterface/dpctl_sycl_kernel_bundle_interface.h @@ -35,6 +35,13 @@ DPCTL_C_EXTERN_C_BEGIN +typedef struct DPCTLSpecConstTy +{ + uint32_t id; + size_t size; + const void *value; +} DPCTLSpecConst; + /** * @defgroup KernelBundleInterface Kernel_bundle class C wrapper */ From 977f32646fc6cb7807b649360c9c5ce93262b16f Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 28 Apr 2026 14:55:36 -0700 Subject: [PATCH 2/7] hook specialization constants into kernel bundle interface --- dpctl/_backend.pxd | 4 +- dpctl/program/_program.pxd | 5 +- dpctl/program/_program.pyx | 53 +++++++++++++-- .../dpctl_sycl_kernel_bundle_interface.h | 6 +- .../dpctl_sycl_kernel_bundle_interface.cpp | 66 +++++++++++++++++-- 5 files changed, 118 insertions(+), 16 deletions(-) diff --git a/dpctl/_backend.pxd b/dpctl/_backend.pxd index d2f1a0decb..d23edaf506 100644 --- a/dpctl/_backend.pxd +++ b/dpctl/_backend.pxd @@ -440,7 +440,9 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h": const DPCTLSyclDeviceRef Dev, const void *IL, size_t Length, - const char *CompileOpts) + const char *CompileOpts, + size_t NumSpecConsts, + const _spec_const *SpecConsts) cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_CreateFromOCLSource( const DPCTLSyclContextRef Ctx, const DPCTLSyclDeviceRef Dev, diff --git a/dpctl/program/_program.pxd b/dpctl/program/_program.pxd index 435ef68521..41f781cecd 100644 --- a/dpctl/program/_program.pxd +++ b/dpctl/program/_program.pxd @@ -63,7 +63,10 @@ cpdef create_kernel_bundle_from_source ( SyclQueue q, unicode source, unicode copts=* ) cpdef create_kernel_bundle_from_spirv ( - SyclQueue q, const unsigned char[:] IL, unicode copts=* + SyclQueue q, + const unsigned char[:] IL, + unicode copts=*, + list specializations=*, ) cpdef create_program_from_source (SyclQueue q, unicode source, unicode copts=*) cpdef create_program_from_spirv ( diff --git a/dpctl/program/_program.pyx b/dpctl/program/_program.pyx index 7cecc3fa4d..50ad8aac02 100644 --- a/dpctl/program/_program.pyx +++ b/dpctl/program/_program.pyx @@ -36,6 +36,7 @@ from cpython.buffer cimport ( ) from cpython.bytes cimport PyBytes_FromStringAndSize from libc.stdint cimport uint32_t +from libc.stdlib cimport free, malloc from libc.string cimport memcmp import warnings @@ -469,7 +470,10 @@ cpdef create_kernel_bundle_from_source(SyclQueue q, str src, str copts=""): cpdef create_kernel_bundle_from_spirv( - SyclQueue q, const unsigned char[:] IL, str copts="" + SyclQueue q, + const unsigned char[:] IL, + str copts="", + list specializations=None, ): """ Creates a Sycl interoperability kernel bundle from an SPIR-V binary. @@ -487,7 +491,9 @@ cpdef create_kernel_bundle_from_spirv( copts (str, optional) Optional compilation flags that will be used when compiling the kernel bundle. Default: ``""``. - + specializations (list, optional) + A list of :class:`.SpecializationConstant` objects to be used + when creating the kernel bundle. Default: ``None``. Returns: kernel_bundle (:class:`.SyclKernelBundle`) A :class:`.SyclKernelBundle` object wrapping the @@ -506,11 +512,44 @@ cpdef create_kernel_bundle_from_spirv( cdef size_t length = IL.shape[0] cdef bytes bCOpts = copts.encode("utf8") cdef const char *COpts = bCOpts - KBref = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, dIL, length, COpts - ) - if KBref is NULL: - raise SyclKernelBundleCompilationError() + cdef size_t num_spconsts + cdef _spec_const *spconsts + cdef SpecializationConstant spconst + + if specializations is not None: + num_spconsts = len(specializations) + spconsts = <_spec_const *>( + malloc(num_spconsts * sizeof(_spec_const)) + ) + if spconsts == NULL: + raise MemoryError( + "Failed to allocate memory for specialization constants." + ) + for i, spconst in enumerate(specializations): + if not isinstance(spconst, SpecializationConstant): + free(spconsts) + raise TypeError( + "All items in specializations must be of type " + f"`SpecializationConstant`, got {type(spconst)}" + ) + spconsts[i] = spconst._spec_const + else: + num_spconsts = 0 + spconsts = NULL + try: + KBref = DPCTLKernelBundle_CreateFromSpirv( + CRef, + DRef, + dIL, + length, COpts, + num_spconsts, + spconsts, + ) + if KBref is NULL: + raise SyclKernelBundleCompilationError() + finally: + if spconsts != NULL: + free(spconsts) return SyclKernelBundle._create(KBref) diff --git a/libsyclinterface/include/syclinterface/dpctl_sycl_kernel_bundle_interface.h b/libsyclinterface/include/syclinterface/dpctl_sycl_kernel_bundle_interface.h index de7543016e..3909a1c3d1 100644 --- a/libsyclinterface/include/syclinterface/dpctl_sycl_kernel_bundle_interface.h +++ b/libsyclinterface/include/syclinterface/dpctl_sycl_kernel_bundle_interface.h @@ -58,6 +58,8 @@ typedef struct DPCTLSpecConstTy * @param Length The size of the IL binary in bytes. * @param CompileOpts Optional compiler flags used when compiling the * SPIR-V binary. + * @param NumSpecConsts The number of specialization constants. + * @param SpecConsts An array of specialization constants. * @return A new SyclKernelBundleRef pointer if the kernel_bundle creation * succeeded, else returns NULL. * @ingroup KernelBundleInterface @@ -68,7 +70,9 @@ DPCTLKernelBundle_CreateFromSpirv(__dpctl_keep const DPCTLSyclContextRef Ctx, __dpctl_keep const DPCTLSyclDeviceRef Dev, __dpctl_keep const void *IL, size_t Length, - const char *CompileOpts); + const char *CompileOpts, + size_t NumSpecConsts, + const DPCTLSpecConst *SpecConsts); /*! * @brief Create a Sycl kernel bundle from an OpenCL kernel source string. diff --git a/libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp b/libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp index 78c714ecbb..73e5d9ef87 100644 --- a/libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp @@ -31,6 +31,7 @@ #include "dpctl_error_handlers.h" #include "dpctl_sycl_type_casters.hpp" #include /* OpenCL headers */ +#include #include #include #include @@ -170,6 +171,21 @@ std::string _GetErrorCode_ocl_impl(cl_int code) } } +typedef cl_int (*clSetProgramSpecializationConstantFT)(cl_program, + cl_uint, + size_t, + const void *); +const char *clSetProgramSpecializationConstant_Name = + "clSetProgramSpecializationConstant"; +clSetProgramSpecializationConstantFT get_clSetProgramSpecializationConstant() +{ + static auto st_clSetProgramSpecializationConstantF = + cl_loader::get().getSymbol( + clSetProgramSpecializationConstant_Name); + + return st_clSetProgramSpecializationConstantF; +} + DPCTLSyclKernelBundleRef _CreateKernelBundle_common_ocl_impl(cl_program clProgram, const context &ctx, @@ -235,7 +251,9 @@ _CreateKernelBundleWithIL_ocl_impl(const context &ctx, const device &dev, const void *IL, size_t il_length, - const char *CompileOpts) + const char *CompileOpts, + size_t NumSpecConsts, + const DPCTLSpecConst *SpecConsts) { auto clCreateProgramWithILF = get_clCreateProgramWithIL(); if (clCreateProgramWithILF == nullptr) { @@ -257,6 +275,22 @@ _CreateKernelBundleWithIL_ocl_impl(const context &ctx, return nullptr; } + if (SpecConsts != nullptr && NumSpecConsts > 0) { + auto clSetProgramSpecConstF = get_clSetProgramSpecializationConstant(); + if (clSetProgramSpecConstF) { + for (size_t i = 0; i < NumSpecConsts; ++i) { + clSetProgramSpecConstF(clProgram, SpecConsts[i].id, + SpecConsts[i].size, SpecConsts[i].value); + } + } + else { + error_handler("clSetProgramSpecializationConstant is not available " + "in the OpenCL implementation.", + __FILE__, __func__, __LINE__); + return nullptr; + } + } + return _CreateKernelBundle_common_ocl_impl(clProgram, ctx, dev, CompileOpts); } @@ -428,7 +462,9 @@ _CreateKernelBundleWithIL_ze_impl(const context &SyclCtx, const device &SyclDev, const void *IL, size_t il_length, - const char *CompileOpts) + const char *CompileOpts, + size_t NumSpecConsts, + const DPCTLSpecConst *SpecConsts) { auto zeModuleCreateFn = get_zeModuleCreate(); if (zeModuleCreateFn == nullptr) { @@ -444,8 +480,22 @@ _CreateKernelBundleWithIL_ze_impl(const context &SyclCtx, ZeDevice = get_native(SyclDev); // Specialization constants are not supported by DPCTL at the moment + std::vector spec_ids; + std::vector spec_values; + + if (SpecConsts != nullptr && NumSpecConsts > 0) { + spec_ids.reserve(NumSpecConsts); + spec_values.reserve(NumSpecConsts); + for (size_t i = 0; i < NumSpecConsts; ++i) { + spec_ids.push_back(SpecConsts[i].id); + spec_values.push_back(SpecConsts[i].value); + } + } ze_module_constants_t ZeSpecConstants = {}; - ZeSpecConstants.numConstants = 0; + ZeSpecConstants.numConstants = static_cast(NumSpecConsts); + ZeSpecConstants.pConstantIds = spec_ids.empty() ? nullptr : spec_ids.data(); + ZeSpecConstants.pConstantValues = + spec_values.empty() ? nullptr : spec_values.data(); // Populate the Level Zero module descriptions ze_module_desc_t ZeModuleDesc = {}; @@ -583,7 +633,9 @@ DPCTLKernelBundle_CreateFromSpirv(__dpctl_keep const DPCTLSyclContextRef CtxRef, __dpctl_keep const DPCTLSyclDeviceRef DevRef, __dpctl_keep const void *IL, size_t length, - const char *CompileOpts) + const char *CompileOpts, + size_t NumSpecConsts, + const DPCTLSpecConst *SpecConsts) { DPCTLSyclKernelBundleRef KBRef = nullptr; if (!CtxRef) { @@ -611,12 +663,14 @@ DPCTLKernelBundle_CreateFromSpirv(__dpctl_keep const DPCTLSyclContextRef CtxRef, switch (BE) { case backend::opencl: KBRef = _CreateKernelBundleWithIL_ocl_impl(*SyclCtx, *SyclDev, IL, - length, CompileOpts); + length, CompileOpts, + NumSpecConsts, SpecConsts); break; case backend::ext_oneapi_level_zero: #ifdef DPCTL_ENABLE_L0_PROGRAM_CREATION KBRef = _CreateKernelBundleWithIL_ze_impl(*SyclCtx, *SyclDev, IL, - length, CompileOpts); + length, CompileOpts, + NumSpecConsts, SpecConsts); break; #endif default: From 0a680788ce0f96e6cc4a7130250b8ce172cf20e3 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 28 Apr 2026 15:29:29 -0700 Subject: [PATCH 3/7] add test for SpecializationConstant use in kernel --- .../specialization_constant_kernel.spv | Bin 0 -> 2288 bytes dpctl/tests/test_sycl_program.py | 38 ++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 dpctl/tests/input_files/specialization_constant_kernel.spv diff --git a/dpctl/tests/input_files/specialization_constant_kernel.spv b/dpctl/tests/input_files/specialization_constant_kernel.spv new file mode 100644 index 0000000000000000000000000000000000000000..d696fa755f7e56d61ceeab707c468f0c937b8ed6 GIT binary patch literal 2288 zcmZ{kTTc^F5XVoUih_cOco&O!Ll6NGwRj6qN~MBQgNZMjrQPTzWlP${fbqd6KZmcr z`Q=R11mpj=yT@!uJlaW#ZFsI!sm~{7}LVC%h-&*Aspt z;roNebek?ydAZ@NEU!D>?zUGA9ka^wQVn0C{MJ)n$ew|y1_)$XNJ^) zF(*aN+t-!Vnbuyle%INUnyq=E>({4e)*CxP&8w}u!8>ooYX)BZjUT>WSuU+sCY>2S z^tMW6f2vU4oGKIwNYj(^Zp*J$)X734XoYc7p`a_Gxwx-SoxvAH+}l1Uj_*{z;PUlL zM)^#n9H+JIH$OP@JAOSZ1;u(})2)|+4~?oD`i-EZF%@?-x}$Au#s3IChs+WEjH%WI z*#xua^BO*%Z!#|aZdGaHkI~8J-DcC>Bkq_v{x8mP-g$x7bVIM^w8O%fqN15&RMfV= z1GDfAi^74N(~+EwWcK1>q{n#vw_Y-@tpoP94l;EvM|HAhK>WIh9VrTV^l=e5p`Q{@ z`_a?>ymap2(DP&Qo&GqtAemUcC$*Li=w<2Y+%@Us4l0&@o{2co67u<0a+IXPCRcg`94JH@#lo|DeJ)I%STr%U{z ztl)0@9Fmt^!-mXk=;R+0CkGpPpE$b!21C->J!}yE6XL{UA6BFwjCAVYBj#LU<9=$e zN5~NVOX6{#v3^y0e815SbeP*o+`KrwfI38db~}12kYIE4Igpr$CwDLh|7syNkbkN z;p2QuX1pUA-0*RqtuZMXd=J)`k_{in)|i$|4Sc+p@4AogaaTlraQvl8{9k}M{D;00 zweN|D;XBNT&WP}le_uome6u2azl9-oo#dI~Xr*X;2wkE Date: Tue, 28 Apr 2026 16:25:17 -0700 Subject: [PATCH 4/7] fix libsyclinterface tests --- .../test_sycl_kernel_bundle_interface.cpp | 24 +++++++++++-------- .../tests/test_sycl_queue_submit.cpp | 6 +++-- ...t_sycl_queue_submit_local_accessor_arg.cpp | 6 +++-- .../test_sycl_queue_submit_raw_kernel_arg.cpp | 6 +++-- ...ycl_queue_submit_work_group_memory_arg.cpp | 6 +++-- 5 files changed, 30 insertions(+), 18 deletions(-) diff --git a/libsyclinterface/tests/test_sycl_kernel_bundle_interface.cpp b/libsyclinterface/tests/test_sycl_kernel_bundle_interface.cpp index a835c277b9..5793e983d9 100644 --- a/libsyclinterface/tests/test_sycl_kernel_bundle_interface.cpp +++ b/libsyclinterface/tests/test_sycl_kernel_bundle_interface.cpp @@ -69,7 +69,8 @@ struct TestDPCTLSyclKernelBundleInterface spirvFile.seekg(0, std::ios::beg); spirvFile.read(spirvBuffer.data(), spirvFileSize); KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, spirvBuffer.data(), spirvFileSize, nullptr); + CRef, DRef, spirvBuffer.data(), spirvFileSize, nullptr, 0, + nullptr); } } @@ -132,18 +133,21 @@ TEST_P(TestDPCTLSyclKernelBundleInterface, ChkCreateFromSpirvNull) const void *null_spirv = nullptr; DPCTLSyclKernelBundleRef KBRef = nullptr; // Null context - EXPECT_NO_FATAL_FAILURE(KBRef = DPCTLKernelBundle_CreateFromSpirv( - Null_CRef, Null_DRef, null_spirv, 0, nullptr)); + EXPECT_NO_FATAL_FAILURE( + KBRef = DPCTLKernelBundle_CreateFromSpirv( + Null_CRef, Null_DRef, null_spirv, 0, nullptr, 0, nullptr)); ASSERT_TRUE(KBRef == nullptr); // Null device - EXPECT_NO_FATAL_FAILURE(KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, Null_DRef, null_spirv, 0, nullptr)); + EXPECT_NO_FATAL_FAILURE( + KBRef = DPCTLKernelBundle_CreateFromSpirv(CRef, Null_DRef, null_spirv, + 0, nullptr, 0, nullptr)); ASSERT_TRUE(KBRef == nullptr); // Null IL - EXPECT_NO_FATAL_FAILURE(KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, null_spirv, 0, nullptr)); + EXPECT_NO_FATAL_FAILURE( + KBRef = DPCTLKernelBundle_CreateFromSpirv(CRef, DRef, null_spirv, 0, + nullptr, 0, nullptr)); ASSERT_TRUE(KBRef == nullptr); } @@ -350,8 +354,8 @@ TEST_F(TestKernelBundleUnsupportedBackend, CheckCreateFromSpirv) spirvFile.close(); DPCTLSyclKernelBundleRef KBRef = nullptr; - EXPECT_NO_FATAL_FAILURE( - KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, spirvBuffer.data(), spirvFileSize, nullptr)); + EXPECT_NO_FATAL_FAILURE(KBRef = DPCTLKernelBundle_CreateFromSpirv( + CRef, DRef, spirvBuffer.data(), spirvFileSize, + nullptr, 0, nullptr)); ASSERT_TRUE(KBRef == nullptr); } diff --git a/libsyclinterface/tests/test_sycl_queue_submit.cpp b/libsyclinterface/tests/test_sycl_queue_submit.cpp index ab5b6bef82..f2fc2b2140 100644 --- a/libsyclinterface/tests/test_sycl_queue_submit.cpp +++ b/libsyclinterface/tests/test_sycl_queue_submit.cpp @@ -242,7 +242,8 @@ struct TestQueueSubmit : public ::testing::Test auto CRef = DPCTLQueue_GetContext(QRef); KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr); + CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr, 0, + nullptr); DPCTLDevice_Delete(DRef); DPCTLDeviceSelector_Delete(DSRef); } @@ -282,7 +283,8 @@ struct TestQueueSubmitFP64 : public ::testing::Test auto CRef = DPCTLQueue_GetContext(QRef); KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr); + CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr, 0, + nullptr); DPCTLDeviceSelector_Delete(DSRef); } diff --git a/libsyclinterface/tests/test_sycl_queue_submit_local_accessor_arg.cpp b/libsyclinterface/tests/test_sycl_queue_submit_local_accessor_arg.cpp index f6110375bb..8f1b97d1d4 100644 --- a/libsyclinterface/tests/test_sycl_queue_submit_local_accessor_arg.cpp +++ b/libsyclinterface/tests/test_sycl_queue_submit_local_accessor_arg.cpp @@ -237,7 +237,8 @@ struct TestQueueSubmitWithLocalAccessor : public ::testing::Test auto CRef = DPCTLQueue_GetContext(QRef); KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr); + CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr, 0, + nullptr); DPCTLDevice_Delete(DRef); DPCTLDeviceSelector_Delete(DSRef); } @@ -276,7 +277,8 @@ struct TestQueueSubmitWithLocalAccessorFP64 : public ::testing::Test auto CRef = DPCTLQueue_GetContext(QRef); KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr); + CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr, 0, + nullptr); DPCTLDeviceSelector_Delete(DSRef); } diff --git a/libsyclinterface/tests/test_sycl_queue_submit_raw_kernel_arg.cpp b/libsyclinterface/tests/test_sycl_queue_submit_raw_kernel_arg.cpp index f40bc20066..04d3958d8d 100644 --- a/libsyclinterface/tests/test_sycl_queue_submit_raw_kernel_arg.cpp +++ b/libsyclinterface/tests/test_sycl_queue_submit_raw_kernel_arg.cpp @@ -262,7 +262,8 @@ struct TestQueueSubmitWithRawKernelArg : public ::testing::Test auto CRef = DPCTLQueue_GetContext(QRef); KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr); + CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr, 0, + nullptr); DPCTLDevice_Delete(DRef); DPCTLDeviceSelector_Delete(DSRef); } @@ -301,7 +302,8 @@ struct TestQueueSubmitWithRawKernelArgFP64 : public ::testing::Test auto CRef = DPCTLQueue_GetContext(QRef); KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr); + CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr, 0, + nullptr); DPCTLDeviceSelector_Delete(DSRef); } diff --git a/libsyclinterface/tests/test_sycl_queue_submit_work_group_memory_arg.cpp b/libsyclinterface/tests/test_sycl_queue_submit_work_group_memory_arg.cpp index d0f44b7275..d1d1f69bfa 100644 --- a/libsyclinterface/tests/test_sycl_queue_submit_work_group_memory_arg.cpp +++ b/libsyclinterface/tests/test_sycl_queue_submit_work_group_memory_arg.cpp @@ -262,7 +262,8 @@ struct TestQueueSubmitWithWorkGroupMemory : public ::testing::Test auto CRef = DPCTLQueue_GetContext(QRef); KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr); + CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr, 0, + nullptr); DPCTLDevice_Delete(DRef); DPCTLDeviceSelector_Delete(DSRef); } @@ -301,7 +302,8 @@ struct TestQueueSubmitWithWorkGroupMemoryFP64 : public ::testing::Test auto CRef = DPCTLQueue_GetContext(QRef); KBRef = DPCTLKernelBundle_CreateFromSpirv( - CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr); + CRef, DRef, spirvBuffer_.data(), spirvFileSize_, nullptr, 0, + nullptr); DPCTLDeviceSelector_Delete(DSRef); } From d8b0afbeb07d819bee1fd8766e0fb0e43e6e4aaf Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 30 Apr 2026 11:35:16 -0700 Subject: [PATCH 5/7] add test for composite specialization constant also removes "v" as a permitted specialization constant intermediate data type, as composite specialization constants are broken into multiple specialization constants, so structs end up passed as a single constant while the program expects multiple, and therefore, doesn't work as intended --- dpctl/program/_program.pyx | 12 ++--- .../specialization_constant_composite.spv | Bin 0 -> 3432 bytes dpctl/tests/test_sycl_program.py | 41 ++++++++++++++++++ 3 files changed, 47 insertions(+), 6 deletions(-) create mode 100644 dpctl/tests/input_files/specialization_constant_composite.spv diff --git a/dpctl/program/_program.pyx b/dpctl/program/_program.pyx index 50ad8aac02..166d46a3e3 100644 --- a/dpctl/program/_program.pyx +++ b/dpctl/program/_program.pyx @@ -282,8 +282,8 @@ cdef class SpecializationConstant: If the constructor is invoked with a single variadic argument, the argument is expected to either expose the Python buffer protocol or be coercible to a NumPy array. If the argument is coercible to a NumPy array - or is one, it must have a supported data type (bool, integral, floating - point, or void). The specialization constant will be constructed from the + or is one, it must have a supported data type (bool, integral, or + floating point). The specialization constant will be constructed from the data in the buffer - ``SpecializationConstant(spec_id, dtype, obj)`` @@ -291,7 +291,7 @@ cdef class SpecializationConstant: argument is a string, it is interpreted as a NumPy ``dtype`` string and the second argument will be coerced to a NumPy array with that data type. The data type specified by the first argument must be a supported data - type (bool, integral, floating point, or void). + type (bool, integral, or floating point). - ``SpecializationConstant(spec_id, nbytes, raw_ptr)`` If the constructor is invoked with two variadic arguments where both are @@ -360,13 +360,13 @@ cdef class SpecializationConstant: ) if isinstance(target_obj, np.ndarray): - if target_obj.dtype.kind not in ("b", "i", "u", "f", "c", "V"): + if target_obj.dtype.kind not in ("b", "i", "u", "f", "c"): raise TypeError( "Coercion of input to buffer resulted in an unsupported " f"data type '{target_obj.dtype}'. When coercing objects, " "`SpecializationConstant` expects the data to coerce to a " - "supported type: bool, integral, real or complex floating " - "point, or void. To pass arbitrary data, use a " + "supported type: bool, integral, or real or complex " + "floating point. To pass arbitrary data, use a " "`memoryview` or `bytes` object, or pass the pointer and " "size directly." ) diff --git a/dpctl/tests/input_files/specialization_constant_composite.spv b/dpctl/tests/input_files/specialization_constant_composite.spv new file mode 100644 index 0000000000000000000000000000000000000000..6b69617a7dc599495bc6a43862832b1ef77eab0d GIT binary patch literal 3432 zcma)-ZBHCk6vt=FQYyt_ZE1aIk=oLNR4fnDx)numffb~brJzmPVb~pD?LOe{fPgir zFZ6TxR$uu={5U3RQse(Ob5Ge2ns}4{p2vI6x#ylc+j8P+o0m#?XS|c%4_>m)dMCWJ zSZJ%GCmlWQ=ygYLIQqOR@&Dt3=e2vMy!_&dpIcbY&ifnN{zlM<{7MjQ`ITBRl)Wvs z&&@9R;of#wi2R_@*r|kyby^uyqW8AL>eT#TBPx2Ttwoi1F*?0&RqA=K3;x69{K9Bs zzfd0WSBA!lVHA|g!=p>JooX>GE(O)ia4xJ@!}4k=+M1iqE#zJB28(Ql@IvMwINzW|f2Uldn^iL4BaSFWD3Ox2RtF_sh>I$(%@g{Kj^v zzUxo!l*&=AI#aH#2jyIKw^j(EQmvZPG3ED4lH)kveCdGxw5qqKL=~OGCF)F zQ|GwY+ZE*)d(xduCaBkgef&DSS6}$~{$^MW>p>J2{bPF(=Y3QDjkr^T%RALlb#pPO z2bC}i>kV&4(Ku(n@PE$A9nYvw?9s@GPscp-bt#sUe!n!H>xaeRZY&Nwao&mJq`h;Z z*9B&XJjgc%`iOj9wAqf_Y`^Se?iBx7QS!2obE3#eEotuy$6LNd$$dhH&Wr!*5qV8A z`7C~LL@peWzmeRW|A~_=f7!{FzahCf|3Q=7rL6c(2rT3?qV$G^{FZ1Bg4Td;3IBG% zYB1e39NZdk`+ASa1CrTi_K(FA&thc7Ydzy<*u|jVEM#hGU@(8IqW5D(tYGA*8%S>fNV6RIbTiwDIwzib(!#WK z-2c453=ofgGv{9l=*<^B?{z_-N2aeS68`s`?I*=VUKEIdp1r*7`vQHmy}!A8$vfMU zsO4W4KOvYecT3zA0sjXAK0msAtK!Lt-OlxucxG`@;QYw=Sxn}getj*_uax=)PrshY zAHRnJ`^N<9J$Hn@N3faWPEF4}H_rk0_gr71JA!>5eCb$^1-=Q7MAw@5Ks8OGEgs1N|PyYdNu{1IO literal 0 HcmV?d00001 diff --git a/dpctl/tests/test_sycl_program.py b/dpctl/tests/test_sycl_program.py index 7b5f8db20e..61b9fee843 100644 --- a/dpctl/tests/test_sycl_program.py +++ b/dpctl/tests/test_sycl_program.py @@ -300,3 +300,44 @@ def test_create_kernel_bundle_with_spec_const(): ht_e.wait() assert np.all(y == 43) + + +def test_create_kernel_bundle_with_composite_spec_const(): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Could not create default queue") + + # composite specialization constants are separated into individual + # specialization constants with unique spec_ids + sp1 = dpctl_prog.SpecializationConstant(0, "i4", 10) + sp2 = dpctl_prog.SpecializationConstant(1, "f4", 2.5) + sp3 = dpctl_prog.SpecializationConstant(2, "?", 1) + + spirv_file = get_spirv_abspath("specialization_constant_composite.spv") + with open(spirv_file, "br") as spv: + spv_bytes = spv.read() + + kb = dpctl_prog.create_kernel_bundle_from_spirv( + q, spv_bytes, specializations=[sp1, sp2, sp3] + ) + kernel = kb.get_sycl_kernel("_ZTS21StructSpecConstKernel") + + n = 128 + x = np.ones(n, dtype="f4") + y = np.zeros_like(x) + + x_usm = dpctl.memory.MemoryUSMDevice(x.nbytes, queue=q) + y_usm = dpctl.memory.MemoryUSMDevice(y.nbytes, queue=q) + + e1 = q.memcpy_async(x_usm, x, x.nbytes) + e2 = q.submit(kernel, [x_usm, y_usm], [n], dEvents=[e1]) + e3 = q.memcpy_async(y, y_usm, y.nbytes, [e2]) + + ht_e = q._submit_keep_args_alive([x_usm], [e3]) + + e3.wait() + ht_e.wait() + + # 1.0 * 10 + 2.5 = 12.5 + assert np.all(y == 12.5) From d2c6bb6707ff954fade21c0897e74cf14d68c009 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 4 May 2026 20:27:13 -0700 Subject: [PATCH 6/7] add program.utils namespace with SPIRV parser --- dpctl/program/__init__.py | 5 ++ dpctl/program/utils/__init__.py | 25 ++++++++ dpctl/program/utils/_utils.py | 106 ++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+) create mode 100644 dpctl/program/utils/__init__.py create mode 100644 dpctl/program/utils/_utils.py diff --git a/dpctl/program/__init__.py b/dpctl/program/__init__.py index bb16f9611b..e1e625ff7a 100644 --- a/dpctl/program/__init__.py +++ b/dpctl/program/__init__.py @@ -45,6 +45,11 @@ "SpecializationConstant", ] +# add submodules +__all__ += [ + "utils", +] + def __getattr__(name): if name == "SyclProgram": diff --git a/dpctl/program/utils/__init__.py b/dpctl/program/utils/__init__.py new file mode 100644 index 0000000000..474f154f95 --- /dev/null +++ b/dpctl/program/utils/__init__.py @@ -0,0 +1,25 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A collection of utility functions for dpctl.program module. +""" + +from ._utils import parse_spirv_specializations + +__all__ = [ + "parse_spirv_specializations", +] diff --git a/dpctl/program/utils/_utils.py b/dpctl/program/utils/_utils.py new file mode 100644 index 0000000000..2a16802dea --- /dev/null +++ b/dpctl/program/utils/_utils.py @@ -0,0 +1,106 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements various utilities for the dpctl.program module.""" + +from enum import IntEnum + +import numpy as np + + +class SpirvOpCode(IntEnum): + OpName = 5 + OpTypeBool = 20 + OpTypeInt = 21 + OpTypeFloat = 22 + OpSpecConstantTrue = 48 + OpSpecConstantFalse = 49 + OpSpecConstant = 50 + OpDecorate = 71 + + +class SpirvDecoration(IntEnum): + SpecId = 1 + + +def parse_spirv_specializations( + spv_bytes: bytes | bytearray | memoryview, +) -> dict[int, dict[str, str]]: + words = np.frombuffer(spv_bytes, dtype=np.uint32) + + # verify magic number + if len(words) < 5 or words[0] != 0x07230203: + raise ValueError("Invalid SPIR-V binary") + + types = {} + ids = {} + names = {} + constants = {} + + i = 5 # skip 5 word header + while i < len(words): + word = words[i] + opcode = word & 0xFFFF + word_count = word >> 16 + + if word_count == 0: + raise ValueError(f"Invalid SPIR-V instruction at word index {i}") + + if opcode == SpirvOpCode.OpTypeBool: + result_id = int(words[i + 1]) + types[result_id] = "?" + elif opcode == SpirvOpCode.OpTypeInt: + result_id = int(words[i + 1]) + width = int(words[i + 2]) + signed = int(words[i + 3]) + prefix = "i" if signed else "u" + types[result_id] = f"{prefix}{width // 8}" + elif opcode == SpirvOpCode.OpTypeFloat: + result_id = int(words[i + 1]) + width = int(words[i + 2]) + types[result_id] = f"f{width // 8}" + elif opcode in ( + SpirvOpCode.OpSpecConstant, + SpirvOpCode.OpSpecConstantTrue, + SpirvOpCode.OpSpecConstantFalse, + ): + type_id = int(words[i + 1]) + result_id = int(words[i + 2]) + constants[result_id] = type_id + elif opcode == SpirvOpCode.OpDecorate: + target_id = int(words[i + 1]) + decoration = int(words[i + 2]) + if decoration == SpirvDecoration.SpecId: + ids[target_id] = int(words[i + 3]) + elif opcode == SpirvOpCode.OpName: + target_id = int(words[i + 1]) + name_bytes = words[i + 2 : i + word_count].tobytes() + names[target_id] = name_bytes.split(b"\x00", 1)[0].decode("utf-8") + + i += word_count + + result = {} + for target_id, spec_id in ids.items(): + type_id = constants.get(target_id) + dtype_str = types.get(type_id, "unknown_type") + name = names.get(target_id, f"unnamed_spec_const_{spec_id}") + + result[spec_id] = { + "name": name, + "dtype": dtype_str, + } + + return result From 043aa83af7e37e87334270d5963b8b78403bdf80 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 4 May 2026 23:51:48 -0700 Subject: [PATCH 7/7] Refactor specialiazation constants to use dataclass also adds spec_id, itemsize, and default_value fields --- dpctl/program/utils/_utils.py | 122 +++++++++++++++--- .../specialization_constant_composite.spv | Bin 3432 -> 3504 bytes dpctl/tests/test_sycl_program.py | 47 +++++++ setup.py | 1 + 4 files changed, 152 insertions(+), 18 deletions(-) diff --git a/dpctl/program/utils/_utils.py b/dpctl/program/utils/_utils.py index 2a16802dea..86df0855ad 100644 --- a/dpctl/program/utils/_utils.py +++ b/dpctl/program/utils/_utils.py @@ -16,6 +16,7 @@ """Implements various utilities for the dpctl.program module.""" +from dataclasses import dataclass from enum import IntEnum import numpy as np @@ -29,6 +30,7 @@ class SpirvOpCode(IntEnum): OpSpecConstantTrue = 48 OpSpecConstantFalse = 49 OpSpecConstant = 50 + OpFunction = 54 OpDecorate = 71 @@ -36,9 +38,52 @@ class SpirvDecoration(IntEnum): SpecId = 1 +@dataclass(frozen=True) +class SpecializationConstantInfo: + """Data class representing specialization constant information.""" + + spec_id: int + dtype: str + name: str + itemsize: int + default_value: int | float | bool | None + + def parse_spirv_specializations( spv_bytes: bytes | bytearray | memoryview, -) -> dict[int, dict[str, str]]: +) -> tuple[SpecializationConstantInfo]: + """ + Parses SPIR-V byte stream to extract information about specializations, + including the specialization IDs, types, names, and default values. + + Note that the dtype information may be imprecise, as the compiler may + choose to, for example, represent a bool as char, or may represent both + signed and unsigned integers as unsigned integer bit buckets of the same + length. + + Args: + spv_bytes (bytes | bytearray | memoryview): + the SPIR-V byte stream. + + Returns: + tuple[SpecializationConstantInfo]: + a tuple of parsed constants and their information represented by + `SpecializationConstantInfo` objects, sorted by their + specialization IDs. The length of the tuple is equal to the number + of specialization constants found. Each + `SpecializationConstantInfo` object contains the following + attributes: + + - `spec_id` (int): The specialization ID. + - `dtype` (str): A NumPy style string representing the data type. + - `itemsize` (int): The size of the specialization constant in + bytes. + - `name` (str): The variable name. If not preserved in the binary, + a default name in the format `unnamed_spec_const_{spec_id}` is + used. + - `default_value` (int | float | bool | None): The default value of + the specialization constant. If not specified, `None` is used. + """ words = np.frombuffer(spv_bytes, dtype=np.uint32) # verify magic number @@ -49,6 +94,7 @@ def parse_spirv_specializations( ids = {} names = {} constants = {} + defaults = {} i = 5 # skip 5 word header while i < len(words): @@ -59,27 +105,45 @@ def parse_spirv_specializations( if word_count == 0: raise ValueError(f"Invalid SPIR-V instruction at word index {i}") - if opcode == SpirvOpCode.OpTypeBool: + if opcode == SpirvOpCode.OpFunction: + # everything following is not relevant to specialization constant + # parsing, so we can stop parsing at this point + break + elif opcode == SpirvOpCode.OpTypeBool: result_id = int(words[i + 1]) - types[result_id] = "?" + types[result_id] = {"dtype": "?", "itemsize": 1} elif opcode == SpirvOpCode.OpTypeInt: result_id = int(words[i + 1]) width = int(words[i + 2]) signed = int(words[i + 3]) prefix = "i" if signed else "u" - types[result_id] = f"{prefix}{width // 8}" + types[result_id] = { + "dtype": f"{prefix}{width // 8}", + "itemsize": width // 8, + } elif opcode == SpirvOpCode.OpTypeFloat: result_id = int(words[i + 1]) width = int(words[i + 2]) - types[result_id] = f"f{width // 8}" - elif opcode in ( - SpirvOpCode.OpSpecConstant, - SpirvOpCode.OpSpecConstantTrue, - SpirvOpCode.OpSpecConstantFalse, - ): + types[result_id] = { + "dtype": f"f{width // 8}", + "itemsize": width // 8, + } + elif opcode == SpirvOpCode.OpSpecConstant: + type_id = int(words[i + 1]) + result_id = int(words[i + 2]) + constants[result_id] = type_id + literal_words = words[i + 3 : i + word_count] + defaults[result_id] = literal_words.tobytes() + elif opcode == SpirvOpCode.OpSpecConstantTrue: type_id = int(words[i + 1]) result_id = int(words[i + 2]) constants[result_id] = type_id + defaults[result_id] = True + elif opcode == SpirvOpCode.OpSpecConstantFalse: + type_id = int(words[i + 1]) + result_id = int(words[i + 2]) + constants[result_id] = type_id + defaults[result_id] = False elif opcode == SpirvOpCode.OpDecorate: target_id = int(words[i + 1]) decoration = int(words[i + 2]) @@ -92,15 +156,37 @@ def parse_spirv_specializations( i += word_count - result = {} + # a spec ID may appear multiple times in the same binary with different + # target IDs. We only need to keep one, so skip duplicates + unique_ids = set() + result = [] for target_id, spec_id in ids.items(): + if spec_id in unique_ids: + continue + unique_ids.add(spec_id) type_id = constants.get(target_id) - dtype_str = types.get(type_id, "unknown_type") + type_info = types.get(type_id, {"dtype": "unknown_type", "itemsize": 0}) name = names.get(target_id, f"unnamed_spec_const_{spec_id}") - result[spec_id] = { - "name": name, - "dtype": dtype_str, - } - - return result + dtype_str = type_info["dtype"] + raw_default = defaults.get(target_id) + default_value = None + if isinstance(raw_default, bytes): + try: + default_value = np.frombuffer(raw_default, dtype=dtype_str)[ + 0 + ].item() + except Exception: + default_value = None + + result.append( + SpecializationConstantInfo( + spec_id=spec_id, + dtype=dtype_str, + name=name, + itemsize=type_info["itemsize"], + default_value=default_value, + ) + ) + + return tuple(sorted(result, key=lambda x: x.spec_id)) diff --git a/dpctl/tests/input_files/specialization_constant_composite.spv b/dpctl/tests/input_files/specialization_constant_composite.spv index 6b69617a7dc599495bc6a43862832b1ef77eab0d..c262f97ff6dfe48767c2ce52a31aea146a0d6b3e 100644 GIT binary patch literal 3504 zcmZ{lYjaao6oz+5C|D3gZeFm4q7(!wy$Bct3pJG3T1rzuQRnnDIc=lKNlZ>FMNs&l zzr(Nk$v@)HaYmiN@qJGAqM1(FGjG;)uf6tKYoBCyoY>_u8Mn{%xhJl>4!9jICl_o4czV*)Ay1zj?)v|E#JPUA+m+|;2Bq1B;!LpI3YNo85;VeOHE1-ek?gxt`}EEE zAbQY>DoGG_I%|zcvG(Z1jOeXa6i?0!cao}8ZdqmGQtWpF%G9}6h2Z`|d3L`=N5|)zYjHKI&WG_zG#$0$sD7`OtWMu7&6dvxQdlp=>&;4-)S7Wg*OcE&=^5MlmTM2G&ra2Lw`irVk_~hBLP3`=6xv)D zd%vP=WA9d{Pzc-Ya09=+?&arxL9i0VQ9DecYOrlB(sN&zeipD)p3)|+u<*xL2N_Aq7Mn-%t<(aQzsobsi3*-5ESR8s|ao~ybb{Z$=4vL-> zm?6$VJ}Xd1Bu^dcC53>Txbkx|^86S)92g~Gh;)!MQRq>o*F{Z?8J>h56$FLqClP@)A5d-;7 zDsB2(&CH(~v5v~0oH-AgmrS3@2l=2ZyF`yk%Y5102V~{Vp@XMh$n4)M${uvc`$d^~ z@*0rL8$<`Nzaol1`s4CE;AP@qJEHh+isl9KcuSO6)SE9Ovr3QfL64od1ER#mj-2;0 zcOZ`wqVNFwJEHibM}AL~I}rDfyvDq~;Pn>wBgyRR&PmRF>}7Hx&NVN0?~rqglBowW z_YFnazvyvUd*}f^8NuW}FP~OD$YorhpTyw}vNE!9|IdUmMJ2~sfqB7)Il;!9ulgmedCbIz0`Uih4c(nty)007YR9Zz@v_wpJu^Geqvsoh{+j6Z9zC`B zM;Y?JLvNvfsiWlliNM+Pdr~+mU?Z;^0=Z!;3fTTte(2#}6Nm?IceyE^m{Y=khj%z+ z`bVxMfjhZOHszVUtEvHaofXI@BixXV{@)Uq0pd|_W`9mVk8e&u&l|p-@|*gmtoYA+ z+s}%LTo#Cdp1u1-7laAHV*c*?a@X7Lh}!x0#9#06rMJZWLcpK5iO*Ai-hJ_$iQVq? zrFdqMdvSkc{Oo+I*H_X}uZ-#iPraVVAHS~!_A^sfdwPU^P_UV!r>3XRy>meSo~h!O z1p7WLNyiEWz6pz>l^(t-o}NddeB<~&EDL-es8LpQMc65jXI`*(wJIJP%kry<=S^T^ zw(Tz8h{tAggluo)Tk)J}b*g)rxq`R3!mlNeht1WZ`0oVnM11RO+w1Yo3HI(f;yH`l zh6MbXw?x2Z{x)xG-e&W*E}7Wq?ah8K-eU4*iOC#35U^QHo5KxnV-D?oKNN2TdJ&KQWv+ A5C8xG literal 3432 zcma)-ZBHCk6vt=FQYyt_ZE1aIk=oLNR4fnDx)numffb~brJzmPVb~pD?LOe{fPgir zFZ6TxR$uu={5U3RQse(Ob5Ge2ns}4{p2vI6x#ylc+j8P+o0m#?XS|c%4_>m)dMCWJ zSZJ%GCmlWQ=ygYLIQqOR@&Dt3=e2vMy!_&dpIcbY&ifnN{zlM<{7MjQ`ITBRl)Wvs z&&@9R;of#wi2R_@*r|kyby^uyqW8AL>eT#TBPx2Ttwoi1F*?0&RqA=K3;x69{K9Bs zzfd0WSBA!lVHA|g!=p>JooX>GE(O)ia4xJ@!}4k=+M1iqE#zJB28(Ql@IvMwINzW|f2Uldn^iL4BaSFWD3Ox2RtF_sh>I$(%@g{Kj^v zzUxo!l*&=AI#aH#2jyIKw^j(EQmvZPG3ED4lH)kveCdGxw5qqKL=~OGCF)F zQ|GwY+ZE*)d(xduCaBkgef&DSS6}$~{$^MW>p>J2{bPF(=Y3QDjkr^T%RALlb#pPO z2bC}i>kV&4(Ku(n@PE$A9nYvw?9s@GPscp-bt#sUe!n!H>xaeRZY&Nwao&mJq`h;Z z*9B&XJjgc%`iOj9wAqf_Y`^Se?iBx7QS!2obE3#eEotuy$6LNd$$dhH&Wr!*5qV8A z`7C~LL@peWzmeRW|A~_=f7!{FzahCf|3Q=7rL6c(2rT3?qV$G^{FZ1Bg4Td;3IBG% zYB1e39NZdk`+ASa1CrTi_K(FA&thc7Ydzy<*u|jVEM#hGU@(8IqW5D(tYGA*8%S>fNV6RIbTiwDIwzib(!#WK z-2c453=ofgGv{9l=*<^B?{z_-N2aeS68`s`?I*=VUKEIdp1r*7`vQHmy}!A8$vfMU zsO4W4KOvYecT3zA0sjXAK0msAtK!Lt-OlxucxG`@;QYw=Sxn}getj*_uax=)PrshY zAHRnJ`^N<9J$Hn@N3faWPEF4}H_rk0_gr71JA!>5eCb$^1-=Q7MAw@5Ks8OGEgs1N|PyYdNu{1IO diff --git a/dpctl/tests/test_sycl_program.py b/dpctl/tests/test_sycl_program.py index 61b9fee843..564f40bed9 100644 --- a/dpctl/tests/test_sycl_program.py +++ b/dpctl/tests/test_sycl_program.py @@ -23,6 +23,7 @@ import dpctl import dpctl.program as dpctl_prog +from dpctl.program.utils import parse_spirv_specializations def get_spirv_abspath(fn): @@ -341,3 +342,49 @@ def test_create_kernel_bundle_with_composite_spec_const(): # 1.0 * 10 + 2.5 = 12.5 assert np.all(y == 12.5) + + +def test_spirv_specializations_parser(): + spirv_file = get_spirv_abspath("specialization_constant_kernel.spv") + with open(spirv_file, "rb") as spv: + spv_bytes = spv.read() + spec_consts = parse_spirv_specializations(spv_bytes) + assert len(spec_consts) == 1 + assert spec_consts[0].dtype == "u4" + + spirv_file = get_spirv_abspath("specialization_constant_composite.spv") + with open(spirv_file, "rb") as spv: + spv_bytes = spv.read() + + spec_consts = parse_spirv_specializations(spv_bytes) + assert len(spec_consts) == 3 + spec_const0, spec_const1, spec_const2 = spec_consts + assert spec_const0.dtype == "u4" + assert spec_const0.itemsize == 4 + assert spec_const0.name == "unnamed_spec_const_0" + assert spec_const0.default_value == 1 + + assert spec_const1.dtype == "f4" + assert spec_const1.itemsize == 4 + assert spec_const1.name == "unnamed_spec_const_1" + assert spec_const1.default_value == 0 + + # compiler translates bool to char + assert spec_const2.dtype == "u1" + assert spec_const2.itemsize == 1 + assert spec_const2.name == "unnamed_spec_const_2" + assert spec_const2.default_value == 0 + + +def test_spirv_specializations_parser_no_spec_consts(): + spirv_file = get_spirv_abspath("multi_kernel.spv") + with open(spirv_file, "rb") as spv: + spv_bytes = spv.read() + spec_consts = parse_spirv_specializations(spv_bytes) + assert not spec_consts + + +def test_spirv_specializations_parser_invalid_spirv(): + invalid_spv = b"\x00\x01\x02\x03\x04\x05" + with pytest.raises(ValueError): + parse_spirv_specializations(invalid_spv) diff --git a/setup.py b/setup.py index 1b10322ef8..2c44bd4a11 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "dpctl", "dpctl.memory", "dpctl.program", + "dpctl.program.utils", "dpctl.utils", ], package_data={