Skip to content

Commit 8e99c76

Browse files
committed
[const-prop] Support propagating into Assert's cond Operand
1 parent 6afcb56 commit 8e99c76

File tree

3 files changed

+81
-69
lines changed

3 files changed

+81
-69
lines changed

src/librustc_mir/transform/const_prop.rs

Lines changed: 79 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -656,75 +656,87 @@ impl<'b, 'a, 'tcx> MutVisitor<'tcx> for ConstPropagator<'b, 'a, 'tcx> {
656656
location: Location,
657657
) {
658658
self.super_terminator(terminator, location);
659-
let source_info = terminator.source_info;;
660-
if let TerminatorKind::Assert { expected, msg, cond, .. } = &terminator.kind {
661-
if let Some(value) = self.eval_operand(&cond, source_info) {
662-
trace!("assertion on {:?} should be {:?}", value, expected);
663-
let expected = ScalarMaybeUndef::from(Scalar::from_bool(*expected));
664-
if expected != self.ecx.read_scalar(value).unwrap() {
665-
// poison all places this operand references so that further code
666-
// doesn't use the invalid value
667-
match cond {
668-
Operand::Move(ref place) | Operand::Copy(ref place) => {
669-
let mut place = place;
670-
while let Place::Projection(ref proj) = *place {
671-
place = &proj.base;
672-
}
673-
if let Place::Base(PlaceBase::Local(local)) = *place {
674-
self.places[local] = None;
675-
}
676-
},
677-
Operand::Constant(_) => {}
659+
let source_info = terminator.source_info;
660+
match &mut terminator.kind {
661+
TerminatorKind::Assert { expected, msg, ref mut cond, .. } => {
662+
if let Some(value) = self.eval_operand(&cond, source_info) {
663+
trace!("assertion on {:?} should be {:?}", value, expected);
664+
let expected = ScalarMaybeUndef::from(Scalar::from_bool(*expected));
665+
let value_const = self.ecx.read_scalar(value).unwrap();
666+
if expected != value_const {
667+
// poison all places this operand references so that further code
668+
// doesn't use the invalid value
669+
match cond {
670+
Operand::Move(ref place) | Operand::Copy(ref place) => {
671+
let mut place = place;
672+
while let Place::Projection(ref proj) = *place {
673+
place = &proj.base;
674+
}
675+
if let Place::Base(PlaceBase::Local(local)) = *place {
676+
self.places[local] = None;
677+
}
678+
},
679+
Operand::Constant(_) => {}
680+
}
681+
let span = terminator.source_info.span;
682+
let hir_id = self
683+
.tcx
684+
.hir()
685+
.as_local_hir_id(self.source.def_id())
686+
.expect("some part of a failing const eval must be local");
687+
use rustc::mir::interpret::InterpError::*;
688+
let msg = match msg {
689+
Overflow(_) |
690+
OverflowNeg |
691+
DivisionByZero |
692+
RemainderByZero => msg.description().to_owned(),
693+
BoundsCheck { ref len, ref index } => {
694+
let len = self
695+
.eval_operand(len, source_info)
696+
.expect("len must be const");
697+
let len = match self.ecx.read_scalar(len) {
698+
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
699+
bits, ..
700+
})) => bits,
701+
other => bug!("const len not primitive: {:?}", other),
702+
};
703+
let index = self
704+
.eval_operand(index, source_info)
705+
.expect("index must be const");
706+
let index = match self.ecx.read_scalar(index) {
707+
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
708+
bits, ..
709+
})) => bits,
710+
other => bug!("const index not primitive: {:?}", other),
711+
};
712+
format!(
713+
"index out of bounds: \
714+
the len is {} but the index is {}",
715+
len,
716+
index,
717+
)
718+
},
719+
// Need proper const propagator for these
720+
_ => return,
721+
};
722+
self.tcx.lint_hir(
723+
::rustc::lint::builtin::CONST_ERR,
724+
hir_id,
725+
span,
726+
&msg,
727+
);
728+
} else {
729+
if let ScalarMaybeUndef::Scalar(scalar) = value_const {
730+
*cond = self.operand_from_scalar(
731+
scalar,
732+
self.tcx.types.bool,
733+
source_info.span,
734+
);
735+
}
678736
}
679-
let span = terminator.source_info.span;
680-
let hir_id = self
681-
.tcx
682-
.hir()
683-
.as_local_hir_id(self.source.def_id())
684-
.expect("some part of a failing const eval must be local");
685-
use rustc::mir::interpret::InterpError::*;
686-
let msg = match msg {
687-
Overflow(_) |
688-
OverflowNeg |
689-
DivisionByZero |
690-
RemainderByZero => msg.description().to_owned(),
691-
BoundsCheck { ref len, ref index } => {
692-
let len = self
693-
.eval_operand(len, source_info)
694-
.expect("len must be const");
695-
let len = match self.ecx.read_scalar(len) {
696-
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
697-
bits, ..
698-
})) => bits,
699-
other => bug!("const len not primitive: {:?}", other),
700-
};
701-
let index = self
702-
.eval_operand(index, source_info)
703-
.expect("index must be const");
704-
let index = match self.ecx.read_scalar(index) {
705-
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
706-
bits, ..
707-
})) => bits,
708-
other => bug!("const index not primitive: {:?}", other),
709-
};
710-
format!(
711-
"index out of bounds: \
712-
the len is {} but the index is {}",
713-
len,
714-
index,
715-
)
716-
},
717-
// Need proper const propagator for these
718-
_ => return,
719-
};
720-
self.tcx.lint_hir(
721-
::rustc::lint::builtin::CONST_ERR,
722-
hir_id,
723-
span,
724-
&msg,
725-
);
726737
}
727-
}
738+
},
739+
_ => {}
728740
}
729741
}
730742
}

src/test/mir-opt/const_prop/array_index.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ fn main() {
2323
// bb0: {
2424
// ...
2525
// _5 = const true;
26-
// assert(move _5, "index out of bounds: the len is move _4 but the index is _3") -> bb1;
26+
// assert(const true, "index out of bounds: the len is move _4 but the index is _3") -> bb1;
2727
// }
2828
// bb1: {
2929
// _1 = _2[_3];

src/test/mir-opt/const_prop/checked_add.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ fn main() {
1616
// bb0: {
1717
// ...
1818
// _2 = (const 2u32, const false);
19-
// assert(!move (_2.1: bool), "attempt to add with overflow") -> bb1;
19+
// assert(!const false, "attempt to add with overflow") -> bb1;
2020
// }
2121
// END rustc.main.ConstProp.after.mir

0 commit comments

Comments
 (0)