Skip to content

test: Port tests in predicates.rs to sqllogictest #8879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 1 addition & 201 deletions datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ use arrow::{
array::*, datatypes::*, record_batch::RecordBatch,
util::display::array_value_to_string,
};
use chrono::prelude::*;

use datafusion::datasource::TableProvider;
use datafusion::error::{DataFusionError, Result};
use datafusion::error::Result;
use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan};
use datafusion::physical_plan::metrics::MetricValue;
use datafusion::physical_plan::ExecutionPlan;
Expand All @@ -34,12 +33,10 @@ use datafusion::test_util;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion::{datasource::MemTable, physical_plan::collect};
use datafusion::{execution::context::SessionContext, physical_plan::displayable};
use datafusion_common::plan_err;
use datafusion_common::{assert_contains, assert_not_contains};
use object_store::path::Path;
use std::fs::File;
use std::io::Write;
use std::ops::Sub;
use std::path::PathBuf;
use tempfile::TempDir;

Expand Down Expand Up @@ -77,7 +74,6 @@ pub mod explain_analyze;
pub mod expr;
pub mod joins;
pub mod partitioned_csv;
pub mod predicates;
pub mod references;
pub mod repartition;
pub mod select;
Expand Down Expand Up @@ -211,202 +207,6 @@ fn create_left_semi_anti_join_context_with_null_ids(
Ok(ctx)
}

fn get_tpch_table_schema(table: &str) -> Schema {
match table {
"customer" => Schema::new(vec![
Field::new("c_custkey", DataType::Int64, false),
Field::new("c_name", DataType::Utf8, false),
Field::new("c_address", DataType::Utf8, false),
Field::new("c_nationkey", DataType::Int64, false),
Field::new("c_phone", DataType::Utf8, false),
Field::new("c_acctbal", DataType::Decimal128(15, 2), false),
Field::new("c_mktsegment", DataType::Utf8, false),
Field::new("c_comment", DataType::Utf8, false),
]),

"orders" => Schema::new(vec![
Field::new("o_orderkey", DataType::Int64, false),
Field::new("o_custkey", DataType::Int64, false),
Field::new("o_orderstatus", DataType::Utf8, false),
Field::new("o_totalprice", DataType::Decimal128(15, 2), false),
Field::new("o_orderdate", DataType::Date32, false),
Field::new("o_orderpriority", DataType::Utf8, false),
Field::new("o_clerk", DataType::Utf8, false),
Field::new("o_shippriority", DataType::Int32, false),
Field::new("o_comment", DataType::Utf8, false),
]),

"lineitem" => Schema::new(vec![
Field::new("l_orderkey", DataType::Int64, false),
Field::new("l_partkey", DataType::Int64, false),
Field::new("l_suppkey", DataType::Int64, false),
Field::new("l_linenumber", DataType::Int32, false),
Field::new("l_quantity", DataType::Decimal128(15, 2), false),
Field::new("l_extendedprice", DataType::Decimal128(15, 2), false),
Field::new("l_discount", DataType::Decimal128(15, 2), false),
Field::new("l_tax", DataType::Decimal128(15, 2), false),
Field::new("l_returnflag", DataType::Utf8, false),
Field::new("l_linestatus", DataType::Utf8, false),
Field::new("l_shipdate", DataType::Date32, false),
Field::new("l_commitdate", DataType::Date32, false),
Field::new("l_receiptdate", DataType::Date32, false),
Field::new("l_shipinstruct", DataType::Utf8, false),
Field::new("l_shipmode", DataType::Utf8, false),
Field::new("l_comment", DataType::Utf8, false),
]),

"nation" => Schema::new(vec![
Field::new("n_nationkey", DataType::Int64, false),
Field::new("n_name", DataType::Utf8, false),
Field::new("n_regionkey", DataType::Int64, false),
Field::new("n_comment", DataType::Utf8, false),
]),

"supplier" => Schema::new(vec![
Field::new("s_suppkey", DataType::Int64, false),
Field::new("s_name", DataType::Utf8, false),
Field::new("s_address", DataType::Utf8, false),
Field::new("s_nationkey", DataType::Int64, false),
Field::new("s_phone", DataType::Utf8, false),
Field::new("s_acctbal", DataType::Decimal128(15, 2), false),
Field::new("s_comment", DataType::Utf8, false),
]),

"partsupp" => Schema::new(vec![
Field::new("ps_partkey", DataType::Int64, false),
Field::new("ps_suppkey", DataType::Int64, false),
Field::new("ps_availqty", DataType::Int32, false),
Field::new("ps_supplycost", DataType::Decimal128(15, 2), false),
Field::new("ps_comment", DataType::Utf8, false),
]),

"part" => Schema::new(vec![
Field::new("p_partkey", DataType::Int64, false),
Field::new("p_name", DataType::Utf8, false),
Field::new("p_mfgr", DataType::Utf8, false),
Field::new("p_brand", DataType::Utf8, false),
Field::new("p_type", DataType::Utf8, false),
Field::new("p_size", DataType::Int32, false),
Field::new("p_container", DataType::Utf8, false),
Field::new("p_retailprice", DataType::Decimal128(15, 2), false),
Field::new("p_comment", DataType::Utf8, false),
]),

"region" => Schema::new(vec![
Field::new("r_regionkey", DataType::Int64, false),
Field::new("r_name", DataType::Utf8, false),
Field::new("r_comment", DataType::Utf8, false),
]),

_ => unimplemented!("Table: {}", table),
}
}

