Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 62 additions & 14 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1726,20 +1726,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
rhs: Self::Value,
) -> (Self::Value, Self::Value) {
// adopted partially from https://github.com/ziglang/zig/blob/master/src/codegen/spirv.zig
let is_add = match oop {
OverflowOp::Add => true,
OverflowOp::Sub => false,
OverflowOp::Mul => {
// NOTE(eddyb) this needs to be `undef`, not `false`/`true`, because
// we don't want the user's boolean constants to keep the zombie alive.
let bool = SpirvType::Bool.def(self.span(), self);
let overflowed = self.undef(bool);

let result = (self.mul(lhs, rhs), overflowed);
self.zombie(result.1.def(self), "checked mul is not supported yet");
return result;
}
};
let signed = match ty.kind() {
ty::Int(_) => true,
ty::Uint(_) => false,
Expand All @@ -1752,6 +1738,68 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
)),
};
let is_add = match oop {
OverflowOp::Add => true,
OverflowOp::Sub => false,
OverflowOp::Mul => {
let int_ty = lhs.ty;
let bits = match self.lookup_type(int_ty) {
SpirvType::Integer(width, _) => width,
other => self.fatal(format!(
"checked mul on non-integer type: {}",
other.debug(int_ty, self)
)),
};

// OpUMulExtended / OpSMulExtended produce an OpTypeStruct{T, T}
// holding the low and high halves of the full 2*bits-wide product.
// The struct is purely intermediate — we extract both halves
// immediately and never expose it to Rust code.
let pair_ty = SpirvType::Adt {
def_id: None,
size: None,
align: Align::from_bytes(0).unwrap(),
field_types: &[int_ty, int_ty],
field_offsets: &[],
field_names: None,
}
.def(self.span(), self);

let extended = if signed {
self.emit()
.s_mul_extended(pair_ty, None, lhs.def(self), rhs.def(self))
} else {
self.emit()
.u_mul_extended(pair_ty, None, lhs.def(self), rhs.def(self))
}
.unwrap();

let low = self
.emit()
.composite_extract(int_ty, None, extended, [0].iter().cloned())
.unwrap()
.with_type(int_ty);
let high = self
.emit()
.composite_extract(int_ty, None, extended, [1].iter().cloned())
.unwrap()
.with_type(int_ty);

let overflowed = if signed {
// For signed multiplication, no overflow occurs iff the high
// half is the sign extension of the low half, i.e. the
// arithmetic-shift of `low` by `bits-1` (replicating the MSB).
let shift_amount = self.constant_int(int_ty, u128::from(bits - 1));
let expected_high = self.ashr(low, shift_amount);
self.icmp(IntPredicate::IntNE, high, expected_high)
} else {
let zero = self.constant_int(int_ty, 0);
self.icmp(IntPredicate::IntNE, high, zero)
};

return (low, overflowed);
}
};

let result = if is_add {
self.add(lhs, rhs)
Expand Down
22 changes: 22 additions & 0 deletions tests/compiletests/ui/dis/checked_mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Verifies that checked multiplication lowers to `OpUMulExtended` /
// `OpSMulExtended` and detects overflow correctly (issue #537).

// build-pass
// compile-flags: -C llvm-args=--disassemble-fn=checked_mul::checked_mul

use spirv_std::spirv;

fn checked_mul(a: u32, b: u32, c: i32, d: i32) -> u32 {
let (ur, uo) = a.overflowing_mul(b);
let (sr, so) = c.overflowing_mul(d);
ur ^ (sr as u32) ^ (uo as u32) ^ (so as u32)
}

#[spirv(fragment)]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] u_in: &[u32; 2],
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] s_in: &[i32; 2],
out: &mut u32,
) {
*out = checked_mul(u_in[0], u_in[1], s_in[0], s_in[1]);
}
32 changes: 32 additions & 0 deletions tests/compiletests/ui/dis/checked_mul.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
%1 = OpFunction %2 None %3
%4 = OpFunctionParameter %2
%5 = OpFunctionParameter %2
%6 = OpFunctionParameter %7
%8 = OpFunctionParameter %7
%9 = OpLabel
OpLine %10 1239 4
%11 = OpUMulExtended %12 %4 %5
%13 = OpCompositeExtract %2 %11 0
%14 = OpCompositeExtract %2 %11 1
%15 = OpINotEqual %16 %14 %17
OpLine %10 390 4
%18 = OpSMulExtended %19 %6 %8
%20 = OpCompositeExtract %7 %18 0
%21 = OpCompositeExtract %7 %18 1
%22 = OpShiftRightArithmetic %7 %20 %23
%24 = OpINotEqual %16 %21 %22
OpLine %25 12 9
%26 = OpBitcast %2 %20
OpLine %25 12 4
%27 = OpBitwiseXor %2 %13 %26
OpLine %25 12 23
%28 = OpSelect %2 %15 %29 %17
OpLine %25 12 4
%30 = OpBitwiseXor %2 %27 %28
OpLine %25 12 37
%31 = OpSelect %2 %24 %29 %17
OpLine %25 12 4
%32 = OpBitwiseXor %2 %30 %31
OpNoLine
OpReturnValue %32
OpFunctionEnd
160 changes: 160 additions & 0 deletions tests/compiletests/ui/lang/core/ops/checked_mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Tests that checked / overflowing / unchecked multiplication compile, including
// the `unchecked_mul` precondition check that internally calls `overflowing_mul`
// (see https://github.com/Rust-GPU/rust-gpu/issues/537).

