Skip to content

Commit

Permalink
use simdutf8 from_utf8 and still do SIMD accelerated utf8 validation
Browse files Browse the repository at this point in the history
instead of doing unsafe from_utf8_unchecked depending on just first 8k utf8 screening
  • Loading branch information
jqnatividad committed Feb 17, 2023
1 parent f75e50c commit f3afcdd
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 39 deletions.
1 change: 0 additions & 1 deletion src/cmd/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,6 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
let rconfig = Config::new(&args.arg_input)
.delimiter(args.flag_delimiter)
.no_headers(args.flag_no_headers)
// .checkutf8(false)
.select(args.arg_column);

let mut rdr = rconfig.reader()?;
Expand Down
3 changes: 2 additions & 1 deletion src/cmd/dedup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ use std::cmp;
use csv::ByteRecord;
use rayon::prelude::*;
use serde::Deserialize;
use simdutf8::basic::from_utf8;

use crate::{
cmd::sort::iter_cmp,
Expand Down Expand Up @@ -220,6 +221,6 @@ where
X: Iterator<Item = &'a [u8]>,
{
xs.next()
.map(|bytes| unsafe { std::str::from_utf8_unchecked(bytes) })
.and_then(|bytes| from_utf8(bytes).ok())
.map(str::to_lowercase)
}
3 changes: 2 additions & 1 deletion src/cmd/exclude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use std::{collections::hash_map::Entry, fmt, fs, io, str};
use ahash::AHashMap;
use byteorder::{BigEndian, WriteBytesExt};
use serde::Deserialize;
use simdutf8::basic::from_utf8;

use crate::{
config::{Config, Delimiter},
Expand Down Expand Up @@ -252,7 +253,7 @@ fn get_row_key(sel: &Selection, row: &csv::ByteRecord, casei: bool) -> Vec<ByteS

#[inline]
fn transform(bs: &[u8], casei: bool) -> ByteString {
let s = unsafe { str::from_utf8_unchecked(bs) };
let s = from_utf8(bs).unwrap_or_default();
if casei {
let norm: String = s
.trim()
Expand Down
5 changes: 3 additions & 2 deletions src/cmd/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ use reqwest::{
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use simdutf8::basic::from_utf8;
use url::Url;

use crate::{
Expand Down Expand Up @@ -618,14 +619,14 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
// let's dynamically construct the URL with it
record_vec.clear();
for field in &record {
record_vec.push(unsafe { std::str::from_utf8_unchecked(field).to_owned() });
record_vec.push(from_utf8(field).unwrap_or_default().to_owned());
}
if let Ok(formatted) =
dynfmt::SimpleCurlyFormat.format(&dynfmt_url_template, &*record_vec)
{
url = formatted.into_owned();
}
} else if let Ok(s) = std::str::from_utf8(&record[column_index]) {
} else if let Ok(s) = from_utf8(&record[column_index]) {
// we're not using a URL template,
// just use the field as-is as the URL
s.clone_into(&mut url);
Expand Down
6 changes: 3 additions & 3 deletions src/cmd/fetchpost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ use reqwest::{
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use serde_urlencoded;
use simdutf8::basic::from_utf8;
use url::Url;

use crate::{
Expand Down Expand Up @@ -609,8 +610,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
form_body_jsonmap.clear();
for col_idx in col_list.iter() {
let header_key = String::from_utf8_lossy(headers.get(*col_idx).unwrap());
let value_string =
unsafe { std::str::from_utf8_unchecked(&record[*col_idx]).to_string() };
let value_string = from_utf8(&record[*col_idx]).unwrap_or_default().to_string();
form_body_jsonmap.insert(
header_key.to_string(),
serde_json::Value::String(value_string),
Expand All @@ -622,7 +622,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> {

if literal_url_used {
url = literal_url.clone();
} else if let Ok(s) = std::str::from_utf8(&record[column_index]) {
} else if let Ok(s) = from_utf8(&record[column_index]) {
s.clone_into(&mut url);
} else {
url = String::new();
Expand Down
3 changes: 2 additions & 1 deletion src/cmd/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use std::cmp;
use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
use rayon::prelude::*;
use serde::Deserialize;
use simdutf8::basic::from_utf8;

use self::Number::{Float, Int};
use crate::{
Expand Down Expand Up @@ -236,7 +237,7 @@ where
X: Iterator<Item = &'a [u8]>,
{
xs.next()
.map(|bytes| unsafe { std::str::from_utf8_unchecked(bytes) })
.map(|bytes| from_utf8(bytes).unwrap())
.and_then(|s| {
if let Ok(i) = s.parse::<i64>() {
Some(Number::Int(i))
Expand Down
39 changes: 14 additions & 25 deletions src/cmd/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ use itertools::Itertools;
use once_cell::sync::OnceCell;
use qsv_dateparser::parse_with_preference;
use serde::Deserialize;
use simdutf8::basic::from_utf8;
use stats::{merge_all, Commute, MinMax, OnlineStats, Unsorted};
use threadpool::ThreadPool;

Expand Down Expand Up @@ -477,7 +478,7 @@ fn init_date_inference(

let mut infer_date_flags: Vec<bool> = Vec::with_capacity(headers.len());
for header in headers {
let header_str = from_bytes::<String>(header).to_lowercase();
let header_str = from_bytes::<String>(header).unwrap().to_lowercase();
let mut date_found = false;
for whitelist_item in &whitelist {
if header_str.contains(whitelist_item) {
Expand Down Expand Up @@ -635,7 +636,7 @@ impl Stats {
};
}
} else {
let n = from_bytes::<f64>(sample);
let n = from_bytes::<f64>(sample).unwrap();
if let Some(v) = self.median.as_mut() {
v.add(n);
}
Expand Down Expand Up @@ -1029,8 +1030,7 @@ impl FieldType {
return (FieldType::TString, None);
}

// we skip utf8 validation since we say we only work with utf8
let string = unsafe { str::from_utf8_unchecked(sample) };
let string = from_utf8(sample).unwrap();

if current_type == FieldType::TFloat
|| current_type == FieldType::TInteger
Expand Down Expand Up @@ -1137,7 +1137,7 @@ impl TypedSum {
#[allow(clippy::cast_precision_loss)]
match typ {
TFloat => {
let float: f64 = from_bytes::<f64>(sample);
let float: f64 = from_bytes::<f64>(sample).unwrap();
match self.float {
None => {
self.float = Some((self.integer as f64) + float);
Expand All @@ -1149,10 +1149,12 @@ impl TypedSum {
}
TInteger => {
if let Some(ref mut float) = self.float {
*float += from_bytes::<f64>(sample);
*float += from_bytes::<f64>(sample).unwrap();
} else {
// so we don't panic on overflow/underflow, use saturating_add
self.integer = self.integer.saturating_add(from_bytes::<i64>(sample));
self.integer = self
.integer
.saturating_add(from_bytes::<i64>(sample).unwrap());
}
}
_ => {}
Expand Down Expand Up @@ -1219,31 +1221,19 @@ impl TypedMinMax {
match typ {
TString | TNull => {}
TFloat => {
let n = unsafe {
str::from_utf8_unchecked(sample)
.parse::<f64>()
.unwrap_unchecked()
};
let n = from_utf8(sample).unwrap().parse::<f64>().unwrap();

self.floats.add(n);
self.integers.add(n as i64);
}
TInteger => {
let n = unsafe {
str::from_utf8_unchecked(sample)
.parse::<i64>()
.unwrap_unchecked()
};
let n = from_utf8(sample).unwrap().parse::<i64>().unwrap();
self.integers.add(n);
#[allow(clippy::cast_precision_loss)]
self.floats.add(n as f64);
}
TDate | TDateTime => {
let n = unsafe {
str::from_utf8_unchecked(sample)
.parse::<i64>()
.unwrap_unchecked()
};
let n = from_utf8(sample).unwrap().parse::<i64>().unwrap();
self.dates.add(n);
}
}
Expand Down Expand Up @@ -1330,7 +1320,6 @@ impl Commute for TypedMinMax {

#[allow(clippy::inline_always)]
#[inline(always)]
fn from_bytes<T: FromStr>(bytes: &[u8]) -> T {
// we don't need to do UTF-8 validation as qsv requires UTF-8 encoding
unsafe { str::from_utf8_unchecked(bytes).parse().unwrap_unchecked() }
fn from_bytes<T: FromStr>(bytes: &[u8]) -> Option<T> {
from_utf8(bytes).ok().and_then(|s| s.parse().ok())
}
10 changes: 5 additions & 5 deletions src/cmd/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ use once_cell::sync::OnceCell;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json::{json, value::Number, Map, Value};
use simdutf8::basic::from_utf8;
use thousands::Separable;

use crate::{
Expand Down Expand Up @@ -503,9 +504,8 @@ fn do_json_validation(
schema_json: &Value,
schema_compiled: &JSONSchema,
) -> Option<String> {
// row number was added as last column. We use unsafe from_utf8_unchecked to
// skip UTF8 validation since we know its safe as we added it earlier
let row_number_string = unsafe { str::from_utf8_unchecked(record.get(headers_len).unwrap()) };
// row number was added as last column. We use can do unwrap safely since we know its there
let row_number_string = from_utf8(record.get(headers_len).unwrap()).unwrap();

// debug!("instance[{row_number}]: {instance:?}");
validate_json_instance(
Expand Down Expand Up @@ -553,9 +553,9 @@ fn to_json_instance(
// iterate over each CSV field and convert to JSON type
for (i, header) in headers.iter().enumerate() {
// convert csv header to string
let header_string = unsafe { std::str::from_utf8_unchecked(header).to_string() };
let header_string = from_utf8(header).unwrap().to_string();
// convert csv value to string; no trimming reqd as it's done on the record level beforehand
let value_string = unsafe { std::str::from_utf8_unchecked(&record[i]).to_string() };
let value_string = from_utf8(&record[i]).unwrap().to_string();

// if value_string is empty, then just put an empty JSON String
if value_string.is_empty() {
Expand Down

0 comments on commit f3afcdd

Please sign in to comment.