Skip to content

RUST-1133 Update driver to use the raw BSON API #546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 3 additions & 151 deletions src/bson_util/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
use std::{
convert::{TryFrom, TryInto},
io::{Read, Write},
time::Duration,
};
use std::{convert::TryFrom, io::Read, time::Duration};

use bson::{spec::ElementType, RawBsonRef};
use bson::RawBsonRef;
use serde::{de::Error as SerdeDeError, ser, Deserialize, Deserializer, Serialize, Serializer};

use crate::{
bson::{doc, Bson, Document},
error::{Error, ErrorKind, Result},
error::{ErrorKind, Result},
runtime::{SyncLittleEndianRead, SyncLittleEndianWrite},
};

Expand Down Expand Up @@ -191,7 +187,6 @@ where

/// The size in bytes of the provided document's entry in a BSON array at the given index.
pub(crate) fn array_entry_size_bytes(index: usize, doc_len: usize) -> u64 {
//
// * type (1 byte)
// * number of decimal digits in key
// * null terminator for the key (1 byte)
Expand Down Expand Up @@ -227,149 +222,6 @@ pub(crate) fn read_document_bytes<R: Read>(mut reader: R) -> Result<Vec<u8>> {
Ok(bytes)
}

/// Get the value for the provided key from a buffer containing a BSON document.
/// If the key is not present, None will be returned.
/// If the BSON is not properly formatted, an internal error would be returned.
///
/// TODO: RUST-924 replace this with raw document API usage.
pub(crate) fn raw_get(doc: &[u8], key: &str) -> Result<Option<Bson>> {
fn read_i32(reader: &mut std::io::Cursor<&[u8]>) -> Result<i32> {
reader.read_i32().map_err(deserialize_error)
}

fn read_u8(reader: &mut std::io::Cursor<&[u8]>) -> Result<u8> {
reader.read_u8().map_err(deserialize_error)
}

fn deserialize_error<T: std::error::Error>(_e: T) -> Error {
deserialize_error_no_arg()
}

fn deserialize_error_no_arg() -> Error {
Error::from(ErrorKind::Internal {
message: "failed to read from serialized document".to_string(),
})
}

let mut reader = std::io::Cursor::new(doc);
let len: u64 = read_i32(&mut reader)?
.try_into()
.map_err(deserialize_error)?;

while reader.position() < len {
let element_start: usize = reader.position().try_into().map_err(deserialize_error)?;

// read the element type
let tag = read_u8(&mut reader)?;

// check if we reached the end of the document
if tag == 0 && reader.position() == len {
return Ok(None);
}

let element_type = ElementType::from(tag).ok_or_else(deserialize_error_no_arg)?;

// walk through the document until a null byte is encountered
while read_u8(&mut reader)? != 0 {
if reader.position() >= len {
return Err(deserialize_error_no_arg());
}
}

// parse the key
let string_end: usize = reader
.position()
.checked_sub(1) // back from null byte
.and_then(|u| usize::try_from(u).ok())
.ok_or_else(deserialize_error_no_arg)?;
let slice = &reader.get_ref()[(element_start + 1)..string_end];
let k = std::str::from_utf8(slice).map_err(deserialize_error)?;

// move to the end of the element
let skip_len = match element_type {
ElementType::Array
| ElementType::EmbeddedDocument
| ElementType::JavaScriptCodeWithScope => {
let l = read_i32(&mut reader)?;
// length includes the 4 bytes for the length, so subtrack them out
l.checked_sub(4).ok_or_else(deserialize_error_no_arg)?
}
ElementType::Binary => read_i32(&mut reader)?
.checked_add(1) // add one for subtype
.ok_or_else(deserialize_error_no_arg)?,
ElementType::Int32 => 4,
ElementType::Int64 => 8,
ElementType::String | ElementType::Symbol | ElementType::JavaScriptCode => {
read_i32(&mut reader)?
}
ElementType::Boolean => 1,
ElementType::Double => 8,
ElementType::Timestamp => 8,
ElementType::Decimal128 => 16,
ElementType::MinKey
| ElementType::MaxKey
| ElementType::Null
| ElementType::Undefined => 0,
ElementType::DateTime => 8,
ElementType::ObjectId => 12,
ElementType::DbPointer => read_i32(&mut reader)?
.checked_add(12) // add 12 for objectid
.ok_or_else(deserialize_error_no_arg)?,
ElementType::RegularExpression => {
// read two cstr's
for _i in 0..2 {
while read_u8(&mut reader)? != 0 {
if reader.position() >= len {
return Err(deserialize_error_no_arg());
}
}
}

0 // don't need to skip anymore since we already read the whole value
}
};
let skip_len: u64 = skip_len.try_into().map_err(deserialize_error)?;
reader.set_position(
reader
.position()
.checked_add(skip_len)
.ok_or_else(deserialize_error_no_arg)?,
);

if k == key {
// if this is the element we're looking for, extract it.
let element_end: usize = reader.position().try_into().map_err(deserialize_error)?;
let element_slice = &reader.get_ref()[element_start..element_end];
let element_length: i32 = element_slice.len().try_into().map_err(deserialize_error)?;

// create a new temporary document which just has the element we want and grab the value
let mut temp_doc = Vec::new();

// write the document length
let temp_len: i32 = element_length
.checked_add(4 + 1)
.ok_or_else(deserialize_error_no_arg)?;
temp_doc
.write_all(&temp_len.to_le_bytes())
.map_err(deserialize_error)?;

// add in the element
temp_doc.extend(element_slice);

// write the null byte
temp_doc.push(0);

let d = Document::from_reader(temp_doc.as_slice()).map_err(deserialize_error)?;
return Ok(Some(
d.get("_id").ok_or_else(deserialize_error_no_arg)?.clone(),
));
}
}

// read all bytes but didn't reach null byte
Err(deserialize_error_no_arg())
}

#[cfg(test)]
mod test {
use crate::bson_util::num_decimal_digits;
Expand Down
98 changes: 26 additions & 72 deletions src/operation/insert/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#[cfg(test)]
mod test;

use std::{collections::HashMap, io::Write};
use std::{collections::HashMap, convert::TryInto};

use bson::{oid::ObjectId, spec::ElementType, Bson};
use bson::{oid::ObjectId, Bson, RawArrayBuf, RawDocumentBuf};
use serde::Serialize;

use crate::{
Expand All @@ -14,7 +14,6 @@ use crate::{
operation::{remove_empty_write_concern, Operation, Retryability, WriteResponseBody},
options::{InsertManyOptions, WriteConcern},
results::InsertManyResult,
runtime::SyncLittleEndianWrite,
Namespace,
};

Expand Down Expand Up @@ -57,7 +56,7 @@ impl<'a, T: Serialize> Operation for Insert<'a, T> {
const NAME: &'static str = "insert";

fn build(&mut self, description: &StreamDescription) -> Result<Command<InsertCommand>> {
let mut docs: Vec<Vec<u8>> = Vec::new();
let mut docs = RawArrayBuf::new();
let mut size = 0;

for (i, d) in self
Expand All @@ -66,31 +65,32 @@ impl<'a, T: Serialize> Operation for Insert<'a, T> {
.take(description.max_write_batch_size as usize)
.enumerate()
{
let mut doc = bson::to_vec(d)?;
let id = match bson_util::raw_get(doc.as_slice(), "_id")? {
Some(b) => b,
let mut doc = bson::to_raw_document_buf(d)?;
let id = match doc.get("_id")? {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up using basically the same prepend _id logic here.

Some(b) => b.try_into()?,
None => {
// TODO: RUST-924 Use raw document API here instead.
let mut new_doc = RawDocumentBuf::new();
let oid = ObjectId::new();
new_doc.append("_id", oid);

// write element to temporary buffer
let mut new_id = Vec::new();
new_id.write_u8(ElementType::ObjectId as u8)?;
new_id.write_all(b"_id\0")?;
new_id.extend(oid.bytes().iter());
let mut new_bytes = new_doc.into_bytes();
new_bytes.pop(); // remove trailing null byte

// insert element to beginning of existing doc after length
doc.splice(4..4, new_id.into_iter());
let mut bytes = doc.into_bytes();
let oid_slice = &new_bytes[4..];
// insert oid at beginning of document
bytes.splice(4..4, oid_slice.iter().cloned());

// update length of doc
let new_len = doc.len() as i32;
doc.splice(0..4, new_len.to_le_bytes().iter().cloned());
// overwrite old length
let new_length = (bytes.len() as i32).to_le_bytes();
(&mut bytes[0..4]).copy_from_slice(&new_length);
doc = RawDocumentBuf::from_bytes(bytes)?;

Bson::ObjectId(oid)
}
};

let doc_size = bson_util::array_entry_size_bytes(i, doc.len());
let doc_size = bson_util::array_entry_size_bytes(i, doc.as_bytes().len());

if (size + doc_size) <= description.max_bson_object_size as u64 {
if self.inserted_ids.len() <= i {
Expand All @@ -115,56 +115,19 @@ impl<'a, T: Serialize> Operation for Insert<'a, T> {

let body = InsertCommand {
insert: self.ns.coll.clone(),
documents: DocumentArraySpec {
documents: docs,
length: size as i32,
},
documents: docs,
options,
};

Ok(Command::new("insert".to_string(), self.ns.db.clone(), body))
}

fn serialize_command(&mut self, cmd: Command<Self::Command>) -> Result<Vec<u8>> {
// TODO: RUST-924 Use raw document API here instead.
let mut serialized = bson::to_vec(&cmd)?;

serialized.pop(); // drop null byte

// write element type
serialized.push(ElementType::Array as u8);

// write key cstring
serialized.write_all("documents".as_bytes())?;
serialized.push(0);

// write length of array
let array_length = 4 + cmd.body.documents.length + 1; // add in 4 for length of array, 1 for null byte
serialized.write_all(&array_length.to_le_bytes())?;

for (i, doc) in cmd.body.documents.documents.into_iter().enumerate() {
// write type of document
serialized.push(ElementType::EmbeddedDocument as u8);

// write array index
serialized.write_all(i.to_string().as_bytes())?;
serialized.push(0);

// write document
serialized.extend(doc);
}

// write null byte for array
serialized.push(0);

// write null byte for containing document
serialized.push(0);

// update length of original doc
let final_length = serialized.len() as i32;
serialized.splice(0..4, final_length.to_le_bytes().iter().cloned());

Ok(serialized)
let mut doc = bson::to_raw_document_buf(&cmd)?;
// need to append documents separately because #[serde(flatten)] breaks the custom
// serialization logic. See https://github.com/serde-rs/serde/issues/2106.
doc.append("documents", cmd.body.documents);
Ok(doc.into_bytes())
}

fn handle_response(
Expand Down Expand Up @@ -222,22 +185,13 @@ impl<'a, T: Serialize> Operation for Insert<'a, T> {
}
}

/// Data used for creating a BSON array.
struct DocumentArraySpec {
/// The sum of the lengths of all the documents.
length: i32,

/// The serialized documents to be inserted.
documents: Vec<Vec<u8>>,
}

#[derive(Serialize)]
pub(crate) struct InsertCommand {
insert: String,

/// will be serialized in `serialize_command`
#[serde(skip)]
documents: DocumentArraySpec,
documents: RawArrayBuf,

#[serde(flatten)]
options: InsertManyOptions,
Expand Down
5 changes: 2 additions & 3 deletions src/operation/insert/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ async fn build() {
let mut cmd_docs: Vec<Document> = cmd
.body
.documents
.documents
.iter()
.map(|b| Document::from_reader(b.as_slice()).unwrap())
.into_iter()
.map(|b| Document::from_reader(b.unwrap().as_document().unwrap().as_bytes()).unwrap())
.collect();
assert_eq!(cmd_docs.len(), fixtures.documents.len());

Expand Down