Skip to content
This repository was archived by the owner on Oct 18, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sqld/src/connection/libsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ impl super::Connection for LibSqlConnection {
pgm: Program,
auth: Authenticated,
builder: B,
_replication_index: Option<FrameNo>,
) -> Result<(B, State)> {
check_program_auth(auth, &pgm)?;
let (resp, receiver) = oneshot::channel();
Expand Down Expand Up @@ -568,7 +569,12 @@ impl super::Connection for LibSqlConnection {
Ok(receiver.await??)
}

async fn describe(&self, sql: String, auth: Authenticated) -> Result<DescribeResult> {
async fn describe(
&self,
sql: String,
auth: Authenticated,
_replication_index: Option<FrameNo>,
) -> Result<DescribeResult> {
check_describe_auth(auth)?;
let (resp, receiver) = oneshot::channel();
let cb = Box::new(move |maybe_conn: Result<&mut Connection>| {
Expand Down
37 changes: 30 additions & 7 deletions sqld/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::error::Error;
use crate::query::{Params, Query};
use crate::query_analysis::{State, Statement};
use crate::query_result_builder::{IgnoreResult, QueryResultBuilder};
use crate::replication::FrameNo;
use crate::Result;

use self::program::{Cond, DescribeResult, Program, Step};
Expand All @@ -29,7 +30,8 @@ pub trait Connection: Send + Sync + 'static {
&self,
pgm: Program,
auth: Authenticated,
reponse_builder: B,
response_builder: B,
replication_index: Option<FrameNo>,
) -> Result<(B, State)>;

/// Execute all the queries in the batch sequentially.
Expand All @@ -40,6 +42,7 @@ pub trait Connection: Send + Sync + 'static {
batch: Vec<Query>,
auth: Authenticated,
result_builder: B,
replication_index: Option<FrameNo>,
) -> Result<(B, State)> {
let batch_len = batch.len();
let mut steps = make_batch_program(batch);
Expand All @@ -64,7 +67,9 @@ pub trait Connection: Send + Sync + 'static {

// ignore the rollback result
let builder = result_builder.take(batch_len);
let (builder, state) = self.execute_program(pgm, auth, builder).await?;
let (builder, state) = self
.execute_program(pgm, auth, builder, replication_index)
.await?;

Ok((builder.into_inner(), state))
}
Expand All @@ -76,10 +81,12 @@ pub trait Connection: Send + Sync + 'static {
batch: Vec<Query>,
auth: Authenticated,
result_builder: B,
replication_index: Option<FrameNo>,
) -> Result<(B, State)> {
let steps = make_batch_program(batch);
let pgm = Program::new(steps);
self.execute_program(pgm, auth, result_builder).await
self.execute_program(pgm, auth, result_builder, replication_index)
.await
}

async fn rollback(&self, auth: Authenticated) -> Result<()> {
Expand All @@ -91,14 +98,20 @@ pub trait Connection: Send + Sync + 'static {
}],
auth,
IgnoreResult,
None,
)
.await?;

Ok(())
}

/// Parse the SQL statement and return information about it.
async fn describe(&self, sql: String, auth: Authenticated) -> Result<DescribeResult>;
async fn describe(
&self,
sql: String,
auth: Authenticated,
replication_index: Option<FrameNo>,
) -> Result<DescribeResult>;

/// Check whether the connection is in autocommit mode.
async fn is_autocommit(&self) -> Result<bool>;
Expand Down Expand Up @@ -271,13 +284,21 @@ impl<DB: Connection> Connection for TrackedConnection<DB> {
pgm: Program,
auth: Authenticated,
builder: B,
replication_index: Option<FrameNo>,
) -> crate::Result<(B, State)> {
self.inner.execute_program(pgm, auth, builder).await
self.inner
.execute_program(pgm, auth, builder, replication_index)
.await
}

#[inline]
async fn describe(&self, sql: String, auth: Authenticated) -> crate::Result<DescribeResult> {
self.inner.describe(sql, auth).await
async fn describe(
&self,
sql: String,
auth: Authenticated,
replication_index: Option<FrameNo>,
) -> crate::Result<DescribeResult> {
self.inner.describe(sql, auth, replication_index).await
}

#[inline]
Expand All @@ -304,6 +325,7 @@ mod test {
_pgm: Program,
_auth: Authenticated,
_builder: B,
_replication_index: Option<FrameNo>,
) -> crate::Result<(B, State)> {
unreachable!()
}
Expand All @@ -312,6 +334,7 @@ mod test {
&self,
_sql: String,
_auth: Authenticated,
_replication_index: Option<FrameNo>,
) -> crate::Result<DescribeResult> {
unreachable!()
}
Expand Down
25 changes: 16 additions & 9 deletions sqld/src/connection/write_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,10 @@ impl WriteProxyConnection {
}
}

