Skip to content

RUST-871 Serialize directly to BSON bytes in insert operations #406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 5, 2021
2 changes: 1 addition & 1 deletion benchmarks/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ fn parse_ids(matches: ArgMatches) -> Vec<bool> {
ids
}

#[cfg_attr(feature = "tokio-runtime", tokio::main)]
#[cfg_attr(feature = "tokio-runtime", tokio::main(flavor = "current_thread"))]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the benchmarks are single threaded, this ends up giving a nice boost by avoiding the tokio overhead of managing a thread pool.

#[cfg_attr(feature = "async-std-runtime", async_std::main)]
async fn main() {
let matches = App::new("RustDriverBenchmark")
Expand Down
333 changes: 169 additions & 164 deletions src/bson_util/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use std::{convert::TryFrom, io::Read, time::Duration};
use std::{
convert::{TryFrom, TryInto},
io::{Read, Write},
time::Duration,
};

use serde::{de::Error, ser, Deserialize, Deserializer, Serialize, Serializer};
use bson::spec::ElementType;
use serde::{de::Error as SerdeDeError, ser, Deserialize, Deserializer, Serialize, Serializer};

use crate::{
bson::{doc, Binary, Bson, Document, JavaScriptCodeWithScope, Regex},
error::{ErrorKind, Result},
bson::{doc, Bson, Document},
error::{Error, ErrorKind, Result},
runtime::{SyncLittleEndianRead, SyncLittleEndianWrite},
};

Expand Down Expand Up @@ -164,128 +169,30 @@ where
.ok_or_else(|| D::Error::custom(format!("could not deserialize u64 from {:?}", bson)))
}

pub fn doc_size_bytes(doc: &Document) -> u64 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were no longer needed as we now just use the length of the serialized vec.

//
// * i32 length prefix (4 bytes)
// * for each element:
// * type (1 byte)
// * number of UTF-8 bytes in key
// * null terminator for the key (1 byte)
// * size of the value
// * null terminator (1 byte)
4 + doc
.into_iter()
.map(|(key, val)| 1 + key.len() as u64 + 1 + size_bytes(val))
.sum::<u64>()
+ 1
}

pub fn size_bytes(val: &Bson) -> u64 {
match val {
Bson::Double(_) => 8,
//
// * length prefix (4 bytes)
// * number of UTF-8 bytes
// * null terminator (1 byte)
Bson::String(s) => 4 + s.len() as u64 + 1,
// An array is serialized as a document with the keys "0", "1", "2", etc., so the size of
// an array is:
//
// * length prefix (4 bytes)
// * for each element:
// * type (1 byte)
// * number of decimal digits in key
// * null terminator for the key (1 byte)
// * size of value
// * null terminator (1 byte)
Bson::Array(arr) => {
4 + arr
.iter()
.enumerate()
.map(|(i, val)| 1 + num_decimal_digits(i) + 1 + size_bytes(val))
.sum::<u64>()
+ 1
}
Bson::Document(doc) => doc_size_bytes(doc),
Bson::Boolean(_) => 1,
Bson::Null => 0,
// for $pattern and $opts:
// * number of UTF-8 bytes
// * null terminator (1 byte)
Bson::RegularExpression(Regex { pattern, options }) => {
pattern.len() as u64 + 1 + options.len() as u64 + 1
}
//
// * length prefix (4 bytes)
// * number of UTF-8 bytes
// * null terminator (1 byte)
Bson::JavaScriptCode(code) => 4 + code.len() as u64 + 1,
//
// * i32 length prefix (4 bytes)
// * i32 length prefix for code (4 bytes)
// * number of UTF-8 bytes in code
// * null terminator for code (1 byte)
// * length of document
Bson::JavaScriptCodeWithScope(JavaScriptCodeWithScope { code, scope }) => {
4 + 4 + code.len() as u64 + 1 + doc_size_bytes(scope)
}
Bson::Int32(_) => 4,
Bson::Int64(_) => 8,
Bson::Timestamp(_) => 8,
//
// * i32 length prefix (4 bytes)
// * subtype (1 byte)
// * number of bytes
Bson::Binary(Binary { bytes, .. }) => 4 + 1 + bytes.len() as u64,
Bson::ObjectId(_) => 12,
Bson::DateTime(_) => 8,
//
// * i32 length prefix (4 bytes)
// * subtype (1 byte)
// * number of UTF-8 bytes
Bson::Symbol(s) => 4 + 1 + s.len() as u64,
Bson::Decimal128(..) => 128 / 8,
Bson::Undefined | Bson::MaxKey | Bson::MinKey => 0,
// DbPointer doesn't have public details exposed by the BSON library, but it comprises of a
// namespace and an ObjectId. Since our methods to calculate the size of BSON values are
// only used to estimate the cutoff for batches when making a large insert, we can just
// assume the largest possible size for a namespace, which is 120 bytes. Therefore, the size
// is:
//
// * i32 length prefix (4 bytes)
// * namespace (120 bytes)
// * null terminator (1 byte)
// * objectid (12 bytes)
Bson::DbPointer(..) => 4 + 120 + 1 + 12,
}
}

