Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

short_vec::decode_len() returns wrong size for aliased values #11624

Merged
merged 3 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions perf/src/sigverify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ fn do_get_packet_offsets(
}

// read the length of Transaction.signatures (serialized with short_vec)
let (sig_len_untrusted, sig_size) = decode_len(&packet.data)?;
let (sig_len_untrusted, sig_size) =
decode_len(&packet.data).map_err(|_| PacketError::InvalidShortVec)?;

// Using msg_start_offset which is based on sig_len_untrusted introduces uncertainty.
// Ultimately, the actual sigverify will determine the uncertainty.
Expand All @@ -156,8 +157,8 @@ fn do_get_packet_offsets(
}

// read the length of Message.account_keys (serialized with short_vec)
let (pubkey_len, pubkey_len_size) =
decode_len(&packet.data[message_account_keys_len_offset..])?;
let (pubkey_len, pubkey_len_size) = decode_len(&packet.data[message_account_keys_len_offset..])
.map_err(|_| PacketError::InvalidShortVec)?;

if (message_account_keys_len_offset + pubkey_len * size_of::<Pubkey>() + pubkey_len_size)
> packet.meta.size
Expand Down
70 changes: 57 additions & 13 deletions sdk/src/short_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,26 @@ impl Serialize for ShortU16 {
}
}

enum VisitResult {
Done(usize, usize),
More(usize, usize),
Err,
}

fn visit_byte(elem: u8, len: usize, size: usize) -> VisitResult {
let len = len | (elem as usize & 0x7f) << (size * 7);
let size = size + 1;
let more = elem as usize & 0x80 == 0x80;

if size > size_of::<u16>() + 1 {
VisitResult::Err
} else if more {
VisitResult::More(len, size)
} else {
VisitResult::Done(len, size)
}
}

struct ShortLenVisitor;

impl<'de> Visitor<'de> for ShortLenVisitor {
Expand All @@ -58,15 +78,16 @@ impl<'de> Visitor<'de> for ShortLenVisitor {
.next_element()?
.ok_or_else(|| de::Error::invalid_length(size, &self))?;

len |= (elem as usize & 0x7f) << (size * 7);
size += 1;

if elem as usize & 0x80 == 0 {
break;
}

if size > size_of::<u16>() + 1 {
return Err(de::Error::invalid_length(size, &self));
match visit_byte(elem, len, size) {
VisitResult::Done(l, _) => {
len = l;
break;
}
VisitResult::More(l, s) => {
len = l;
size = s;
}
VisitResult::Err => return Err(de::Error::invalid_length(size + 1, &self)),
}
}

Expand Down Expand Up @@ -178,10 +199,20 @@ impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
}

/// Return the decoded value and how many bytes it consumed.
pub fn decode_len(bytes: &[u8]) -> Result<(usize, usize), Box<bincode::ErrorKind>> {
let short_len: ShortU16 = bincode::deserialize(bytes)?;
let num_bytes = bincode::serialized_size(&short_len)?;
Ok((short_len.0 as usize, num_bytes as usize))
pub fn decode_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
let mut len = 0;
let mut size = 0;
for byte in bytes.iter() {
match visit_byte(*byte, len, size) {
VisitResult::More(l, s) => {
len = l;
size = s;
}
VisitResult::Done(len, size) => return Ok((len, size)),
VisitResult::Err => return Err(()),
}
}
Err(())
}

#[cfg(test)]
Expand Down Expand Up @@ -246,4 +277,17 @@ mod tests {
let s = serde_json::to_string(&vec).unwrap();
assert_eq!(s, "[[3],0,1,2]");
}

#[test]
fn test_decode_len_aliased_values() {
let one1 = [0x01];
let one2 = [0x81, 0x00];
garious marked this conversation as resolved.
Show resolved Hide resolved
let one3 = [0x81, 0x80, 0x00];
let one4 = [0x81, 0x80, 0x80, 0x00];

assert_eq!(decode_len(&one1).unwrap(), (1, 1));
assert_eq!(decode_len(&one2).unwrap(), (1, 2));
assert_eq!(decode_len(&one3).unwrap(), (1, 3));
assert!(decode_len(&one4).is_err());
}
}