async fn register_tpch_csv(ctx: &SessionContext, table: &str) -> Result<()> {
let schema = get_tpch_table_schema(table);

ctx.register_csv(
table,
format!("tests/tpch-csv/{table}.csv").as_str(),
CsvReadOptions::new().schema(&schema),
)
.await?;
Ok(())
}

async fn register_tpch_csv_data(
ctx: &SessionContext,
table_name: &str,
data: &str,
) -> Result<()> {
let schema = Arc::new(get_tpch_table_schema(table_name));

let mut reader = ::csv::ReaderBuilder::new()
.has_headers(false)
.from_reader(data.as_bytes());
let records: Vec<_> = reader.records().map(|it| it.unwrap()).collect();

let mut cols: Vec<Box<dyn ArrayBuilder>> = vec![];
for field in schema.fields().iter() {
match field.data_type() {
DataType::Utf8 => cols.push(Box::new(StringBuilder::new())),
DataType::Date32 => {
cols.push(Box::new(Date32Builder::with_capacity(records.len())))
}
DataType::Int32 => {
cols.push(Box::new(Int32Builder::with_capacity(records.len())))
}
DataType::Int64 => {
cols.push(Box::new(Int64Builder::with_capacity(records.len())))
}
DataType::Decimal128(_, _) => {
cols.push(Box::new(Decimal128Builder::with_capacity(records.len())))
}
_ => plan_err!("Not implemented: {}", field.data_type())?,
}
}

for record in records.iter() {
for (idx, val) in record.iter().enumerate() {
let col = cols.get_mut(idx).unwrap();
let field = schema.field(idx);
match field.data_type() {
DataType::Utf8 => {
let sb = col.as_any_mut().downcast_mut::<StringBuilder>().unwrap();
sb.append_value(val);
}
DataType::Date32 => {
let sb = col.as_any_mut().downcast_mut::<Date32Builder>().unwrap();
let dt = NaiveDate::parse_from_str(val.trim(), "%Y-%m-%d").unwrap();
let dt = dt
.sub(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap())
.num_days() as i32;
sb.append_value(dt);
}
DataType::Int32 => {
let sb = col.as_any_mut().downcast_mut::<Int32Builder>().unwrap();
sb.append_value(val.trim().parse().unwrap());
}
DataType::Int64 => {
let sb = col.as_any_mut().downcast_mut::<Int64Builder>().unwrap();
sb.append_value(val.trim().parse().unwrap());
}
DataType::Decimal128(_, _) => {
let sb = col
.as_any_mut()
.downcast_mut::<Decimal128Builder>()
.unwrap();
let val = val.trim().replace('.', "");
let value_i128 = val.parse::<i128>().unwrap();
sb.append_value(value_i128);
}
_ => plan_err!("Not implemented: {}", field.data_type())?,
}
}
}
let cols: Vec<ArrayRef> = cols
.iter_mut()
.zip(schema.fields())
.map(|(it, field)| match field.data_type() {
DataType::Decimal128(p, s) => Arc::new(
it.as_any_mut()
.downcast_mut::<Decimal128Builder>()
.unwrap()
.finish()
.with_precision_and_scale(*p, *s)
.unwrap(),
),
_ => it.finish(),
})
.collect();

let batch = RecordBatch::try_new(Arc::clone(&schema), cols)?;

let _ = ctx.register_batch(table_name, batch).unwrap();

Ok(())
}

async fn register_aggregate_csv_by_sql(ctx: &SessionContext) {
let testdata = datafusion::test_util::arrow_test_data();

Expand Down
Loading