From ec5749f80b1fec12bc7e028004801405c573ff4a Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Mon, 4 May 2026 01:07:19 -0700 Subject: [PATCH] Implement checked mul via Op{U,S}MulExtended (#537). --- .../src/builder/builder_methods.rs | 76 +++++++-- tests/compiletests/ui/dis/checked_mul.rs | 22 +++ tests/compiletests/ui/dis/checked_mul.stderr | 32 ++++ .../ui/lang/core/ops/checked_mul.rs | 160 ++++++++++++++++++ tests/difftests/tests/Cargo.lock | 17 ++ tests/difftests/tests/Cargo.toml | 2 + .../checked_mul/checked_mul-cpu/Cargo.toml | 10 ++ .../checked_mul/checked_mul-cpu/src/main.rs | 20 +++ .../checked_mul/checked_mul-shader/Cargo.toml | 15 ++ .../checked_mul/checked_mul-shader/src/lib.rs | 46 +++++ .../checked_mul-shader/src/main.rs | 48 ++++++ 11 files changed, 434 insertions(+), 14 deletions(-) create mode 100644 tests/compiletests/ui/dis/checked_mul.rs create mode 100644 tests/compiletests/ui/dis/checked_mul.stderr create mode 100644 tests/compiletests/ui/lang/core/ops/checked_mul.rs create mode 100644 tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-cpu/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-cpu/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/src/lib.rs create mode 100644 tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/src/main.rs diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index d86db1cbd00..b2078c9864d 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -1726,20 +1726,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { rhs: Self::Value, ) -> (Self::Value, Self::Value) { // adopted partially from https://github.com/ziglang/zig/blob/master/src/codegen/spirv.zig - let is_add = match oop { - OverflowOp::Add => true, - OverflowOp::Sub => false, - OverflowOp::Mul => { - // NOTE(eddyb) this needs to be `undef`, not `false`/`true`, because - // we don't want the user's boolean constants to keep the zombie alive. - let bool = SpirvType::Bool.def(self.span(), self); - let overflowed = self.undef(bool); - - let result = (self.mul(lhs, rhs), overflowed); - self.zombie(result.1.def(self), "checked mul is not supported yet"); - return result; - } - }; let signed = match ty.kind() { ty::Int(_) => true, ty::Uint(_) => false, @@ -1752,6 +1738,68 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } )), }; + let is_add = match oop { + OverflowOp::Add => true, + OverflowOp::Sub => false, + OverflowOp::Mul => { + let int_ty = lhs.ty; + let bits = match self.lookup_type(int_ty) { + SpirvType::Integer(width, _) => width, + other => self.fatal(format!( + "checked mul on non-integer type: {}", + other.debug(int_ty, self) + )), + }; + + // OpUMulExtended / OpSMulExtended produce an OpTypeStruct{T, T} + // holding the low and high halves of the full 2*bits-wide product. + // The struct is purely intermediate — we extract both halves + // immediately and never expose it to Rust code. + let pair_ty = SpirvType::Adt { + def_id: None, + size: None, + align: Align::from_bytes(0).unwrap(), + field_types: &[int_ty, int_ty], + field_offsets: &[], + field_names: None, + } + .def(self.span(), self); + + let extended = if signed { + self.emit() + .s_mul_extended(pair_ty, None, lhs.def(self), rhs.def(self)) + } else { + self.emit() + .u_mul_extended(pair_ty, None, lhs.def(self), rhs.def(self)) + } + .unwrap(); + + let low = self + .emit() + .composite_extract(int_ty, None, extended, [0].iter().cloned()) + .unwrap() + .with_type(int_ty); + let high = self + .emit() + .composite_extract(int_ty, None, extended, [1].iter().cloned()) + .unwrap() + .with_type(int_ty); + + let overflowed = if signed { + // For signed multiplication, no overflow occurs iff the high + // half is the sign extension of the low half, i.e. the + // arithmetic-shift of `low` by `bits-1` (replicating the MSB). + let shift_amount = self.constant_int(int_ty, u128::from(bits - 1)); + let expected_high = self.ashr(low, shift_amount); + self.icmp(IntPredicate::IntNE, high, expected_high) + } else { + let zero = self.constant_int(int_ty, 0); + self.icmp(IntPredicate::IntNE, high, zero) + }; + + return (low, overflowed); + } + }; let result = if is_add { self.add(lhs, rhs) diff --git a/tests/compiletests/ui/dis/checked_mul.rs b/tests/compiletests/ui/dis/checked_mul.rs new file mode 100644 index 00000000000..120119ca871 --- /dev/null +++ b/tests/compiletests/ui/dis/checked_mul.rs @@ -0,0 +1,22 @@ +// Verifies that checked multiplication lowers to `OpUMulExtended` / +// `OpSMulExtended` and detects overflow correctly (issue #537). + +// build-pass +// compile-flags: -C llvm-args=--disassemble-fn=checked_mul::checked_mul + +use spirv_std::spirv; + +fn checked_mul(a: u32, b: u32, c: i32, d: i32) -> u32 { + let (ur, uo) = a.overflowing_mul(b); + let (sr, so) = c.overflowing_mul(d); + ur ^ (sr as u32) ^ (uo as u32) ^ (so as u32) +} + +#[spirv(fragment)] +pub fn main( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] u_in: &[u32; 2], + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] s_in: &[i32; 2], + out: &mut u32, +) { + *out = checked_mul(u_in[0], u_in[1], s_in[0], s_in[1]); +} diff --git a/tests/compiletests/ui/dis/checked_mul.stderr b/tests/compiletests/ui/dis/checked_mul.stderr new file mode 100644 index 00000000000..74b23841068 --- /dev/null +++ b/tests/compiletests/ui/dis/checked_mul.stderr @@ -0,0 +1,32 @@ +%1 = OpFunction %2 None %3 +%4 = OpFunctionParameter %2 +%5 = OpFunctionParameter %2 +%6 = OpFunctionParameter %7 +%8 = OpFunctionParameter %7 +%9 = OpLabel +OpLine %10 1239 4 +%11 = OpUMulExtended %12 %4 %5 +%13 = OpCompositeExtract %2 %11 0 +%14 = OpCompositeExtract %2 %11 1 +%15 = OpINotEqual %16 %14 %17 +OpLine %10 390 4 +%18 = OpSMulExtended %19 %6 %8 +%20 = OpCompositeExtract %7 %18 0 +%21 = OpCompositeExtract %7 %18 1 +%22 = OpShiftRightArithmetic %7 %20 %23 +%24 = OpINotEqual %16 %21 %22 +OpLine %25 12 9 +%26 = OpBitcast %2 %20 +OpLine %25 12 4 +%27 = OpBitwiseXor %2 %13 %26 +OpLine %25 12 23 +%28 = OpSelect %2 %15 %29 %17 +OpLine %25 12 4 +%30 = OpBitwiseXor %2 %27 %28 +OpLine %25 12 37 +%31 = OpSelect %2 %24 %29 %17 +OpLine %25 12 4 +%32 = OpBitwiseXor %2 %30 %31 +OpNoLine +OpReturnValue %32 +OpFunctionEnd diff --git a/tests/compiletests/ui/lang/core/ops/checked_mul.rs b/tests/compiletests/ui/lang/core/ops/checked_mul.rs new file mode 100644 index 00000000000..f77678a31ed --- /dev/null +++ b/tests/compiletests/ui/lang/core/ops/checked_mul.rs @@ -0,0 +1,160 @@ +// Tests that checked / overflowing / unchecked multiplication compile, including +// the `unchecked_mul` precondition check that internally calls `overflowing_mul` +// (see https://github.com/Rust-GPU/rust-gpu/issues/537). + +// build-pass +// compile-flags: -C target-feature=+Int8,+Int16,+Int64 + +use spirv_std::spirv; + +#[spirv(fragment)] +pub fn checked_mul_u8( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u8, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u8, + out: &mut u32, +) { + let (r, o) = a.overflowing_mul(*b); + *out = (r as u32) | ((o as u32) << 8); + if let Some(v) = a.checked_mul(*b) { + *out ^= v as u32; + } +} + +#[spirv(fragment)] +pub fn checked_mul_u16( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u16, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u16, + out: &mut u32, +) { + let (r, o) = a.overflowing_mul(*b); + *out = (r as u32) | ((o as u32) << 16); + if let Some(v) = a.checked_mul(*b) { + *out ^= v as u32; + } +} + +#[spirv(fragment)] +pub fn checked_mul_u32( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32, + out: &mut u32, +) { + let (r, o) = a.overflowing_mul(*b); + *out = r ^ (o as u32); + if let Some(v) = a.checked_mul(*b) { + *out ^= v; + } +} + +#[spirv(fragment)] +pub fn checked_mul_u64( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u64, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u64, + out: &mut u32, +) { + let (r, o) = a.overflowing_mul(*b); + *out = (r as u32) ^ (o as u32); + if let Some(v) = a.checked_mul(*b) { + *out ^= v as u32; + } +} + +#[spirv(fragment)] +pub fn checked_mul_i8( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i8, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i8, + out: &mut u32, +) { + let (r, o) = a.overflowing_mul(*b); + *out = (r as u32) ^ ((o as u32) << 8); + if let Some(v) = a.checked_mul(*b) { + *out ^= v as u32; + } +} + +#[spirv(fragment)] +pub fn checked_mul_i16( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i16, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i16, + out: &mut u32, +) { + let (r, o) = a.overflowing_mul(*b); + *out = (r as u32) ^ ((o as u32) << 16); + if let Some(v) = a.checked_mul(*b) { + *out ^= v as u32; + } +} + +#[spirv(fragment)] +pub fn checked_mul_i32( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i32, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i32, + out: &mut u32, +) { + let (r, o) = a.overflowing_mul(*b); + *out = (r as u32) ^ (o as u32); + if let Some(v) = a.checked_mul(*b) { + *out ^= v as u32; + } +} + +#[spirv(fragment)] +pub fn checked_mul_i64( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i64, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i64, + out: &mut u32, +) { + let (r, o) = a.overflowing_mul(*b); + *out = (r as u32) ^ (o as u32); + if let Some(v) = a.checked_mul(*b) { + *out ^= v as u32; + } +} + +// Issue #537 specifically: `unchecked_mul`'s precondition check uses +// `overflowing_mul`, which previously zombied with "checked mul is not +// supported yet". +#[spirv(fragment)] +pub fn unchecked_mul_u32( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32, + out: &mut u32, +) { + *out = unsafe { a.unchecked_mul(*b) }; +} + +#[spirv(fragment)] +pub fn unchecked_mul_i32( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i32, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i32, + out: &mut u32, +) { + *out = unsafe { a.unchecked_mul(*b) } as u32; +} + +// The original issue used `usize::unchecked_mul()` (e.g. via `Layout::repeat`). +#[spirv(fragment)] +pub fn unchecked_mul_usize( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32, + out: &mut u32, +) { + let a = *a as usize; + let b = *b as usize; + *out = unsafe { a.unchecked_mul(b) } as u32; +} + +#[spirv(fragment)] +pub fn checked_mul_usize( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32, + #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32, + out: &mut u32, +) { + let a = *a as usize; + let b = *b as usize; + let (r, o) = a.overflowing_mul(b); + *out = (r as u32) ^ (o as u32); + if let Some(v) = a.checked_mul(b) { + *out ^= v as u32; + } +} diff --git a/tests/difftests/tests/Cargo.lock b/tests/difftests/tests/Cargo.lock index 7104d6799c9..bdad5be558c 100644 --- a/tests/difftests/tests/Cargo.lock +++ b/tests/difftests/tests/Cargo.lock @@ -343,6 +343,23 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "checked_mul-cpu" +version = "0.0.0" +dependencies = [ + "checked_mul-shader", + "difftest", +] + +[[package]] +name = "checked_mul-shader" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", + "spirv-std", +] + [[package]] name = "codespan-reporting" version = "0.12.0" diff --git a/tests/difftests/tests/Cargo.toml b/tests/difftests/tests/Cargo.toml index dda56eb0024..19a351c6a94 100644 --- a/tests/difftests/tests/Cargo.toml +++ b/tests/difftests/tests/Cargo.toml @@ -34,6 +34,8 @@ members = [ "lang/core/ops/matrix_ops/matrix_ops-rust", "lang/core/ops/matrix_ops/matrix_ops-wgsl", "lang/core/ops/bitwise_ops/bitwise_ops-rust", + "lang/core/ops/checked_mul/checked_mul-cpu", + "lang/core/ops/checked_mul/checked_mul-shader", "lang/core/ops/const_fold_int/const-expr-cpu", "lang/core/ops/const_fold_int/const-expr-shader", "lang/core/ops/const_fold_int/const-fold-cpu", diff --git a/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-cpu/Cargo.toml b/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-cpu/Cargo.toml new file mode 100644 index 00000000000..ac2b88f5f63 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-cpu/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "checked_mul-cpu" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +checked_mul-shader = { path = "../checked_mul-shader" } +difftest.workspace = true diff --git a/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-cpu/src/main.rs b/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-cpu/src/main.rs new file mode 100644 index 00000000000..42a1008e2ec --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-cpu/src/main.rs @@ -0,0 +1,20 @@ +use checked_mul_shader::{INPUTS, OUTPUT_LEN, PAIR_COUNT}; +use difftest::config::Config; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let mut output = vec![0u32; OUTPUT_LEN]; + for i in 0..PAIR_COUNT { + let a = INPUTS[2 * i]; + let b = INPUTS[2 * i + 1]; + let (ur, uo) = a.overflowing_mul(b); + output[4 * i] = ur; + output[4 * i + 1] = uo as u32; + let (sr, so) = (a as i32).overflowing_mul(b as i32); + output[4 * i + 2] = sr as u32; + output[4 * i + 3] = so as u32; + } + + config.write_result(&output).unwrap(); +} diff --git a/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/Cargo.toml b/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/Cargo.toml new file mode 100644 index 00000000000..efaab765953 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "checked_mul-shader" +edition.workspace = true + +[lints] +workspace = true + +# GPU deps +[dependencies] +spirv-std.workspace = true + +# CPU deps (for the test harness) +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true +bytemuck.workspace = true diff --git a/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/src/lib.rs b/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/src/lib.rs new file mode 100644 index 00000000000..333732d9c17 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/src/lib.rs @@ -0,0 +1,46 @@ +#![no_std] + +use spirv_std::spirv; + +// 8 (a, b) input pairs flattened as `[a0, b0, a1, b1, ...]`. Each pair is +// multiplied as both `u32` and `i32` (reinterpreting the bits) and the four +// resulting words are written contiguously: `[u_low, u_overflow, s_low, s_overflow]`. +// +// Cases include a mix of non-overflowing and overflowing products in both +// signed and unsigned interpretations, including signed `i32::MIN * i32::MIN` +// (which overflows). +pub const PAIR_COUNT: usize = 8; +pub const INPUT_LEN: usize = PAIR_COUNT * 2; +pub const OUTPUT_LEN: usize = PAIR_COUNT * 4; + +#[cfg(not(target_arch = "spirv"))] +#[rustfmt::skip] +pub const INPUTS: [u32; INPUT_LEN] = [ + 0, 0, + 1, 1, + 1000, 1000, + 0x0000_FFFF, 0x0000_FFFF, + 0x0001_0000, 0x0001_0000, + 0xFFFF_FFFF, 2, + 0x8000_0000, 0x8000_0000, + 0x1234_5678, 0xABCD_EF01, +]; + +#[spirv(compute(threads(1)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32; INPUT_LEN], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [u32; OUTPUT_LEN], +) { + let mut i = 0; + while i < PAIR_COUNT { + let a = input[2 * i]; + let b = input[2 * i + 1]; + let (ur, uo) = a.overflowing_mul(b); + output[4 * i] = ur; + output[4 * i + 1] = uo as u32; + let (sr, so) = (a as i32).overflowing_mul(b as i32); + output[4 * i + 2] = sr as u32; + output[4 * i + 3] = so as u32; + i += 1; + } +} diff --git a/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/src/main.rs b/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/src/main.rs new file mode 100644 index 00000000000..64c412b7b5b --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/checked_mul/checked_mul-shader/src/main.rs @@ -0,0 +1,48 @@ +#[cfg(not(target_arch = "spirv"))] +fn main() { + use checked_mul_shader::{INPUTS, OUTPUT_LEN}; + use difftest::config::Config; + use difftest::scaffold::compute::{ + AshBackend, BufferConfig, BufferUsage, ComputeShaderTest, RustComputeShader, + }; + + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Skip on macOS — Ash tests have known MoltenVK configuration issues there. + #[cfg(target_os = "macos")] + { + use difftest::scaffold::Skip; + Skip::new("Ash tests are skipped on macOS due to MoltenVK configuration issues") + .run_test(&config) + .unwrap(); + return; + } + + #[cfg(not(target_os = "macos"))] + { + let input_bytes: Vec = bytemuck::cast_slice(&INPUTS).to_vec(); + let output_size = (OUTPUT_LEN * std::mem::size_of::()) as u64; + + let buffers = vec![ + BufferConfig { + size: input_bytes.len() as u64, + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: output_size, + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + // Use the Ash backend so naga doesn't reject `OpUMulExtended` / + // `OpSMulExtended` on its way to a non-Vulkan backend. + let shader = RustComputeShader::default(); + let test = ComputeShaderTest::::new(shader, [1, 1, 1], buffers).unwrap(); + test.run_test(&config).unwrap(); + } +} + +#[cfg(target_arch = "spirv")] +fn main() {}