/// The size in bytes of the provided document's entry in a BSON array at the given index.
pub(crate) fn array_entry_size_bytes(index: usize, doc: &Document) -> u64 {
pub(crate) fn array_entry_size_bytes(index: usize, doc_len: usize) -> u64 {
//
// * type (1 byte)
// * number of decimal digits in key
// * null terminator for the key (1 byte)
// * size of value
1 + num_decimal_digits(index) + 1 + doc_size_bytes(doc)

1 + num_decimal_digits(index) + 1 + doc_len as u64
}

/// The number of digits in `n` in base 10.
/// Useful for calculating the size of an array entry in BSON.
fn num_decimal_digits(n: usize) -> u64 {
let mut digits = 1;
let mut curr = 10;

while curr < n {
curr = match curr.checked_mul(10) {
Some(val) => val,
None => break,
};
fn num_decimal_digits(mut n: usize) -> u64 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the old implementation of this method was incorrect actually so I rewrote it.

let mut digits = 0;

loop {
n /= 10;
digits += 1;
}

digits
if n == 0 {
return digits;
}
}
}

/// Read a document's raw BSON bytes from the provided reader.
Expand All @@ -300,63 +207,161 @@ pub(crate) fn read_document_bytes<R: Read>(mut reader: R) -> Result<Vec<u8>> {
Ok(bytes)
}

/// Serialize the document to raw BSON and return a vec containing the bytes.
#[cfg(test)]
pub(crate) fn document_to_vec(doc: Document) -> Result<Vec<u8>> {
let mut v = Vec::new();
doc.to_writer(&mut v)?;
Ok(v)
/// Get the value for the provided key from a buffer containing a BSON document.
/// If the key is not present, None will be returned.
/// If the BSON is not properly formatted, an internal error would be returned.
///
/// TODO: RUST-924 replace this with raw document API usage.
pub(crate) fn raw_get(doc: &[u8], key: &str) -> Result<Option<Bson>> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this function needs to live in the driver for now as we need to be able to comb through the serialized command for _id values, but it can go away shortly once RUST-924 is complete.

fn read_i32(reader: &mut std::io::Cursor<&[u8]>) -> Result<i32> {
reader.read_i32().map_err(deserialize_error)
}

fn read_u8(reader: &mut std::io::Cursor<&[u8]>) -> Result<u8> {
reader.read_u8().map_err(deserialize_error)
}

fn deserialize_error<T: std::error::Error>(_e: T) -> Error {
deserialize_error_no_arg()
}

fn deserialize_error_no_arg() -> Error {
Error::from(ErrorKind::Internal {
message: "failed to read from serialized document".to_string(),
})
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documents queried here should always be valid as they're coming straight out of our serializer, so any error encountered here is mapped to an internal error.


let mut reader = std::io::Cursor::new(doc);
let len: u64 = read_i32(&mut reader)?
.try_into()
.map_err(deserialize_error)?;

while reader.position() < len {
let element_start: usize = reader.position().try_into().map_err(deserialize_error)?;

// read the element type
let tag = read_u8(&mut reader)?;

// check if we reached the end of the document
if tag == 0 && reader.position() == len {
return Ok(None);
}

let element_type = ElementType::from(tag).ok_or_else(deserialize_error_no_arg)?;

// walk through the document until a null byte is encountered
while read_u8(&mut reader)? != 0 {
if reader.position() >= len {
return Err(deserialize_error_no_arg());
}
}

// parse the key
let string_end: usize = reader
.position()
.checked_sub(1) // back from null byte
.and_then(|u| usize::try_from(u).ok())
.ok_or_else(deserialize_error_no_arg)?;
let slice = &reader.get_ref()[(element_start + 1)..string_end];
let k = std::str::from_utf8(slice).map_err(deserialize_error)?;

// move to the end of the element
let skip_len = match element_type {
ElementType::Array
| ElementType::EmbeddedDocument
| ElementType::JavaScriptCodeWithScope => {
let l = read_i32(&mut reader)?;
// length includes the 4 bytes for the length, so subtrack them out
l.checked_sub(4).ok_or_else(deserialize_error_no_arg)?
}
ElementType::Binary => read_i32(&mut reader)?
.checked_add(1) // add one for subtype
.ok_or_else(deserialize_error_no_arg)?,
ElementType::Int32 => 4,
ElementType::Int64 => 8,
ElementType::String | ElementType::Symbol | ElementType::JavaScriptCode => {
read_i32(&mut reader)?
}
ElementType::Boolean => 1,
ElementType::Double => 8,
ElementType::Timestamp => 8,
ElementType::Decimal128 => 16,
ElementType::MinKey
| ElementType::MaxKey
| ElementType::Null
| ElementType::Undefined => 0,
ElementType::DateTime => 8,
ElementType::ObjectId => 12,
ElementType::DbPointer => read_i32(&mut reader)?
.checked_add(12) // add 12 for objectid
.ok_or_else(deserialize_error_no_arg)?,
ElementType::RegularExpression => {
// read two cstr's
for _i in 0..2 {
while read_u8(&mut reader)? != 0 {
if reader.position() >= len {
return Err(deserialize_error_no_arg());
}
}
}

0 // don't need to skip anymore since we already read the whole value
}
};
let skip_len: u64 = skip_len.try_into().map_err(deserialize_error)?;
reader.set_position(
reader
.position()
.checked_add(skip_len)
.ok_or_else(deserialize_error_no_arg)?,
);

if k == key {
// if this is the element we're looking for, extract it.
let element_end: usize = reader.position().try_into().map_err(deserialize_error)?;
let element_slice = &reader.get_ref()[element_start..element_end];
let element_length: i32 = element_slice.len().try_into().map_err(deserialize_error)?;

// create a new temporary document which just has the element we want and grab the value
let mut temp_doc = Vec::new();

// write the document length
let temp_len: i32 = element_length
.checked_add(4 + 1)
.ok_or_else(deserialize_error_no_arg)?;
temp_doc
.write_all(&temp_len.to_le_bytes())
.map_err(deserialize_error)?;

// add in the element
temp_doc.extend(element_slice);

// write the null byte
temp_doc.push(0);

let d = Document::from_reader(temp_doc.as_slice()).map_err(deserialize_error)?;
return Ok(Some(
d.get("_id").ok_or_else(deserialize_error_no_arg)?.clone(),
));
}
}

// read all bytes but didn't reach null byte
Err(deserialize_error_no_arg())
}

#[cfg(test)]
mod test {
use crate::bson::{
doc,
oid::ObjectId,
spec::BinarySubtype,
Binary,
Bson,
DateTime,
JavaScriptCodeWithScope,
Regex,
Timestamp,
};

use super::doc_size_bytes;
use crate::bson_util::num_decimal_digits;

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn doc_size_bytes_eq_serialized_size_bytes() {
let doc = doc! {
"double": -12.3,
"string": "foo",
"array": ["foobar", -7, Bson::Null, Bson::Timestamp(Timestamp { time: 12345, increment: 67890 }), false],
"document": {
"x": 1,
"yyz": "Rush is one of the greatest bands of all time",
},
"bool": true,
"null": Bson::Null,
"regex": Bson::RegularExpression(Regex { pattern: "foobar".into(), options: "i".into() }),
"code": Bson::JavaScriptCode("foo(x) { return x + 1; }".into()),
"code with scope": Bson::JavaScriptCodeWithScope(JavaScriptCodeWithScope {
code: "foo(x) { return x + y; }".into(),
scope: doc! { "y": -17 },
}),
"i32": 12i32,
"i64": -126i64,
"timestamp": Bson::Timestamp(Timestamp { time: 12233, increment: 34444 }),
"binary": Bson::Binary(Binary{ subtype: BinarySubtype::Generic, bytes: vec![3, 222, 11] }),
"objectid": ObjectId::from_bytes([1; 12]),
"datetime": DateTime::from_millis(4444333221),
"symbol": Bson::Symbol("foobar".into()),
};

let size_bytes = doc_size_bytes(&doc);

let mut serialized_bytes = Vec::new();
doc.to_writer(&mut serialized_bytes).unwrap();

assert_eq!(size_bytes, serialized_bytes.len() as u64);
async fn num_digits() {
assert_eq!(num_decimal_digits(0), 1);
assert_eq!(num_decimal_digits(1), 1);
assert_eq!(num_decimal_digits(10), 2);
assert_eq!(num_decimal_digits(15), 2);
assert_eq!(num_decimal_digits(100), 3);
assert_eq!(num_decimal_digits(125), 3);
}
}
Loading