Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 66 additions & 52 deletions crates/physical-plan/src/compile.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,39 @@
//! Lowering from the logical plan to the physical plan.

use crate::plan;
use crate::plan::{CrossJoin, Filter, PhysicalCtx, PhysicalExpr, PhysicalPlan};
use crate::plan::{PhysicalCtx, PhysicalExpr, PhysicalPlan};
use spacetimedb_expr::expr::{Expr, Let, LetCtx, Project, RelExpr, Select};
use spacetimedb_expr::statement::Statement;
use spacetimedb_expr::ty::{TyCtx, TyId};
use spacetimedb_expr::ty::{TyCtx, Type};
use spacetimedb_expr::StatementCtx;
use spacetimedb_sql_parser::ast::BinOp;

fn compile_expr(_ctx: &TyCtx, vars: &LetCtx, expr: Expr) -> PhysicalExpr {
fn compile_expr(ctx: &TyCtx, vars: &LetCtx, expr: Expr) -> PhysicalExpr {
match expr {
Expr::Bin(op, lhs, rhs) => {
let lhs = compile_expr(_ctx, vars, *lhs);
let rhs = compile_expr(_ctx, vars, *rhs);
let lhs = compile_expr(ctx, vars, *lhs);
let rhs = compile_expr(ctx, vars, *rhs);
PhysicalExpr::BinOp(op, Box::new(lhs), Box::new(rhs))
}
Expr::Var(sym, _ty) => {
let var = vars.get_var(sym).cloned().unwrap();
compile_expr(_ctx, vars, var)
compile_expr(ctx, vars, var)
}
Expr::Row(row, ty) => {
Expr::Row(row, _) => {
PhysicalExpr::Tuple(
row.into_vec()
.into_iter()
// The `sym` is inline in `expr`
.map(|(_sym, expr)| compile_expr(_ctx, vars, expr))
.map(|(_sym, expr)| compile_expr(ctx, vars, expr))
.collect(),
ty,
)
}
Expr::Lit(value, ty) => PhysicalExpr::Value(value, ty),
Expr::Field(expr, pos, ty) => {
let expr = compile_expr(_ctx, vars, *expr);
PhysicalExpr::Field(Box::new(expr), pos, ty)
Expr::Lit(value, _) => PhysicalExpr::Value(value),
Expr::Field(expr, pos, _) => {
let expr = compile_expr(ctx, vars, *expr);
PhysicalExpr::Field(Box::new(expr), pos)
}
Expr::Input(ty) => PhysicalExpr::Input(ty),
Expr::Input(ty) if matches!(*ctx.try_resolve(ty).unwrap(), Type::Var(..)) => PhysicalExpr::Ptr,
Expr::Input(_) => PhysicalExpr::Tup,
}
}

Expand All @@ -54,35 +53,24 @@ fn compile_let(ctx: &TyCtx, Let { vars, exprs }: Let) -> Vec<PhysicalExpr> {
fn compile_filter(ctx: &TyCtx, select: Select) -> PhysicalPlan {
let input = compile_rel_expr(ctx, select.input);
if let Some(op) = join_exprs(compile_let(ctx, select.expr)) {
PhysicalPlan::Filter(Filter {
input: Box::new(input),
op,
})
PhysicalPlan::Filter(Box::new(input), op)
} else {
input
}
}

fn compile_project(ctx: &TyCtx, expr: Project) -> PhysicalPlan {
let proj = plan::Project {
input: Box::new(compile_rel_expr(ctx, expr.input)),
op: join_exprs(compile_let(ctx, expr.expr)).unwrap(),
};
let input = Box::new(compile_rel_expr(ctx, expr.input));
let op = join_exprs(compile_let(ctx, expr.expr)).unwrap();

PhysicalPlan::Project(proj)
PhysicalPlan::Project(input, op)
}

fn compile_cross_joins(ctx: &TyCtx, joins: Vec<RelExpr>, ty: TyId) -> PhysicalPlan {
fn compile_cross_joins(ctx: &TyCtx, joins: Vec<RelExpr>) -> PhysicalPlan {
joins
.into_iter()
.map(|expr| compile_rel_expr(ctx, expr))
.reduce(|lhs, rhs| {
PhysicalPlan::CrossJoin(CrossJoin {
lhs: Box::new(lhs),
rhs: Box::new(rhs),
ty,
})
})
.reduce(|lhs, rhs| PhysicalPlan::NLJoin(Box::new(lhs), Box::new(rhs)))
.unwrap()
}

Expand All @@ -91,7 +79,7 @@ fn compile_rel_expr(ctx: &TyCtx, ast: RelExpr) -> PhysicalPlan {
RelExpr::RelVar(table, _ty) => PhysicalPlan::TableScan(table),
RelExpr::Select(select) => compile_filter(ctx, *select),
RelExpr::Proj(proj) => compile_project(ctx, *proj),
RelExpr::Join(joins, ty) => compile_cross_joins(ctx, joins.into_vec(), ty),
RelExpr::Join(joins, _) => compile_cross_joins(ctx, joins.into_vec()),
RelExpr::Union(_, _) | RelExpr::Minus(_, _) | RelExpr::Dedup(_) => {
unreachable!("DISTINCT is not implemented")
}
Expand Down Expand Up @@ -165,24 +153,50 @@ mod tests {
Ok(statement)
}

impl PhysicalPlan {
pub fn as_project(&self) -> Option<(&PhysicalPlan, &PhysicalExpr)> {
if let PhysicalPlan::Project(input, expr) = self {
Some((input, expr))
} else {
None
}
}

pub fn as_filter(&self) -> Option<(&PhysicalPlan, &PhysicalExpr)> {
if let PhysicalPlan::Filter(input, expr) = self {
Some((input, expr))
} else {
None
}
}

pub fn as_nljoin(&self) -> Option<(&PhysicalPlan, &PhysicalPlan)> {
if let PhysicalPlan::NLJoin(lhs, rhs) = self {
Some((lhs, rhs))
} else {
None
}
}
}

#[test]
fn test_project() -> ResultTest<()> {
let (ast, ctx) = compile_sql_sub_test("SELECT * FROM t")?;
assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::TableScan(_)));

let ast = compile_sql_stmt_test("SELECT u32 FROM t")?;
assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::Project(_)));
assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::Project(..)));

Ok(())
}