/// wait for the replicator to have caught up with our current write frame_no
async fn wait_replication_sync(&self) -> Result<()> {
let current_fno = *self.last_write_frame_no.lock();
/// wait for the replicator to have caught up with the replication_index if `Some` or our
/// current write frame_no
async fn wait_replication_sync(&self, replication_index: Option<FrameNo>) -> Result<()> {
let current_fno = replication_index.or_else(|| *self.last_write_frame_no.lock());
match current_fno {
Some(current_frame_no) => {
let mut receiver = self.applied_frame_no_receiver.clone();
Expand All @@ -273,16 +274,17 @@ impl Connection for WriteProxyConnection {
pgm: Program,
auth: Authenticated,
builder: B,
replication_index: Option<FrameNo>,
) -> Result<(B, State)> {
let mut state = self.state.lock().await;
if *state == State::Init && pgm.is_read_only() {
self.wait_replication_sync().await?;
self.wait_replication_sync(replication_index).await?;
// We know that this program won't perform any writes. We attempt to run it on the
// replica. If it leaves an open transaction, then this program is an interactive
// transaction, so we rollback the replica, and execute again on the primary.
let (builder, new_state) = self
.read_conn
.execute_program(pgm.clone(), auth.clone(), builder)
.execute_program(pgm.clone(), auth.clone(), builder, replication_index)
.await?;
if new_state != State::Init {
self.read_conn.rollback(auth.clone()).await?;
Expand All @@ -295,9 +297,14 @@ impl Connection for WriteProxyConnection {
}
}

async fn describe(&self, sql: String, auth: Authenticated) -> Result<DescribeResult> {
self.wait_replication_sync().await?;
self.read_conn.describe(sql, auth).await
async fn describe(
&self,
sql: String,
auth: Authenticated,
replication_index: Option<FrameNo>,
) -> Result<DescribeResult> {
self.wait_replication_sync(replication_index).await?;
self.read_conn.describe(sql, auth, replication_index).await
}

async fn is_autocommit(&self) -> Result<bool> {
Expand All @@ -309,7 +316,7 @@ impl Connection for WriteProxyConnection {
}

async fn checkpoint(&self) -> Result<()> {
self.wait_replication_sync().await?;
self.wait_replication_sync(None).await?;
self.read_conn.checkpoint().await
}
}
Expand Down
7 changes: 5 additions & 2 deletions sqld/src/hrana/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::query_analysis::Statement;
use crate::query_result_builder::{
QueryResultBuilder, QueryResultBuilderError, StepResult, StepResultsBuilder,
};
use crate::replication::FrameNo;

use super::result_builder::HranaBatchProtoBuilder;
use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error};
Expand Down Expand Up @@ -106,10 +107,11 @@ pub async fn execute_batch(
db: &impl Connection,
auth: Authenticated,
pgm: Program,
replication_index: Option<u64>,
) -> Result<proto::BatchResult> {
let batch_builder = HranaBatchProtoBuilder::default();
let (builder, _state) = db
.execute_program(pgm, auth, batch_builder)
.execute_program(pgm, auth, batch_builder, replication_index)
.await
.map_err(catch_batch_error)?;

Expand Down Expand Up @@ -146,10 +148,11 @@ pub async fn execute_sequence(
db: &impl Connection,
auth: Authenticated,
pgm: Program,
replication_index: Option<FrameNo>,
) -> Result<()> {
let builder = StepResultsBuilder::default();
let (builder, _state) = db
.execute_program(pgm, auth, builder)
.execute_program(pgm, auth, builder, replication_index)
.await
.map_err(catch_batch_error)?;
builder
Expand Down
23 changes: 20 additions & 3 deletions sqld/src/hrana/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct OpenReq<C> {
db: Arc<C>,
auth: Authenticated,
pgm: Program,
replication_index: Option<FrameNo>,
}

impl<C> CursorHandle<C> {
Expand All @@ -48,9 +49,20 @@ impl<C> CursorHandle<C> {
}
}

pub fn open(&mut self, db: Arc<C>, auth: Authenticated, pgm: Program) {
pub fn open(
&mut self,
db: Arc<C>,
auth: Authenticated,
pgm: Program,
replication_index: Option<FrameNo>,
) {
let open_tx = self.open_tx.take().unwrap();
let _: Result<_, _> = open_tx.send(OpenReq { db, auth, pgm });
let _: Result<_, _> = open_tx.send(OpenReq {
db,
auth,
pgm,
replication_index,
});
}

pub async fn fetch(&mut self) -> Result<Option<SizedEntry>> {
Expand Down Expand Up @@ -78,7 +90,12 @@ async fn run_cursor<C: Connection>(

if let Err(err) = open_req
.db
.execute_program(open_req.pgm, open_req.auth, result_builder)
.execute_program(
open_req.pgm,
open_req.auth,
result_builder,
open_req.replication_index,
)
.await
{
let entry = match batch::batch_error_from_sqld_error(err) {
Expand Down
2 changes: 1 addition & 1 deletion sqld/src/hrana/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async fn handle_cursor<C: Connection>(
let db = stream_guard.get_db_owned()?;
let sqls = stream_guard.sqls();
let pgm = batch::proto_batch_to_program(&req_body.batch, sqls, version)?;
cursor_hnd.open(db, auth, pgm);
cursor_hnd.open(db, auth, pgm, req_body.batch.replication_index);

let resp_body = proto::CursorRespBody {
baton: stream_guard.release(),
Expand Down
6 changes: 6 additions & 0 deletions sqld/src/hrana/http/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ pub struct SequenceStreamReq {
#[serde(default)]
#[prost(int32, optional, tag = "2")]
pub sql_id: Option<i32>,
#[serde(default)]
#[prost(uint64, optional, tag = "3")]
pub replication_index: Option<u64>,
}

#[derive(Serialize, prost::Message)]
Expand All @@ -130,6 +133,9 @@ pub struct DescribeStreamReq {
#[serde(default)]
#[prost(int32, optional, tag = "2")]
pub sql_id: Option<i32>,
#[serde(default)]
#[prost(uint64, optional, tag = "3")]
pub replication_index: Option<u64>,
}

#[derive(Serialize, prost::Message)]
Expand Down
8 changes: 4 additions & 4 deletions sqld/src/hrana/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async fn try_handle<D: Connection>(
let sqls = stream_guard.sqls();
let query =
stmt::proto_stmt_to_query(&req.stmt, sqls, version).map_err(catch_stmt_error)?;
let result = stmt::execute_stmt(db, auth, query)
let result = stmt::execute_stmt(db, auth, query, req.stmt.replication_index)
.await
.map_err(catch_stmt_error)?;
proto::StreamResponse::Execute(proto::ExecuteStreamResp { result })
Expand All @@ -73,7 +73,7 @@ async fn try_handle<D: Connection>(
let db = stream_guard.get_db()?;
let sqls = stream_guard.sqls();
let pgm = batch::proto_batch_to_program(&req.batch, sqls, version)?;
let result = batch::execute_batch(db, auth, pgm)
let result = batch::execute_batch(db, auth, pgm, req.batch.replication_index)
.await
.map_err(catch_batch_error)?;
proto::StreamResponse::Batch(proto::BatchStreamResp { result })
Expand All @@ -83,7 +83,7 @@ async fn try_handle<D: Connection>(
let sqls = stream_guard.sqls();
let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, version)?;
let pgm = batch::proto_sequence_to_program(sql).map_err(catch_stmt_error)?;
batch::execute_sequence(db, auth, pgm)
batch::execute_sequence(db, auth, pgm, req.replication_index)
.await
.map_err(catch_stmt_error)
.map_err(catch_batch_error)?;
Expand All @@ -93,7 +93,7 @@ async fn try_handle<D: Connection>(
let db = stream_guard.get_db()?;
let sqls = stream_guard.sqls();
let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, version)?;
let result = stmt::describe_stmt(db, auth, sql.into())
let result = stmt::describe_stmt(db, auth, sql.into(), req.replication_index)
.await
.map_err(catch_stmt_error)?;
proto::StreamResponse::Describe(proto::DescribeStreamResp { result })
Expand Down
6 changes: 6 additions & 0 deletions sqld/src/hrana/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pub struct Stmt {
#[serde(default)]
#[prost(bool, optional, tag = "5")]
pub want_rows: Option<bool>,
#[serde(default)]
#[prost(uint64, optional, tag = "6")]
pub replication_index: Option<u64>,
}

#[derive(Deserialize, prost::Message)]
Expand Down Expand Up @@ -73,6 +76,9 @@ pub struct Row {
pub struct Batch {
#[prost(message, repeated, tag = "1")]
pub steps: Vec<BatchStep>,
#[prost(uint64, optional, tag = "2")]
#[serde(default)]
pub replication_index: Option<u64>,
}

#[derive(Deserialize, prost::Message)]
Expand Down
7 changes: 5 additions & 2 deletions sqld/src/hrana/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::hrana;
use crate::query::{Params, Query, Value};
use crate::query_analysis::Statement;
use crate::query_result_builder::{QueryResultBuilder, QueryResultBuilderError};
use crate::replication::FrameNo;

/// An error during execution of an SQL statement.
#[derive(thiserror::Error, Debug)]
Expand Down Expand Up @@ -54,10 +55,11 @@ pub async fn execute_stmt(
db: &impl Connection,
auth: Authenticated,
query: Query,
replication_index: Option<FrameNo>,
) -> Result<proto::StmtResult> {
let builder = SingleStatementBuilder::default();
let (stmt_res, _) = db
.execute_batch(vec![query], auth, builder)
.execute_batch(vec![query], auth, builder, replication_index)
.await
.map_err(catch_stmt_error)?;
stmt_res.into_ret().map_err(catch_stmt_error)
Expand All @@ -67,8 +69,9 @@ pub async fn describe_stmt(
db: &impl Connection,
auth: Authenticated,
sql: String,
replication_index: Option<FrameNo>,
) -> Result<proto::DescribeResult> {
match db.describe(sql, auth).await? {
match db.describe(sql, auth, replication_index).await? {
Ok(describe_response) => Ok(proto_describe_result_from_describe_response(
describe_response,
)),
Expand Down
Loading