Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Brent Gardner committed Nov 3, 2022
1 parent e1ebf1d commit ee992ad
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,11 @@ impl SessionConfig {
self.set(key, ScalarValue::UInt64(Some(value)))
}

/// Set a generic `str` configuration option
pub fn set_str(self, key: &str, value: &str) -> Self {
self.set(key, ScalarValue::Utf8(Some(value.to_string())))
}

/// Customize batch size
pub fn with_batch_size(self, n: usize) -> Self {
// batch size must be greater than zero
Expand Down Expand Up @@ -1935,10 +1940,12 @@ impl FunctionRegistry for TaskContext {
mod tests {
use super::*;
use crate::assert_batches_eq;
use crate::datasource::datasource::TableProviderFactory;
use crate::execution::context::QueryPlanner;
use crate::execution::runtime_env::RuntimeConfig;
use crate::physical_plan::expressions::AvgAccumulator;
use crate::test;
use crate::test_util::parquet_test_data;
use crate::test_util::{parquet_test_data, TestTableFactory};
use crate::variable::VarType;
use arrow::array::ArrayRef;
use arrow::datatypes::*;
Expand All @@ -1947,9 +1954,10 @@ mod tests {
use datafusion_expr::{create_udaf, create_udf, Expr, Volatility};
use datafusion_physical_expr::functions::make_scalar_function;
use std::fs::File;
use std::path::PathBuf;
use std::sync::Weak;
use std::thread::{self, JoinHandle};
use std::{io::prelude::*, sync::Mutex};
use std::{env, io::prelude::*, sync::Mutex};
use tempfile::TempDir;

#[tokio::test]
Expand Down Expand Up @@ -2187,6 +2195,40 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn with_listing_schema_provider() -> Result<()> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let url = format!("file://{}/tests/tpch-csv", path.display());

let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
table_factories.insert("test".to_string(), Arc::new(TestTableFactory {}));
let rt_cfg = RuntimeConfig::new().with_table_factories(table_factories);
let runtime = Arc::new(RuntimeEnv::new(rt_cfg).unwrap());
let cfg = SessionConfig::new()
.set_str("datafusion.catalog.location", url.as_str())
.set_str("datafusion.catalog.type", "test");
let session_state = SessionState::with_config_rt(cfg, runtime);
let ctx = SessionContext::with_state(session_state);

let mut table_count = 0;
for cat_name in ctx.catalog_names().iter() {
let cat = ctx.catalog(cat_name).unwrap();
for s_name in cat.schema_names().iter() {
let schema = cat.schema(s_name).unwrap();
if let Some(listing) =
schema.as_any().downcast_ref::<ListingSchemaProvider>()
{
listing.refresh().await.unwrap();
}
table_count = schema.table_names().len();
}
}

assert_eq!(table_count, 8);
Ok(())
}

#[tokio::test]
async fn custom_query_planner() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
Expand Down

0 comments on commit ee992ad

Please sign in to comment.