#[test]
fn test_select() -> ResultTest<()> {
let (ast, ctx) = compile_sql_sub_test("SELECT * FROM t WHERE u32 = 1")?;
assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::Filter(_)));
assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::Filter(..)));

let (ast, ctx) = compile_sql_sub_test("SELECT * FROM t WHERE u32 = 1 AND f32 = f32")?;
assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::Filter(_)));
assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::Filter(..)));
Ok(())
}

Expand All @@ -191,35 +205,35 @@ mod tests {
// Check we can do a cross join
let (ast, ctx) = compile_sql_sub_test("SELECT t.* FROM t JOIN u")?;
let ast = compile(&ctx, ast).plan;
let plan::Project { input, op } = ast.as_project().unwrap();
let CrossJoin { lhs, rhs, ty: _ } = input.as_cross().unwrap();
let (input, op) = ast.as_project().unwrap();
let (lhs, rhs) = input.as_nljoin().unwrap();

assert!(matches!(op, PhysicalExpr::Field(_, _, _)));
assert!(matches!(&**lhs, PhysicalPlan::TableScan(_)));
assert!(matches!(&**rhs, PhysicalPlan::TableScan(_)));
assert!(matches!(op, PhysicalExpr::Field(..)));
assert!(matches!(lhs, PhysicalPlan::TableScan(_)));
assert!(matches!(rhs, PhysicalPlan::TableScan(_)));

// Check we can do multiple joins
let (ast, ctx) = compile_sql_sub_test("SELECT t.* FROM t JOIN u JOIN x")?;
let ast = compile(&ctx, ast).plan;
let plan::Project { input, op: _ } = ast.as_project().unwrap();
let CrossJoin { lhs, rhs, ty: _ } = input.as_cross().unwrap();
assert!(matches!(&**rhs, PhysicalPlan::TableScan(_)));
let (input, _) = ast.as_project().unwrap();
let (lhs, rhs) = input.as_nljoin().unwrap();
assert!(matches!(rhs, PhysicalPlan::TableScan(_)));

let CrossJoin { lhs, rhs, ty: _ } = lhs.as_cross().unwrap();
assert!(matches!(&**lhs, PhysicalPlan::TableScan(_)));
assert!(matches!(&**rhs, PhysicalPlan::TableScan(_)));
let (lhs, rhs) = lhs.as_nljoin().unwrap();
assert!(matches!(lhs, PhysicalPlan::TableScan(_)));
assert!(matches!(rhs, PhysicalPlan::TableScan(_)));

// Check we can do a join with a filter
let (ast, ctx) = compile_sql_sub_test("SELECT t.* FROM t JOIN u ON t.u32 = u.u32")?;
let ast = compile(&ctx, ast).plan;

let plan::Project { input, op: _ } = ast.as_project().unwrap();
let Filter { input, op } = input.as_filter().unwrap();
let (input, _) = ast.as_project().unwrap();
let (input, op) = input.as_filter().unwrap();
assert!(matches!(op, PhysicalExpr::BinOp(_, _, _)));

let CrossJoin { lhs, rhs, ty: _ } = input.as_cross().unwrap();
assert!(matches!(&**lhs, PhysicalPlan::TableScan(_)));
assert!(matches!(&**rhs, PhysicalPlan::TableScan(_)));
let (lhs, rhs) = input.as_nljoin().unwrap();
assert!(matches!(lhs, PhysicalPlan::TableScan(_)));
assert!(matches!(rhs, PhysicalPlan::TableScan(_)));

Ok(())
}
Expand Down
Loading