Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 122 additions & 61 deletions crates/expr/src/rls.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::rc::Rc;
use std::sync::Arc;

use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::identity::{AuthCtx, SqlPermission};
use spacetimedb_primitives::TableId;
use spacetimedb_sql_parser::ast::BinOp;

Expand Down Expand Up @@ -211,6 +212,10 @@ fn resolve_views_for_expr(
suffix: &mut usize,
auth: &AuthCtx,
) -> anyhow::Result<Vec<RelExpr>> {
if auth.bypass_rls() {
return Ok([view].into());
}

let is_return_table = |relvar: &Relvar| return_table_id.is_some_and(|id| relvar.schema.table_id == id);

// Collect the table ids queried by this view.
Expand Down Expand Up @@ -265,7 +270,17 @@ fn resolve_views_for_expr(
ResolveList::new(table_id, resolving.clone()),
has_param,
suffix,
auth,
// Bypass RLS when evaluating the RLS query (but not original query) joins
&AuthCtx::with_permissions(
auth.caller(),
{
let auth = auth.clone();
Arc::new(move |p| match p {
SqlPermission::BypassRLS => true,
_ => auth.has_permission(p),
})
},
),
)?;

// Run alpha conversion on each view definition
Expand Down Expand Up @@ -478,7 +493,7 @@ mod tests {
def::ModuleDef,
schema::{Schema, TableOrViewSchema, TableSchema},
};
use spacetimedb_sql_parser::ast::BinOp;
use spacetimedb_sql_parser::ast::{BinOp, LogOp};

use crate::{
check::{parse_and_type_sub, test_utils::build_module_def, SchemaView},
Expand All @@ -495,6 +510,8 @@ mod tests {
"users" => Some(TableId(0)),
"admins" => Some(TableId(1)),
"player" => Some(TableId(2)),
"secret" => Some(TableId(3)),
"access" => Some(TableId(4)),
_ => None,
}
}
Expand All @@ -504,6 +521,8 @@ mod tests {
0 => Some((TableId(0), "users")),
1 => Some((TableId(1), "admins")),
2 => Some((TableId(2), "player")),
3 => Some((TableId(3), "secret")),
4 => Some((TableId(4), "access")),
_ => None,
}
.and_then(|(table_id, name)| {
Expand All @@ -523,6 +542,8 @@ mod tests {
"select player.* from player join users u on player.id = u.id".into(),
"select player.* from player join admins".into(),
]),
TableId(3) => Ok(vec!["select secret.* from secret join access where access.identity = :sender and access.allowed = true".into()]),
TableId(4) => Ok(vec!["select * from access where false".into()]),
_ => Ok(vec![]),
}
}
Expand All @@ -542,6 +563,14 @@ mod tests {
"player",
ProductType::from([("id", AlgebraicType::U64), ("level_num", AlgebraicType::U64)]),
),
(
"secret",
ProductType::from([("id", AlgebraicType::U64)]),
),
(
"access",
ProductType::from([("identity", AlgebraicType::identity()), ("allowed", AlgebraicType::Bool)]),
),
])
}

Expand Down Expand Up @@ -623,43 +652,29 @@ mod tests {
vec![
ProjectName::Some(
RelExpr::Select(
Box::new(RelExpr::Select(
Box::new(RelExpr::Select(
Box::new(RelExpr::LeftDeepJoin(LeftDeepJoin {
lhs: Box::new(RelExpr::RelVar(Relvar {
schema: player_schema.clone(),
alias: "player".into(),
delta: None,
})),
rhs: Relvar {
schema: users_schema.clone(),
alias: "u_2".into(),
delta: None,
},
})),
Expr::BinOp(
BinOp::Eq,
Box::new(Expr::Field(FieldProject {
table: "u_2".into(),
field: 0,
ty: AlgebraicType::identity(),
})),
Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity())),
),
)),
Expr::BinOp(
BinOp::Eq,
Box::new(Expr::Field(FieldProject {
table: "player".into(),
field: 0,
ty: AlgebraicType::U64,
})),
Box::new(Expr::Field(FieldProject {
table: "u_2".into(),
field: 1,
ty: AlgebraicType::U64,
Box::new(RelExpr::EqJoin(
LeftDeepJoin {
lhs: Box::new(RelExpr::RelVar(Relvar {
schema: player_schema.clone(),
alias: "player".into(),
delta: None,
})),
),
rhs: Relvar {
schema: users_schema.clone(),
alias: "u_1".into(),
delta: None,
},
},
FieldProject {
table: "player".into(),
field: 0,
ty: AlgebraicType::U64,
},
FieldProject {
table: "u_1".into(),
field: 1,
ty: AlgebraicType::U64,
},
)),
Expr::BinOp(
BinOp::Eq,
Expand All @@ -675,29 +690,18 @@ mod tests {
),
ProjectName::Some(
RelExpr::Select(
Box::new(RelExpr::Select(
Box::new(RelExpr::LeftDeepJoin(LeftDeepJoin {
lhs: Box::new(RelExpr::RelVar(Relvar {
schema: player_schema.clone(),
alias: "player".into(),
delta: None,
})),
rhs: Relvar {
schema: admins_schema.clone(),
alias: "admins_4".into(),
delta: None,
},
Box::new(RelExpr::LeftDeepJoin(LeftDeepJoin {
lhs: Box::new(RelExpr::RelVar(Relvar {
schema: player_schema.clone(),
alias: "player".into(),
delta: None,
})),
Expr::BinOp(
BinOp::Eq,
Box::new(Expr::Field(FieldProject {
table: "admins_4".into(),
field: 0,
ty: AlgebraicType::identity(),
})),
Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity())),
),
)),
rhs: Relvar {
schema: admins_schema.clone(),
alias: "admins_2".into(),
delta: None,
},
})),
Expr::BinOp(
BinOp::Eq,
Box::new(Expr::Field(FieldProject {
Expand All @@ -715,4 +719,61 @@ mod tests {

Ok(())
}

#[test]
fn test_multiple_rls_rules_for_multiple_tables() -> anyhow::Result<()> {
let tx = SchemaViewer(module_def());
let auth = AuthCtx::new(Identity::ZERO, Identity::ONE);
let sql = "select * from secret";
let resolved = resolve(sql, &tx, &auth)?;

let secret_schema = tx.schema("secret").unwrap();
let access_schema = tx.schema("access").unwrap();

pretty::assert_eq!(
resolved,
vec![
ProjectName::Some(
RelExpr::Select(
Box::new(RelExpr::LeftDeepJoin(LeftDeepJoin {
lhs: Box::new(RelExpr::RelVar(Relvar {
schema: secret_schema.clone(),
alias: "secret".into(),
delta: None,
})),
rhs: Relvar {
schema: access_schema.clone(),
alias: "access_1".into(),
delta: None,
},
})),
Expr::LogOp(
LogOp::And,
Box::new(Expr::BinOp(
BinOp::Eq,
Box::new(Expr::Field(FieldProject {
table: "access_1".into(),
field: 0,
ty: AlgebraicType::identity(),
})),
Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity())),
)),
Box::new(Expr::BinOp(
BinOp::Eq,
Box::new(Expr::Field(FieldProject {
table: "access_1".into(),
field: 1,
ty: AlgebraicType::Bool,
})),
Box::new(Expr::Value(AlgebraicValue::Bool(true), AlgebraicType::Bool)),
)),
),
),
"secret".into(),
),
]
);

Ok(())
}
}