Skip to content

Commit

Permalink
Move some tests from core to expr (#2700)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Jun 6, 2022
1 parent 9306534 commit bbb674a
Show file tree
Hide file tree
Showing 2 changed files with 334 additions and 334 deletions.
334 changes: 0 additions & 334 deletions datafusion/core/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,337 +97,3 @@ pub fn source_as_provider(
)),
}
}

#[cfg(test)]
mod tests {
use super::super::{col, lit};
use super::*;
use crate::test_util::scan_empty;
use arrow::datatypes::{DataType, Field, Schema};

fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Int32, false),
])
}

fn display_plan() -> LogicalPlan {
scan_empty(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))
.unwrap()
.filter(col("state").eq(lit("CO")))
.unwrap()
.project(vec![col("id")])
.unwrap()
.build()
.unwrap()
}

#[test]
fn test_display_indent() {
let plan = display_plan();

let expected = "Projection: #employee_csv.id\
\n Filter: #employee_csv.state = Utf8(\"CO\")\
\n TableScan: employee_csv projection=Some([id, state])";

assert_eq!(expected, format!("{}", plan.display_indent()));
}

#[test]
fn test_display_indent_schema() {
let plan = display_plan();

let expected = "Projection: #employee_csv.id [id:Int32]\
\n Filter: #employee_csv.state = Utf8(\"CO\") [id:Int32, state:Utf8]\
\n TableScan: employee_csv projection=Some([id, state]) [id:Int32, state:Utf8]";

assert_eq!(expected, format!("{}", plan.display_indent_schema()));
}

#[test]
fn test_display_graphviz() {
let plan = display_plan();

// just test for a few key lines in the output rather than the
// whole thing to make test mainteance easier.
let graphviz = format!("{}", plan.display_graphviz());

assert!(
graphviz.contains(
r#"// Begin DataFusion GraphViz Plan (see https://graphviz.org)"#
),
"\n{}",
plan.display_graphviz()
);
assert!(
graphviz.contains(
r#"[shape=box label="TableScan: employee_csv projection=Some([id, state])"]"#
),
"\n{}",
plan.display_graphviz()
);
assert!(graphviz.contains(r#"[shape=box label="TableScan: employee_csv projection=Some([id, state])\nSchema: [id:Int32, state:Utf8]"]"#),
"\n{}", plan.display_graphviz());
assert!(
graphviz.contains(r#"// End DataFusion GraphViz Plan"#),
"\n{}",
plan.display_graphviz()
);
}

/// Tests for the Visitor trait and walking logical plan nodes
#[derive(Debug, Default)]
struct OkVisitor {
strings: Vec<String>,
}

impl PlanVisitor for OkVisitor {
type Error = String;

fn pre_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
let s = match plan {
LogicalPlan::Projection { .. } => "pre_visit Projection",
LogicalPlan::Filter { .. } => "pre_visit Filter",
LogicalPlan::TableScan { .. } => "pre_visit TableScan",
_ => unimplemented!("unknown plan type"),
};

self.strings.push(s.into());
Ok(true)
}

fn post_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
let s = match plan {
LogicalPlan::Projection { .. } => "post_visit Projection",
LogicalPlan::Filter { .. } => "post_visit Filter",
LogicalPlan::TableScan { .. } => "post_visit TableScan",
_ => unimplemented!("unknown plan type"),
};

self.strings.push(s.into());
Ok(true)
}
}

#[test]
fn visit_order() {
let mut visitor = OkVisitor::default();
let plan = test_plan();
let res = plan.accept(&mut visitor);
assert!(res.is_ok());

assert_eq!(
visitor.strings,
vec![
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
"post_visit Filter",
"post_visit Projection",
]
);
}

#[derive(Debug, Default)]
/// Counter than counts to zero and returns true when it gets there
struct OptionalCounter {
val: Option<usize>,
}

impl OptionalCounter {
fn new(val: usize) -> Self {
Self { val: Some(val) }
}
// Decrements the counter by 1, if any, returning true if it hits zero
fn dec(&mut self) -> bool {
if Some(0) == self.val {
true
} else {
self.val = self.val.take().map(|i| i - 1);
false
}
}
}

#[derive(Debug, Default)]
/// Visitor that returns false after some number of visits
struct StoppingVisitor {
inner: OkVisitor,
/// When Some(0) returns false from pre_visit
return_false_from_pre_in: OptionalCounter,
/// When Some(0) returns false from post_visit
return_false_from_post_in: OptionalCounter,
}

impl PlanVisitor for StoppingVisitor {
type Error = String;

fn pre_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
if self.return_false_from_pre_in.dec() {
return Ok(false);
}
self.inner.pre_visit(plan)
}

fn post_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
if self.return_false_from_post_in.dec() {
return Ok(false);
}

self.inner.post_visit(plan)
}
}

/// test early stopping in pre-visit
#[test]
fn early_stopping_pre_visit() {
let mut visitor = StoppingVisitor {
return_false_from_pre_in: OptionalCounter::new(2),
..Default::default()
};
let plan = test_plan();
let res = plan.accept(&mut visitor);
assert!(res.is_ok());

assert_eq!(
visitor.inner.strings,
vec!["pre_visit Projection", "pre_visit Filter"]
);
}

#[test]
fn early_stopping_post_visit() {
let mut visitor = StoppingVisitor {
return_false_from_post_in: OptionalCounter::new(1),
..Default::default()
};
let plan = test_plan();
let res = plan.accept(&mut visitor);
assert!(res.is_ok());

assert_eq!(
visitor.inner.strings,
vec![
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
]
);
}

#[derive(Debug, Default)]
/// Visitor that returns an error after some number of visits
struct ErrorVisitor {
inner: OkVisitor,
/// When Some(0) returns false from pre_visit
return_error_from_pre_in: OptionalCounter,
/// When Some(0) returns false from post_visit
return_error_from_post_in: OptionalCounter,
}

impl PlanVisitor for ErrorVisitor {
type Error = String;

fn pre_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
if self.return_error_from_pre_in.dec() {
return Err("Error in pre_visit".into());
}

self.inner.pre_visit(plan)
}

fn post_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
if self.return_error_from_post_in.dec() {
return Err("Error in post_visit".into());
}

self.inner.post_visit(plan)
}
}

#[test]
fn error_pre_visit() {
let mut visitor = ErrorVisitor {
return_error_from_pre_in: OptionalCounter::new(2),
..Default::default()
};
let plan = test_plan();
let res = plan.accept(&mut visitor);

if let Err(e) = res {
assert_eq!("Error in pre_visit", e);
} else {
panic!("Expected an error");
}

assert_eq!(
visitor.inner.strings,
vec!["pre_visit Projection", "pre_visit Filter"]
);
}

#[test]
fn error_post_visit() {
let mut visitor = ErrorVisitor {
return_error_from_post_in: OptionalCounter::new(1),
..Default::default()
};
let plan = test_plan();
let res = plan.accept(&mut visitor);
if let Err(e) = res {
assert_eq!("Error in post_visit", e);
} else {
panic!("Expected an error");
}

assert_eq!(
visitor.inner.strings,
vec![
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
]
);
}

fn test_plan() -> LogicalPlan {
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("state", DataType::Utf8, false),
]);

scan_empty(None, &schema, Some(vec![0, 1]))
.unwrap()
.filter(col("state").eq(lit("CO")))
.unwrap()
.project(vec![col("id")])
.unwrap()
.build()
.unwrap()
}
}
Loading

0 comments on commit bbb674a

Please sign in to comment.