Skip to content

Commit c4b8770

Browse files
authored
Improve performance of byte_pair_merge (#31)
The improvements to `byte_pair_merge` are: - Changing the `parts` vector to avoid repetition of data. This vector used to store ranges for which the invariant `parts[i].end == parts[i + 1].start` holds, which makes the vector twice as big as it needs to be. Keeping this vector small improves CPU-cache efficiency. - Using `usize::MAX` as a sentinel in lieu of `Optional` for the computation of the minimum rank. This change removes branching from the loop to compute the minimum rank, generating assembly that uses conditional moves instead. Ideally, we could keep the `Optional` and inform it of the sentinel much like `Optional<NonZeroUsize>`. As far as I could tell, specifying custom sentinels for `Optional` has an old Rust [RFC](rust-lang/rfcs#41) that has stalled, so we don't get to have nice things. - Minimizing the number of lookups into `ranks` by looking up ranks once and iteratively updating them after each merge. This reduces the number of rank lookups from `n*m` to `n + O(m)`
1 parent 7830ed5 commit c4b8770

File tree

1 file changed

+79
-24
lines changed

1 file changed

+79
-24
lines changed

src/lib.rs

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,58 @@ use pyo3::types::{PyBytes, PyList, PyTuple};
1111
use pyo3::PyResult;
1212
use rustc_hash::FxHashMap as HashMap;
1313

14-
fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
15-
let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
14+
fn _byte_pair_merge<T>(
15+
piece: &[u8],
16+
ranks: &HashMap<Vec<u8>, usize>,
17+
f: impl Fn(std::ops::Range<usize>) -> T,
18+
) -> Vec<T> {
19+
// This is a vector of (start, rank).
20+
// The rank is of the byte pair starting at position start.
21+
// The rank of the last item in the vector is not a valid value.
22+
let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect();
23+
24+
// NOTE: using a macro here because a closure fails to get inlined
25+
// according to optimization remarks.
26+
// A closure also cannot capture a reference to `piece` without
27+
// the borrow checker complaining about the mutable borrows during
28+
// the assignments later in this code.
29+
macro_rules! get_rank {
30+
($start_idx:expr, $skip:expr) => {{
31+
let start_idx: usize = $start_idx;
32+
let skip: usize = $skip;
33+
if (start_idx + skip + 2) < parts.len() {
34+
ranks
35+
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
36+
.map(|r| *r)
37+
} else {
38+
None
39+
}
40+
}};
41+
($idx:expr) => {{
42+
get_rank!($idx, 0)
43+
}};
44+
}
45+
46+
// We look up the ranks once in the beggining and iteratively update
47+
// them during each merge, which reduces the number of rank lookups.
48+
for i in 0..parts.len() - 2 {
49+
match get_rank!(i) {
50+
Some(rank) => {
51+
// usize::MAX is a sentinel value and cannot be a valid rank
52+
debug_assert!(rank != usize::MAX);
53+
parts[i].1 = rank;
54+
}
55+
None => {
56+
continue;
57+
}
58+
};
59+
}
1660

17-
// If you have n parts and m merges, this does O(mn) work
18-
// We could do something with a heap and do O(m log n) work
61+
// If you have n parts and m merges, this does O(mn) work.
62+
// We could do something with a heap and do O(m log n) work.
63+
// It is important to consider that n is often small (<100), and as such
64+
// the cache-locality benefits outweigh the algorithmic complexity downsides
65+
// of the `parts` vector data structure above.
1966

2067
// Note that we hash bytes, not token pairs. As long as we train BPE the way we
2168
// currently do, this is equivalent. An easy way to break this would be to decouple
@@ -24,45 +71,53 @@ fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::o
2471
if parts.len() == 1 {
2572
break;
2673
}
27-
let mut min_rank: Option<(usize, usize)> = None;
28-
for i in 0..parts.len() - 1 {
29-
let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) {
30-
*r
31-
} else {
32-
continue;
33-
};
34-
if min_rank.is_none() || rank < min_rank.unwrap().0 {
35-
min_rank = Some((rank, i));
74+
75+
// usize::MAX is a sentinel rank value allowing us to
76+
// take the min more quickly
77+
let mut min_rank: (usize, usize) = (usize::MAX, 0);
78+
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
79+
if rank < min_rank.0 {
80+
min_rank = (rank, i);
3681
}
3782
}
38-
if let Some((_, i)) = min_rank {
39-
parts[i] = parts[i].start..parts[i + 1].end;
83+
84+
if min_rank.0 != usize::MAX {
85+
let i = min_rank.1;
86+
87+
// NOTE: We are about to remove parts[i + 1]. We do not do it
88+
// yet because there are cache-locality benefits to updating
89+
// parts[i] and parts[i-1] before removing, which could thrash
90+
// the cache. Thus, we update the rank calculation by skipping over
91+
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
92+
parts[i].1 = get_rank!(i, 1).unwrap_or(usize::MAX);
93+
if i > 0 {
94+
parts[i - 1].1 = get_rank!(i - 1, 1).unwrap_or(usize::MAX);
95+
}
96+
4097
parts.remove(i + 1);
4198
} else {
4299
break;
43100
}
44101
}
45-
parts
102+
let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
103+
for i in 0..parts.len() - 1 {
104+
out.push(f(parts[i].0..parts[i + 1].0));
105+
}
106+
out
46107
}
47108

48109
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
49110
if piece.len() == 1 {
50111
return vec![ranks[piece]];
51112
}
52-
_byte_pair_merge(piece, ranks)
53-
.iter()
54-
.map(|p| ranks[&piece[p.start..p.end]])
55-
.collect()
113+
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
56114
}
57115

58116
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
59117
if piece.len() == 1 {
60118
return vec![piece];
61119
}
62-
_byte_pair_merge(piece, ranks)
63-
.iter()
64-
.map(|p| &piece[p.start..p.end])
65-
.collect()
120+
_byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
66121
}
67122

68123
// Various performance notes:

0 commit comments

Comments
 (0)