Skip to content

Commit ba3b3f9

Browse files
committed
Teams-based authorization for SQL / subscriptions
1 parent e1cab37 commit ba3b3f9

File tree

24 files changed

+378
-191
lines changed

24 files changed

+378
-191
lines changed

crates/client-api/src/lib.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,20 @@ pub trait Authorization {
490490
database: Identity,
491491
action: Action,
492492
) -> impl Future<Output = Result<(), Unauthorized>> + Send;
493+
494+
/// Obtain an attenuated [AuthCtx] for `subject` to evaluate SQL against
495+
/// `database`.
496+
///
497+
/// "SQL" includes the sql endpoint, pg wire connections, as well as
498+
/// subscription queries.
499+
///
500+
/// If any SQL should be rejected outright, or the authorization database
501+
/// is not available, return `Err(Unauthorized)`.
502+
fn authorize_sql(
503+
&self,
504+
subject: Identity,
505+
database: Identity,
506+
) -> impl Future<Output = Result<AuthCtx, Unauthorized>> + Send;
493507
}
494508

495509
impl<T: Authorization> Authorization for Arc<T> {
@@ -501,6 +515,14 @@ impl<T: Authorization> Authorization for Arc<T> {
501515
) -> impl Future<Output = Result<(), Unauthorized>> + Send {
502516
(**self).authorize_action(subject, database, action)
503517
}
518+
519+
fn authorize_sql(
520+
&self,
521+
subject: Identity,
522+
database: Identity,
523+
) -> impl Future<Output = Result<AuthCtx, Unauthorized>> + Send {
524+
(**self).authorize_sql(subject, database)
525+
}
504526
}
505527