// build-pass
// compile-flags: -C target-feature=+Int8,+Int16,+Int64

use spirv_std::spirv;

#[spirv(fragment)]
pub fn checked_mul_u8(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u8,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u8,
out: &mut u32,
) {
let (r, o) = a.overflowing_mul(*b);
*out = (r as u32) | ((o as u32) << 8);
if let Some(v) = a.checked_mul(*b) {
*out ^= v as u32;
}
}

#[spirv(fragment)]
pub fn checked_mul_u16(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u16,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u16,
out: &mut u32,
) {
let (r, o) = a.overflowing_mul(*b);
*out = (r as u32) | ((o as u32) << 16);
if let Some(v) = a.checked_mul(*b) {
*out ^= v as u32;
}
}

#[spirv(fragment)]
pub fn checked_mul_u32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32,
out: &mut u32,
) {
let (r, o) = a.overflowing_mul(*b);
*out = r ^ (o as u32);
if let Some(v) = a.checked_mul(*b) {
*out ^= v;
}
}

#[spirv(fragment)]
pub fn checked_mul_u64(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u64,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u64,
out: &mut u32,
) {
let (r, o) = a.overflowing_mul(*b);
*out = (r as u32) ^ (o as u32);
if let Some(v) = a.checked_mul(*b) {
*out ^= v as u32;
}
}

#[spirv(fragment)]
pub fn checked_mul_i8(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i8,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i8,
out: &mut u32,
) {
let (r, o) = a.overflowing_mul(*b);
*out = (r as u32) ^ ((o as u32) << 8);
if let Some(v) = a.checked_mul(*b) {
*out ^= v as u32;
}
}

#[spirv(fragment)]
pub fn checked_mul_i16(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i16,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i16,
out: &mut u32,
) {
let (r, o) = a.overflowing_mul(*b);
*out = (r as u32) ^ ((o as u32) << 16);
if let Some(v) = a.checked_mul(*b) {
*out ^= v as u32;
}
}

#[spirv(fragment)]
pub fn checked_mul_i32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i32,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i32,
out: &mut u32,
) {
let (r, o) = a.overflowing_mul(*b);
*out = (r as u32) ^ (o as u32);
if let Some(v) = a.checked_mul(*b) {
*out ^= v as u32;
}
}

#[spirv(fragment)]
pub fn checked_mul_i64(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i64,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i64,
out: &mut u32,
) {
let (r, o) = a.overflowing_mul(*b);
*out = (r as u32) ^ (o as u32);
if let Some(v) = a.checked_mul(*b) {
*out ^= v as u32;
}
}

// Issue #537 specifically: `unchecked_mul`'s precondition check uses
// `overflowing_mul`, which previously zombied with "checked mul is not
// supported yet".
#[spirv(fragment)]
pub fn unchecked_mul_u32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32,
out: &mut u32,
) {
*out = unsafe { a.unchecked_mul(*b) };
}

#[spirv(fragment)]
pub fn unchecked_mul_i32(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i32,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i32,
out: &mut u32,
) {
*out = unsafe { a.unchecked_mul(*b) } as u32;
}

// The original issue used `usize::unchecked_mul()` (e.g. via `Layout::repeat`).
#[spirv(fragment)]
pub fn unchecked_mul_usize(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32,
out: &mut u32,
) {
let a = *a as usize;
let b = *b as usize;
*out = unsafe { a.unchecked_mul(b) } as u32;
}

#[spirv(fragment)]
pub fn checked_mul_usize(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32,
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32,
out: &mut u32,
) {
let a = *a as usize;
let b = *b as usize;
let (r, o) = a.overflowing_mul(b);
*out = (r as u32) ^ (o as u32);
if let Some(v) = a.checked_mul(b) {
*out ^= v as u32;
}
}
17 changes: 17 additions & 0 deletions tests/difftests/tests/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions tests/difftests/tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ members = [
"lang/core/ops/matrix_ops/matrix_ops-rust",
"lang/core/ops/matrix_ops/matrix_ops-wgsl",
"lang/core/ops/bitwise_ops/bitwise_ops-rust",
"lang/core/ops/checked_mul/checked_mul-cpu",
"lang/core/ops/checked_mul/checked_mul-shader",
"lang/core/ops/const_fold_int/const-expr-cpu",
"lang/core/ops/const_fold_int/const-expr-shader",
"lang/core/ops/const_fold_int/const-fold-cpu",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "checked_mul-cpu"
edition.workspace = true

[lints]
workspace = true

[dependencies]
checked_mul-shader = { path = "../checked_mul-shader" }
difftest.workspace = true
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use checked_mul_shader::{INPUTS, OUTPUT_LEN, PAIR_COUNT};
use difftest::config::Config;

fn main() {
let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap();

let mut output = vec![0u32; OUTPUT_LEN];
for i in 0..PAIR_COUNT {
let a = INPUTS[2 * i];
let b = INPUTS[2 * i + 1];
let (ur, uo) = a.overflowing_mul(b);
output[4 * i] = ur;
output[4 * i + 1] = uo as u32;
let (sr, so) = (a as i32).overflowing_mul(b as i32);
output[4 * i + 2] = sr as u32;
output[4 * i + 3] = so as u32;
}

config.write_result(&output).unwrap();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "checked_mul-shader"
edition.workspace = true

[lints]
workspace = true

# GPU deps
[dependencies]
spirv-std.workspace = true

# CPU deps (for the test harness)
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
difftest.workspace = true
bytemuck.workspace = true
Loading
Loading