Skip to content

Commit 68a69d8

Browse files
committed
fix 128bits cttz intrinsic UB
1 parent 3b0e6af commit 68a69d8

File tree

1 file changed

+137
-110
lines changed

1 file changed

+137
-110
lines changed

‎src/intrinsic/mod.rs

Lines changed: 137 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -410,39 +410,14 @@ impl<'a, 'gcc, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'gcc, 'tc
410410
| sym::saturating_sub => {
411411
match int_type_width_signed(args[0].layout.ty, self) {
412412
Some((width, signed)) => match name {
413-
sym::cttz => {
414-
let func = self.current_func();
415-
let then_block = func.new_block("then");
416-
let else_block = func.new_block("else");
417-
let after_block = func.new_block("after");
418-
419-
let arg = args[0].immediate();
420-
let result = func.new_local(None, self.u32_type, "zeros");
421-
let zero = self.cx.gcc_zero(arg.get_type());
422-
let cond = self.gcc_icmp(IntPredicate::IntEQ, arg, zero);
423-
self.llbb().end_with_conditional(None, cond, then_block, else_block);
424-
425-
let zero_result = self.cx.gcc_uint(self.u32_type, width);
426-
then_block.add_assignment(None, result, zero_result);
427-
then_block.end_with_jump(None, after_block);
428-
429-
self.switch_to_block(else_block);
430-
431-
let zeros = self.count_trailing_zeroes(width, arg);
432-
self.llbb().add_assignment(None, result, zeros);
433-
self.llbb().end_with_jump(None, after_block);
434-
435-
// NOTE: since jumps were added in a place rustc does not
436-
// expect, the current block in the state need to be updated.
437-
self.switch_to_block(after_block);
438-
439-
result.to_rvalue()
440-
}
441413
sym::ctlz => self.count_leading_zeroes(width, args[0].immediate()),
442414
sym::ctlz_nonzero => {
443415
self.count_leading_zeroes_nonzero(width, args[0].immediate())
444416
}
445-
sym::cttz_nonzero => self.count_trailing_zeroes(width, args[0].immediate()),
417+
sym::cttz => self.count_trailing_zeroes(width, args[0].immediate()),
418+
sym::cttz_nonzero => {
419+
self.count_trailing_zeroes_nonzero(width, args[0].immediate())
420+
}
446421
sym::ctpop => self.pop_count(args[0].immediate()),
447422
sym::bswap => {
448423
if width == 8 {
@@ -882,6 +857,8 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
882857
}
883858

884859
fn count_leading_zeroes(&mut self, width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
860+
// __buildin_clz* functions are UB when called with 0, handle that special case before using them
861+
885862
// TODO(antoyo): use width?
886863
let result_type = self.u32_type;
887864
let result = self.current_func().new_local(None, result_type, "zeros");
@@ -937,9 +914,8 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
937914
}
938915
res
939916
} else if width == 128 {
940-
// if arg is 0 return 128
941-
// else if the 64 high bits of arg are not 0, return clzll(64 high bits of arg)
942-
// else return 64 + clzll(64 low bits of arg)
917+
// __buildin_clzll is UB when called with 0, call it on the 64 high bits if they are not 0,
918+
// else call it on the 64 low bits and add 64. In the else case, 64 low bits can't be 0 because arg is not 0.
943919

944920
let ctlz_then_block = self.current_func().new_block("ctlz_then");
945921
let ctlz_else_block = self.current_func().new_block("ctlz_else");
@@ -994,7 +970,46 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
994970
}
995971
}
996972

997-
fn count_trailing_zeroes(&mut self, _width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
973+
fn count_trailing_zeroes(&mut self, width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
974+
// __buildin_ctz* functions are UB when called with 0, handle that special case before using them
975+
976+
// TODO(antoyo): use width?
977+
let result_type = self.u32_type;
978+
let result = self.current_func().new_local(None, result_type, "zeros");
979+
980+
let then_block = self.current_func().new_block("then");
981+
let else_block = self.current_func().new_block("else");
982+
let after_block = self.current_func().new_block("after");
983+
984+
let zero = self.cx.const_uint(arg.get_type(), 0);
985+
let cond = self.gcc_icmp(IntPredicate::IntEQ, arg, zero);
986+
self.llbb().end_with_conditional(None, cond, then_block, else_block);
987+
988+
let zero_result = self.cx.gcc_uint(self.u32_type, width);
989+
then_block.add_assignment(None, result, zero_result);
990+
then_block.end_with_jump(None, after_block);
991+
self.switch_to_block(else_block);
992+
993+
self.count_trailing_zeroes_nonzero_impl(
994+
width,
995+
arg,
996+
Some((result, else_block, after_block)),
997+
);
998+
// else_block.end_with_jump(None, after_block);
999+
self.switch_to_block(after_block);
1000+
result.to_rvalue()
1001+
}
1002+
1003+
fn count_trailing_zeroes_nonzero(&mut self, width: u64, arg: RValue<'gcc>) -> RValue<'gcc> {
1004+
self.count_trailing_zeroes_nonzero_impl(width, arg, None)
1005+
}
1006+
1007+
fn count_trailing_zeroes_nonzero_impl(
1008+
&mut self,
1009+
_width: u64,
1010+
arg: RValue<'gcc>,
1011+
block: Option<(LValue<'gcc>, Block<'gcc>, Block<'gcc>)>,
1012+
) -> RValue<'gcc> {
9981013
let arg_type = arg.get_type();
9991014
let result_type = self.u32_type;
10001015
let arg = if arg_type.is_signed(self.cx) {
@@ -1004,86 +1019,98 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
10041019
arg
10051020
};
10061021
let arg_type = arg.get_type();
1007-
let (count_trailing_zeroes, expected_type) =
1008-
// TODO(antoyo): write a new function Type::is_compatible_with(&Type) and use it here
1009-
// instead of using is_uint().
1010-
if arg_type.is_uchar(self.cx) || arg_type.is_ushort(self.cx) || arg_type.is_uint(self.cx) {
1011-
// NOTE: we don't need to & 0xFF for uchar because the result is undefined on zero.
1012-
("__builtin_ctz", self.cx.uint_type)
1013-
}
1014-
else if arg_type.is_ulong(self.cx) {
1015-
("__builtin_ctzl", self.cx.ulong_type)
1016-
}
1017-
else if arg_type.is_ulonglong(self.cx) {
1018-
("__builtin_ctzll", self.cx.ulonglong_type)
1019-
}
1020-
else if arg_type.is_u128(self.cx) {
1021-
// Adapted from the algorithm to count leading zeroes from: https://stackoverflow.com/a/28433850/389119
1022-
let array_type = self.context.new_array_type(None, arg_type, 3);
1023-
let result = self.current_func()
1024-
.new_local(None, array_type, "count_loading_zeroes_results");
1025-
1026-
let sixty_four = self.gcc_int(arg_type, 64);
1027-
let shift = self.gcc_lshr(arg, sixty_four);
1028-
let high = self.gcc_int_cast(shift, self.u64_type);
1029-
let low = self.gcc_int_cast(arg, self.u64_type);
1030-
1031-
let zero = self.context.new_rvalue_zero(self.usize_type);
1032-
let one = self.context.new_rvalue_one(self.usize_type);
1033-
let two = self.context.new_rvalue_from_long(self.usize_type, 2);
1034-
1035-
let ctzll = self.context.get_builtin_function("__builtin_ctzll");
1036-
1037-
let first_elem = self.context.new_array_access(self.location, result, zero);
1038-
let first_value = self.gcc_int_cast(self.context.new_call(self.location, ctzll, &[low]), arg_type);
1039-
self.llbb()
1040-
.add_assignment(self.location, first_elem, first_value);
1041-
1042-
let second_elem = self.context.new_array_access(self.location, result, one);
1043-
let second_value = self.gcc_add(self.gcc_int_cast(self.context.new_call(self.location, ctzll, &[high]), arg_type), sixty_four);
1044-
self.llbb()
1045-
.add_assignment(self.location, second_elem, second_value);
1046-
1047-
let third_elem = self.context.new_array_access(self.location, result, two);
1048-
let third_value = self.gcc_int(arg_type, 128);
1049-
self.llbb()
1050-
.add_assignment(self.location, third_elem, third_value);
1051-
1052-
let not_low = self.context.new_unary_op(self.location, UnaryOp::LogicalNegate, self.u64_type, low);
1053-
let not_high = self.context.new_unary_op(self.location, UnaryOp::LogicalNegate, self.u64_type, high);
1054-
let not_low_and_not_high = not_low & not_high;
1055-
let index = not_low + not_low_and_not_high;
1056-
// NOTE: the following cast is necessary to avoid a GIMPLE verification failure in
1057-
// gcc.
1058-
// TODO(antoyo): do the correct verification in libgccjit to avoid an error at the
1059-
// compilation stage.
1060-
let index = self.context.new_cast(self.location, index, self.i32_type);
1061-
1062-
let res = self.context.new_array_access(self.location, result, index);
1063-
1064-
return self.gcc_int_cast(res.to_rvalue(), result_type);
1022+
// TODO(antoyo): write a new function Type::is_compatible_with(&Type) and use it here
1023+
// instead of using is_uint().
1024+
let func_params = if arg_type.is_uchar(self.cx)
1025+
|| arg_type.is_ushort(self.cx)
1026+
|| arg_type.is_uint(self.cx)
1027+
{
1028+
// NOTE: we don't need to & 0xFF for uchar because the result is undefined on zero.
1029+
Some(("__builtin_ctz", self.cx.uint_type))
1030+
} else if arg_type.is_ulong(self.cx) {
1031+
Some(("__builtin_ctzl", self.cx.ulong_type))
1032+
} else if arg_type.is_ulonglong(self.cx) {
1033+
Some(("__builtin_ctzll", self.cx.ulonglong_type))
1034+
} else {
1035+
None
1036+
};
1037+
if let Some((count_trailing_zeroes, expected_type)) = func_params {
1038+
let count_trailing_zeroes = self.context.get_builtin_function(count_trailing_zeroes);
1039+
let arg = if arg_type != expected_type {
1040+
self.context.new_cast(self.location, arg, expected_type)
1041+
} else {
1042+
arg
1043+
};
1044+
let res = self.context.new_call(self.location, count_trailing_zeroes, &[arg]);
1045+
let res = self.context.new_cast(self.location, res, result_type);
1046+
if let Some((result, else_block, after_block)) = block {
1047+
else_block.add_assignment(None, result, res);
1048+
else_block.end_with_jump(None, after_block);
10651049
}
1066-
else {
1067-
let count_trailing_zeroes = self.context.get_builtin_function("__builtin_ctzll");
1068-
let arg_size = arg_type.get_size();
1069-
let casted_arg = self.context.new_cast(self.location, arg, self.ulonglong_type);
1070-
let byte_diff = self.ulonglong_type.get_size() as i64 - arg_size as i64;
1071-
let diff = self.context.new_rvalue_from_long(self.int_type, byte_diff * 8);
1072-
let mask = self.context.new_rvalue_from_long(arg_type, -1); // To get the value with all bits set.
1073-
let masked = mask & self.context.new_unary_op(self.location, UnaryOp::BitwiseNegate, arg_type, arg);
1074-
let cond = self.context.new_comparison(self.location, ComparisonOp::Equals, masked, mask);
1075-
let diff = diff * self.context.new_cast(self.location, cond, self.int_type);
1076-
let res = self.context.new_call(self.location, count_trailing_zeroes, &[casted_arg]) - diff;
1077-
return self.context.new_cast(self.location, res, result_type);
1050+
res
1051+
} else if arg_type.is_u128(self.cx) {
1052+
// __buildin_ctzll is UB when called with 0, call it on the 64 low bits if they are not 0,
1053+
// else call it on the 64 high bits and add 64. In the else case, 64 high bits can't be 0 because arg is not 0.
1054+
1055+
let cttz_then_block = self.current_func().new_block("cttz_then");
1056+
let cttz_else_block = self.current_func().new_block("cttz_else");
1057+
let (result, block, after_block) = if let Some(block) = block {
1058+
block
1059+
} else {
1060+
(
1061+
self.current_func().new_local(None, result_type, "zeros"),
1062+
self.llbb(),
1063+
self.current_func().new_block("cttz_after"),
1064+
)
10781065
};
1079-
let count_trailing_zeroes = self.context.get_builtin_function(count_trailing_zeroes);
1080-
let arg = if arg_type != expected_type {
1081-
self.context.new_cast(self.location, arg, expected_type)
1066+
let low = self.gcc_int_cast(arg, self.u64_type);
1067+
1068+
let ctzll = self.context.get_builtin_function("__builtin_ctzll");
1069+
1070+
let zero = self.const_uint(low.get_type(), 0);
1071+
let cond = self.gcc_icmp(IntPredicate::IntNE, low, zero);
1072+
block.end_with_conditional(self.location, cond, cttz_then_block, cttz_else_block);
1073+
1074+
let trailing_zeroes =
1075+
self.gcc_int_cast(self.context.new_call(None, ctzll, &[low]), result_type);
1076+
1077+
cttz_then_block.add_assignment(None, result, trailing_zeroes);
1078+
cttz_then_block.end_with_jump(None, after_block);
1079+
self.switch_to_block(cttz_else_block);
1080+
1081+
let sixty_four = self.const_uint(arg_type, 64);
1082+
let shift = self.lshr(arg, sixty_four);
1083+
let high = self.gcc_int_cast(shift, self.u64_type);
1084+
let high_trailing_zeroes =
1085+
self.gcc_int_cast(self.context.new_call(None, ctzll, &[high]), result_type);
1086+
let sixty_four_result_type = self.const_uint(result_type, 64);
1087+
let trailing_zeroes = self.add(high_trailing_zeroes, sixty_four_result_type);
1088+
cttz_else_block.add_assignment(None, result, trailing_zeroes);
1089+
cttz_else_block.end_with_jump(None, after_block);
1090+
self.switch_to_block(after_block);
1091+
1092+
result.to_rvalue()
10821093
} else {
1083-
arg
1084-
};
1085-
let res = self.context.new_call(self.location, count_trailing_zeroes, &[arg]);
1086-
self.context.new_cast(self.location, res, result_type)
1094+
let count_trailing_zeroes = self.context.get_builtin_function("__builtin_ctzll");
1095+
let arg_size = arg_type.get_size();
1096+
let casted_arg = self.context.new_cast(self.location, arg, self.ulonglong_type);
1097+
let byte_diff = self.ulonglong_type.get_size() as i64 - arg_size as i64;
1098+
let diff = self.context.new_rvalue_from_long(self.int_type, byte_diff * 8);
1099+
let mask = self.context.new_rvalue_from_long(arg_type, -1); // To get the value with all bits set.
1100+
let masked = mask
1101+
& self.context.new_unary_op(self.location, UnaryOp::BitwiseNegate, arg_type, arg);
1102+
let cond =
1103+
self.context.new_comparison(self.location, ComparisonOp::Equals, masked, mask);
1104+
let diff = diff * self.context.new_cast(self.location, cond, self.int_type);
1105+
let res =
1106+
self.context.new_call(self.location, count_trailing_zeroes, &[casted_arg]) - diff;
1107+
let res = self.context.new_cast(self.location, res, result_type);
1108+
if let Some((result, else_block, after_block)) = block {
1109+
else_block.add_assignment(None, result, res);
1110+
else_block.end_with_jump(None, after_block);
1111+
}
1112+
res
1113+
}
10871114
}
10881115

10891116
fn pop_count(&mut self, value: RValue<'gcc>) -> RValue<'gcc> {

0 commit comments

Comments
 (0)