506528
pub fn log_and_500(e: impl std::fmt::Display) -> ErrorResponse {

crates/client-api/src/routes/database.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use spacetimedb::host::module_host::ClientConnectedError;
2727
use spacetimedb::host::ReducerOutcome;
2828
use spacetimedb::host::{FunctionArgs, MigratePlanResult};
2929
use spacetimedb::host::{ReducerCallError, UpdateDatabaseResult};
30-
use spacetimedb::identity::{AuthCtx, Identity};
30+
use spacetimedb::identity::Identity;
3131
use spacetimedb::messages::control_db::{Database, HostType};
3232
use spacetimedb_client_api_messages::http::SqlStmtResult;
3333
use spacetimedb_client_api_messages::name::{
@@ -420,7 +420,7 @@ pub async fn sql_direct<S>(
420420
sql: String,
421421
) -> axum::response::Result<Vec<SqlStmtResult<ProductValue>>>
422422
where
423-
S: NodeDelegate + ControlStateDelegate,
423+
S: NodeDelegate + ControlStateDelegate + Authorization,
424424
{
425425
// Anyone is authorized to execute SQL queries. The SQL engine will determine
426426
// which queries this identity is allowed to execute against the database.
@@ -430,8 +430,9 @@ where
430430
.await?
431431
.ok_or(NO_SUCH_DATABASE)?;
432432

433-
let auth = AuthCtx::new(database.owner_identity, caller_identity);
434-
log::debug!("auth: {auth:?}");
433+
let auth = worker_ctx
434+
.authorize_sql(caller_identity, database.database_identity)
435+
.await?;
435436

436437
let host = worker_ctx
437438
.leader(database.id)
@@ -450,7 +451,7 @@ pub async fn sql<S>(
450451
body: String,
451452
) -> axum::response::Result<impl IntoResponse>
452453
where
453-
S: NodeDelegate + ControlStateDelegate,
454+
S: NodeDelegate + ControlStateDelegate + Authorization,
454455
{
455456
let json = sql_direct(worker_ctx, name_or_identity, params, auth.claims.identity, body).await?;
456457

crates/client-api/src/routes/subscribe.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ use crate::util::websocket::{
4949
CloseCode, CloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, WebSocketUpgrade, WsError,
5050
};
5151
use crate::util::{NameOrIdentity, XForwardedFor};
52-
use crate::{log_and_500, ControlStateDelegate, NodeDelegate};
52+
use crate::{log_and_500, Authorization, ControlStateDelegate, NodeDelegate};
5353

5454
#[allow(clippy::declare_interior_mutable_const)]
5555
pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::TEXT_PROTOCOL);
@@ -106,7 +106,7 @@ pub async fn handle_websocket<S>(
106106
ws: WebSocketUpgrade,
107107
) -> axum::response::Result<impl IntoResponse>
108108
where
109-
S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions,
109+
S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions + Authorization,
110110
{
111111
if connection_id.is_some() {
112112
// TODO: Bump this up to `log::warn!` after removing the client SDKs' uses of that parameter.
@@ -125,6 +125,7 @@ where
125125
}
126126

127127
let db_identity = name_or_identity.resolve(&ctx).await?;
128+
let sql_auth = ctx.authorize_sql(auth.claims.identity, db_identity).await?;
128129

129130
let (res, ws_upgrade, protocol) =
130131
ws.select_protocol([(BIN_PROTOCOL, Protocol::Binary), (TEXT_PROTOCOL, Protocol::Text)]);
@@ -218,6 +219,7 @@ where
218219
let client = ClientConnection::spawn(
219220
client_id,
220221
auth.into(),
222+
sql_auth,
221223
client_config,
222224
leader.replica_id,
223225
module_rx,

crates/core/src/client/client_connection.rs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use spacetimedb_client_api_messages::websocket::{
2828
UnsubscribeMulti,
2929
};
3030
use spacetimedb_durability::{DurableOffset, TxOffset};
31-
use spacetimedb_lib::identity::RequestId;
31+
use spacetimedb_lib::identity::{AuthCtx, RequestId};
3232
use spacetimedb_lib::metrics::ExecutionMetrics;
3333
use spacetimedb_lib::Identity;
3434
use tokio::sync::mpsc::error::{SendError, TrySendError};
@@ -423,6 +423,7 @@ pub struct ClientConnection {
423423
sender: Arc<ClientConnectionSender>,
424424
pub replica_id: u64,
425425
module_rx: watch::Receiver<ModuleHost>,
426+
auth: AuthCtx,
426427
}
427428

428429
impl Deref for ClientConnection {
@@ -674,9 +675,11 @@ impl ClientConnection {
674675
/// to verify that the database at `module_rx` approves of this connection,
675676
/// and should not invoke this method if that call returns an error,
676677
/// and pass the returned [`Connected`] as `_proof_of_client_connected_call`.
678+
#[allow(clippy::too_many_arguments)]
677679
pub async fn spawn<Fut>(
678680
id: ClientActorId,
679681
auth: ConnectionAuthCtx,
682+
sql_auth: AuthCtx,
680683
config: ClientConfig,
681684
replica_id: u64,
682685
mut module_rx: watch::Receiver<ModuleHost>,
@@ -734,6 +737,7 @@ impl ClientConnection {
734737
sender,
735738
replica_id,
736739
module_rx,
740+
auth: sql_auth,
737741
};
738742

739743
let actor_fut = actor(this.clone(), receiver);
@@ -749,10 +753,12 @@ impl ClientConnection {
749753
replica_id: u64,
750754
module_rx: watch::Receiver<ModuleHost>,
751755
) -> Self {
756+
let auth = AuthCtx::new(module_rx.borrow().database_info().database_identity, id.identity);
752757
Self {
753758
sender: Arc::new(ClientConnectionSender::dummy(id, config, module_rx.clone())),
754759
replica_id,
755760
module_rx,
761+
auth,
756762
}
757763
}
758764

@@ -842,9 +848,13 @@ impl ClientConnection {
842848
let me = self.clone();
843849
self.module()
844850
.on_module_thread("subscribe_single", move || {
845-
me.module()
846-
.subscriptions()
847-
.add_single_subscription(me.sender, subscription, timer, None)
851+
me.module().subscriptions().add_single_subscription(
852+
me.sender,
853+
me.auth.clone(),
854+
subscription,
855+
timer,
856+
None,
857+
)
848858
})
849859
.await?
850860
}
@@ -854,7 +864,7 @@ impl ClientConnection {
854864
asyncify(move || {
855865
me.module()
856866
.subscriptions()
857-
.remove_single_subscription(me.sender, request, timer)
867+
.remove_single_subscription(me.sender, me.auth.clone(), request, timer)
858868
})
859869
.await
860870
}
@@ -869,7 +879,7 @@ impl ClientConnection {
869879
.on_module_thread("subscribe_multi", move || {
870880
me.module()
871881
.subscriptions()
872-
.add_multi_subscription(me.sender, request, timer, None)
882+
.add_multi_subscription(me.sender, me.auth.clone(), request, timer, None)
873883
})
874884
.await?
875885
}
@@ -884,7 +894,7 @@ impl ClientConnection {
884894
.on_module_thread("unsubscribe_multi", move || {
885895
me.module()
886896
.subscriptions()
887-
.remove_multi_subscription(me.sender, request, timer)
897+
.remove_multi_subscription(me.sender, me.auth.clone(), request, timer)
888898
})
889899
.await?
890900
}
@@ -894,7 +904,7 @@ impl ClientConnection {
894904
asyncify(move || {
895905
me.module()
896906
.subscriptions()
897-
.add_legacy_subscriber(me.sender, subscription, timer, None)
907+
.add_legacy_subscriber(me.sender, me.auth.clone(), subscription, timer, None)
898908
})
899909
.await
900910
}
@@ -907,7 +917,7 @@ impl ClientConnection {
907917
) -> Result<(), anyhow::Error> {
908918
self.module()
909919
.one_off_query::<JsonFormat>(
910-
self.id.identity,
920+
self.auth.clone(),
911921
query.to_owned(),
912922
self.sender.clone(),
913923
message_id.to_owned(),
@@ -925,7 +935,7 @@ impl ClientConnection {
925935
) -> Result<(), anyhow::Error> {
926936
self.module()
927937
.one_off_query::<BsatnFormat>(
928-
self.id.identity,
938+
self.auth.clone(),
929939
query.to_owned(),
930940
self.sender.clone(),
931941
message_id.to_owned(),

crates/core/src/db/relational_db.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,10 @@ impl RelationalDB {
624624
self.database_identity
625625
}
626626

627+
pub fn owner_identity(&self) -> Identity {
628+
self.owner_identity
629+
}
630+
627631
/// The number of bytes on disk occupied by the durability layer.
628632
///
629633
/// If this is an in-memory instance, `Ok(0)` is returned.

crates/core/src/host/host_controller.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,7 @@ async fn make_replica_ctx(
545545
send_worker_queue.clone(),
546546
)));
547547
let downgraded = Arc::downgrade(&subscriptions);
548-
let subscriptions = ModuleSubscriptions::new(
549-
relational_db.clone(),
550-
subscriptions,
551-
send_worker_queue,
552-
database.owner_identity,
553-
);
548+
let subscriptions = ModuleSubscriptions::new(relational_db.clone(), subscriptions, send_worker_queue);
554549

555550
// If an error occurs when evaluating a subscription,
556551
// we mark each client that was affected,

crates/core/src/host/module_host.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,7 +1329,7 @@ impl ModuleHost {
13291329
#[tracing::instrument(level = "trace", skip_all)]
13301330
pub async fn one_off_query<F: BuildableWebsocketFormat>(
13311331
&self,
1332-
caller_identity: Identity,
1332+
auth: AuthCtx,
13331333
query: String,
13341334
client: Arc<ClientConnectionSender>,
13351335
message_id: Vec<u8>,
@@ -1340,7 +1340,6 @@ impl ModuleHost {
13401340
let replica_ctx = self.replica_ctx();
13411341
let db = replica_ctx.relational_db.clone();
13421342
let subscriptions = replica_ctx.subscriptions.clone();
1343-
let auth = AuthCtx::new(replica_ctx.owner_identity, caller_identity);
13441343
log::debug!("One-off query: {query}");
13451344
let metrics = self
13461345
.on_module_thread("one_off_query", move || {

crates/core/src/sql/ast.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use spacetimedb_datastore::locking_tx_datastore::state_view::StateView;
66
use spacetimedb_datastore::system_tables::{StRowLevelSecurityFields, ST_ROW_LEVEL_SECURITY_ID};
77
use spacetimedb_expr::check::SchemaView;
88
use spacetimedb_expr::statement::compile_sql_stmt;
9-
use spacetimedb_lib::db::auth::StAccess;
109
use spacetimedb_lib::identity::AuthCtx;
1110
use spacetimedb_primitives::{ColId, TableId};
1211
use spacetimedb_sats::{AlgebraicType, AlgebraicValue};
@@ -492,22 +491,20 @@ impl<T> Deref for SchemaViewer<'_, T> {
492491

493492
impl<T: StateView> SchemaView for SchemaViewer<'_, T> {
494493
fn table_id(&self, name: &str) -> Option<TableId> {
495-
let AuthCtx { owner, caller } = self.auth;
496494
// Get the schema from the in-memory state instead of fetching from the database for speed
497495
self.tx
498496
.table_id_from_name(name)
499497
.ok()
500498
.flatten()
501499
.and_then(|table_id| self.schema_for_table(table_id))
502-
.filter(|schema| schema.table_access == StAccess::Public || caller == owner)
500+
.filter(|schema| self.auth.has_read_access(schema.table_access))
503501
.map(|schema| schema.table_id)
504502
}
505503

506504
fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>> {
507-
let AuthCtx { owner, caller } = self.auth;
508505
self.tx
509506
.get_schema(table_id)
510-
.filter(|schema| schema.table_access == StAccess::Public || caller == owner)
507+
.filter(|schema| self.auth.has_read_access(schema.table_access))
511508
.cloned()
512509
}
513510

crates/core/src/sql/execute.rs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,15 @@ pub fn execute_sql(
122122
let mut tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql);
123123
let mut updates = Vec::with_capacity(ast.len());
124124
let res = execute(
125-
&mut DbProgram::new(db, &mut (&mut tx).into(), auth),
125+
&mut DbProgram::new(db, &mut (&mut tx).into(), auth.clone()),
126126
ast,
127127
sql,
128128
&mut updates,
129129
);
130130
if res.is_ok() && !updates.is_empty() {
131131
let event = ModuleEvent {
132132
timestamp: Timestamp::now(),
133-
caller_identity: auth.caller,
133+
caller_identity: auth.caller(),
134134
caller_connection_id: None,
135135
function_call: ModuleFunctionCall {
136136
reducer: String::new(),
@@ -249,7 +249,7 @@ pub fn run(
249249
}
250250
Statement::DML(stmt) => {
251251
// An extra layer of auth is required for DML
252-
if auth.caller != auth.owner {
252+
if !auth.has_write_access() {
253253
return Err(anyhow!("Only owners are authorized to run SQL DML statements").into());
254254
}
255255

@@ -287,7 +287,7 @@ pub fn run(
287287
None,
288288
ModuleEvent {
289289
timestamp: Timestamp::now(),
290-
caller_identity: auth.caller,
290+
caller_identity: auth.caller(),
291291
caller_connection_id: None,
292292
function_call: ModuleFunctionCall {
293293
reducer: String::new(),
@@ -510,7 +510,7 @@ pub(crate) mod tests {
510510
expected: impl IntoIterator<Item = ProductValue>,
511511
) {
512512
assert_eq!(
513-
run(db, sql, *auth, None, &mut vec![])
513+
run(db, sql, auth.clone(), None, &mut vec![])
514514
.unwrap()
515515
.rows
516516
.into_iter()
@@ -1270,19 +1270,25 @@ pub(crate) mod tests {
12701270
let run = |db, sql, auth, subs| run(db, sql, auth, subs, &mut vec![]);
12711271

12721272
// No row limit, both queries pass.
1273-
assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok());
1274-
assert!(run(&db, "SELECT * FROM T", external_auth, None).is_ok());
1273+
assert!(run(&db, "SELECT * FROM T", internal_auth.clone(), None).is_ok());
1274+
assert!(run(&db, "SELECT * FROM T", external_auth.clone(), None).is_ok());
12751275

12761276
// Set row limit.
1277-
assert!(run(&db, "SET row_limit = 4", internal_auth, None).is_ok());
1277+
assert!(run(&db, "SET row_limit = 4", internal_auth.clone(), None).is_ok());
12781278

12791279
// External query fails.
1280-
assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok());
1281-
assert!(run(&db, "SELECT * FROM T", external_auth, None).is_err());
1280+
assert!(run(&db, "SELECT * FROM T", internal_auth.clone(), None).is_ok());
1281+
assert!(run(&db, "SELECT * FROM T", external_auth.clone(), None).is_err());
12821282

12831283
// Increase row limit.
1284-
assert!(run(&db, "DELETE FROM st_var WHERE name = 'row_limit'", internal_auth, None).is_ok());
1285-
assert!(run(&db, "SET row_limit = 5", internal_auth, None).is_ok());
1284+
assert!(run(
1285+
&db,
1286+
"DELETE FROM st_var WHERE name = 'row_limit'",
1287+
internal_auth.clone(),
1288+
None
1289+
)
1290+
.is_ok());
1291+
assert!(run(&db, "SET row_limit = 5", internal_auth.clone(), None).is_ok());
12861292

12871293
// Both queries pass.
12881294
assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok());
@@ -1333,10 +1339,10 @@ pub(crate) mod tests {
13331339
..ExecutionMetrics::default()
13341340
};
13351341

1336-
check(&db, "INSERT INTO T (a) VALUES (5)", internal_auth, ins)?;
1337-
check(&db, "UPDATE T SET a = 2", internal_auth, upd)?;
1342+
check(&db, "INSERT INTO T (a) VALUES (5)", internal_auth.clone(), ins)?;
1343+
check(&db, "UPDATE T SET a = 2", internal_auth.clone(), upd)?;
13381344
assert_eq!(
1339-
run(&db, "SELECT * FROM T", internal_auth, None)?.rows,
1345+
run(&db, "SELECT * FROM T", internal_auth.clone(), None)?.rows,
13401346
vec![product!(2u8)]
13411347
);
13421348
check(&db, "DELETE FROM T", internal_auth, del)?;

0 commit comments

Comments
 (0)