Skip to content

Commit ddda937

Browse files
authored
Rollup merge of #143273 - 1c3t3a:enum-check-negative, r=SparrowLii
Make the enum check work for negative discriminants The discriminant check was not working correctly for negative numbers. This change fixes that by masking out the relevant bits correctly. Fixes #143218.
2 parents c3e3f43 + 8a0d8dd commit ddda937

File tree

3 files changed

+93
-4
lines changed

3 files changed

+93
-4
lines changed

‎compiler/rustc_mir_transform/src/check_enums.rs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ enum EnumCheckType<'tcx> {
120120
},
121121
}
122122

123+
#[derive(Debug, Copy, Clone)]
123124
struct TyAndSize<'tcx> {
124125
pub ty: Ty<'tcx>,
125126
pub size: Size,
@@ -337,7 +338,7 @@ fn insert_direct_enum_check<'tcx>(
337338
let invalid_discr_block_data = BasicBlockData::new(None, false);
338339
let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
339340
let block_data = &mut basic_blocks[current_block];
340-
let discr = insert_discr_cast_to_u128(
341+
let discr_place = insert_discr_cast_to_u128(
341342
tcx,
342343
local_decls,
343344
block_data,
@@ -348,13 +349,34 @@ fn insert_direct_enum_check<'tcx>(
348349
source_info,
349350
);
350351

352+
// Mask out the bits of the discriminant type.
353+
let mask = discr.size.unsigned_int_max();
354+
let discr_masked =
355+
local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
356+
let rvalue = Rvalue::BinaryOp(
357+
BinOp::BitAnd,
358+
Box::new((
359+
Operand::Copy(discr_place),
360+
Operand::Constant(Box::new(ConstOperand {
361+
span: source_info.span,
362+
user_ty: None,
363+
const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
364+
})),
365+
)),
366+
);
367+
block_data
368+
.statements
369+
.push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue)))));
370+
351371
// Branch based on the discriminant value.
352372
block_data.terminator = Some(Terminator {
353373
source_info,
354374
kind: TerminatorKind::SwitchInt {
355-
discr: Operand::Copy(discr),
375+
discr: Operand::Copy(discr_masked),
356376
targets: SwitchTargets::new(
357-
discriminants.into_iter().map(|discr| (discr, new_block)),
377+
discriminants
378+
.into_iter()
379+
.map(|discr_val| (discr.size.truncate(discr_val), new_block)),
358380
invalid_discr_block,
359381
),
360382
},
@@ -371,7 +393,7 @@ fn insert_direct_enum_check<'tcx>(
371393
})),
372394
expected: true,
373395
target: new_block,
374-
msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
396+
msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
375397
// This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
376398
// We never want to insert an unwind into unsafe code, because unwinding could
377399
// make a failing UB check turn into much worse UB when we start unwinding.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//@ run-fail
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: trying to construct an enum from an invalid value 0xfd
4+
5+
#[allow(dead_code)]
6+
enum Foo {
7+
A = -2,
8+
B = -1,
9+
C = 1,
10+
}
11+
12+
fn main() {
13+
let _val: Foo = unsafe { std::mem::transmute::<i8, Foo>(-3) };
14+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//@ run-pass
2+
//@ compile-flags: -C debug-assertions
3+
4+
#[allow(dead_code)]
5+
#[derive(Debug, PartialEq)]
6+
enum Foo {
7+
A = -12121,
8+
B = -2,
9+
C = -1,
10+
D = 1,
11+
E = 2,
12+
F = 12121,
13+
}
14+
15+
#[allow(dead_code)]
16+
#[repr(i64)]
17+
#[derive(Debug, PartialEq)]
18+
enum Bar {
19+
A = i64::MIN,
20+
B = -2,
21+
C = -1,
22+
D = 1,
23+
E = 2,
24+
F = i64::MAX,
25+
}
26+
27+
fn main() {
28+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-12121) };
29+
assert_eq!(val, Foo::A);
30+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-2) };
31+
assert_eq!(val, Foo::B);
32+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-1) };
33+
assert_eq!(val, Foo::C);
34+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(1) };
35+
assert_eq!(val, Foo::D);
36+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(2) };
37+
assert_eq!(val, Foo::E);
38+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(12121) };
39+
assert_eq!(val, Foo::F);
40+
41+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(i64::MIN) };
42+
assert_eq!(val, Bar::A);
43+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(-2) };
44+
assert_eq!(val, Bar::B);
45+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(-1) };
46+
assert_eq!(val, Bar::C);
47+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(1) };
48+
assert_eq!(val, Bar::D);
49+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(2) };
50+
assert_eq!(val, Bar::E);
51+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(i64::MAX) };
52+
assert_eq!(val, Bar::F);
53+
}

0 commit comments

Comments
 (0)