Description
Godbolt link to the code below: https://rust.godbolt.org/z/aKf3Wq
pub fn foo1(x: &[u32], y: &[u32]) -> u32 {
let mut sum = 0;
let chunk_size = y.len();
for (c, y) in y.iter().enumerate() {
for chunk in x.chunks_exact(chunk_size) {
sum += chunk[c] + y;
}
}
sum
}
This code has a bounds check for chunk[c]
although c < chunk_size
by construction.
The same code a bit more convoluted gets rid of the bounds check
pub fn foo2(x: &[u32], y: &[u32]) -> u32 {
let mut sum = 0;
let chunk_size = y.len();
for c in 0..chunk_size {
let y = y[c];
for chunk in x.chunks_exact(chunk_size) {
sum += chunk[c] + y;
}
}
sum
}
It seems like the information that 0 <= c < y.len()
gets lost for the optimizer when going via y.iter().enumerate()
. So this is unrelated to chunks_exact()
specifically but I can't come up with an equivalent example without it.
edit: As noticed in #75935 (comment), this can be worked around by defining a custom slice iterator that does counting of elements instead of working with an end pointer.
The problem is that the slice::iter()
works with an end pointer to know when the iteration can stop and keeps no information around for the optimizer that it's actually going to iterate exactly N times. Unclear to me how this information can be preserved without changing how the iterator works, which will probably have other negative effects.
edit2:
As noticed in #77822 this happens with C++/C too and can also simplified a lot on the Rust side
pub fn foo(y: &[u32]) {
let mut x = 0;
for (c, _y) in y.iter().enumerate() {
assert!(c < y.len());
x = c;
}
assert!(x == y.len());
}
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <vector>
void foo1(const uint32_t *y, size_t y_len) {
const uint32_t *y_end = y + y_len;
size_t c = 0;
for (const uint32_t *y_iter = y; y_iter != y_end; y_iter++, c++) {
assert(c < y_len);
}
assert(c == y_len);
}
void foo2(const std::vector<uint32_t>& y) {
size_t c = 0;
for (auto y_iter: y) {
assert(c < y.size());
c++;
}
assert(c == y.size());
}
void foo3(const std::vector<uint32_t>& y) {
size_t c = 0;
for (auto y_iter = y.cbegin(); y_iter != y.cend(); y_iter++, c++) {
assert(c < y.size());
}
assert(c == y.size());
}
edit3: This is now also reported to https://bugs.llvm.org/show_bug.cgi?id=48965