diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index 1286c9e14a03..848e708c7dbb 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -1,10 +1,10 @@ mod conversion; mod error; -use super::{IsolationLevel, TransactionOptions}; +use super::{IsolationLevel, Transaction, TransactionOptions}; use crate::{ ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet, Transaction}, + connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, error::{Error, ErrorKind}, visitor::{self, Visitor}, }; @@ -96,7 +96,10 @@ static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommit #[async_trait] impl TransactionCapable for Mssql { - async fn start_transaction(&self, isolation: Option) -> crate::Result> { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> crate::Result> { // Isolation levels in SQL Server are set on the connection and live until they're changed. // Always explicitly setting the isolation level each time a tx is started (either to the given value // or by using the default/connection string value) prevents transactions started on connections from @@ -107,7 +110,9 @@ impl TransactionCapable for Mssql { let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); - Transaction::new(self, self.begin_statement(), opts).await + Ok(Box::new( + DefaultTransaction::new(self, self.begin_statement(), opts).await?, + )) } } diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index ce116dd2df6f..e4be7b47c404 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -472,7 +472,7 @@ impl Mysql { } } -impl TransactionCapable for Mysql {} +impl_default_TransactionCapable!(Mysql); #[async_trait] impl Queryable for Mysql { diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index efa414dc2461..d4dc008bd5f9 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -3,7 +3,7 @@ mod error; use crate::{ ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet, Transaction}, + connector::{metrics, queryable::*, ResultSet}, error::{Error, ErrorKind}, visitor::{self, Visitor}, }; @@ -34,7 +34,7 @@ pub(crate) const DEFAULT_SCHEMA: &str = "public"; #[cfg(feature = "expose-drivers")] pub use tokio_postgres; -use super::IsolationLevel; +use super::{IsolationLevel, Transaction}; #[derive(Clone)] struct Hidden(T); @@ -765,7 +765,7 @@ impl Display for SetSearchPath<'_> { } } -impl TransactionCapable for PostgreSql {} +impl_default_TransactionCapable!(PostgreSql); #[async_trait] impl Queryable for PostgreSql { @@ -912,7 +912,7 @@ impl Queryable for PostgreSql { self.is_healthy.load(Ordering::SeqCst) } - async fn server_reset_query(&self, tx: &Transaction<'_>) -> crate::Result<()> { + async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { if self.pg_bouncer { tx.raw_cmd("DEALLOCATE ALL").await } else { diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 24735890dfc4..09dbc7abba4c 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -1,4 +1,4 @@ -use super::{IsolationLevel, ResultSet, Transaction, TransactionOptions}; +use super::{IsolationLevel, ResultSet, Transaction}; use crate::ast::*; use async_trait::async_trait; @@ -82,7 +82,7 @@ pub trait Queryable: Send + Sync { } /// Execute an arbitrary function in the beginning of each transaction. - async fn server_reset_query(&self, _: &Transaction<'_>) -> crate::Result<()> { + async fn server_reset_query(&self, _: &dyn Transaction) -> crate::Result<()> { Ok(()) } @@ -101,13 +101,30 @@ pub trait Queryable: Send + Sync { /// A thing that can start a new transaction. #[async_trait] -pub trait TransactionCapable: Queryable -where - Self: Sized, -{ +pub trait TransactionCapable: Queryable { /// Starts a new transaction - async fn start_transaction(&self, isolation: Option) -> crate::Result> { - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); - Transaction::new(self, self.begin_statement(), opts).await - } + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> crate::Result>; +} + +macro_rules! impl_default_TransactionCapable { + ($t:ty) => { + #[async_trait] + impl TransactionCapable for $t { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> crate::Result> { + let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first()); + + Ok(Box::new( + crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?, + )) + } + } + }; } + +pub(crate) use impl_default_TransactionCapable; diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index ea81c51453e4..da85697a5936 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -167,7 +167,7 @@ impl Sqlite { } } -impl TransactionCapable for Sqlite {} +impl_default_TransactionCapable!(Sqlite); #[async_trait] impl Queryable for Sqlite { diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index 94302f98d472..b7e91e97f6a8 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -9,6 +9,18 @@ use std::{fmt, str::FromStr}; extern crate metrics as metrics; +#[async_trait] +pub trait Transaction: Queryable { + /// Commit the changes to the database and consume the transaction. + async fn commit(&self) -> crate::Result<()>; + + /// Rolls back the changes to the database. + async fn rollback(&self) -> crate::Result<()>; + + /// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991 + fn as_queryable(&self) -> &dyn Queryable; +} + pub(crate) struct TransactionOptions { /// The isolation level to use. pub(crate) isolation_level: Option, @@ -17,21 +29,21 @@ pub(crate) struct TransactionOptions { pub(crate) isolation_first: bool, } -/// A representation of an SQL database transaction. If not commited, a +/// A default representation of an SQL database transaction. If not commited, a /// transaction will be rolled back by default when dropped. /// /// Currently does not support nesting, so starting a new transaction using the /// transaction object will panic. -pub struct Transaction<'a> { - pub(crate) inner: &'a dyn Queryable, +pub struct DefaultTransaction<'a> { + pub inner: &'a dyn Queryable, } -impl<'a> Transaction<'a> { +impl<'a> DefaultTransaction<'a> { pub(crate) async fn new( inner: &'a dyn Queryable, begin_stmt: &str, tx_opts: TransactionOptions, - ) -> crate::Result> { + ) -> crate::Result> { let this = Self { inner }; if tx_opts.isolation_first { @@ -53,9 +65,12 @@ impl<'a> Transaction<'a> { increment_gauge!("prisma_client_queries_active", 1.0); Ok(this) } +} +#[async_trait] +impl<'a> Transaction for DefaultTransaction<'a> { /// Commit the changes to the database and consume the transaction. - pub async fn commit(&self) -> crate::Result<()> { + async fn commit(&self) -> crate::Result<()> { decrement_gauge!("prisma_client_queries_active", 1.0); self.inner.raw_cmd("COMMIT").await?; @@ -63,16 +78,20 @@ impl<'a> Transaction<'a> { } /// Rolls back the changes to the database. - pub async fn rollback(&self) -> crate::Result<()> { + async fn rollback(&self) -> crate::Result<()> { decrement_gauge!("prisma_client_queries_active", 1.0); self.inner.raw_cmd("ROLLBACK").await?; Ok(()) } + + fn as_queryable(&self) -> &dyn Queryable { + self + } } #[async_trait] -impl<'a> Queryable for Transaction<'a> { +impl<'a> Queryable for DefaultTransaction<'a> { async fn query(&self, q: Query<'_>) -> crate::Result { self.inner.query(q).await } @@ -171,7 +190,7 @@ impl FromStr for IsolationLevel { } } impl TransactionOptions { - pub(crate) fn new(isolation_level: Option, isolation_first: bool) -> Self { + pub fn new(isolation_level: Option, isolation_first: bool) -> Self { Self { isolation_level, isolation_first, diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 751bad564455..c0aa8c93b75d 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -6,7 +6,7 @@ use crate::connector::MysqlUrl; use crate::connector::PostgresUrl; use crate::{ ast, - connector::{self, IsolationLevel, Queryable, Transaction, TransactionCapable}, + connector::{self, impl_default_TransactionCapable, IsolationLevel, Queryable, Transaction, TransactionCapable}, error::Error, }; use async_trait::async_trait; @@ -18,7 +18,7 @@ pub struct PooledConnection { pub(crate) inner: MobcPooled, } -impl TransactionCapable for PooledConnection {} +impl_default_TransactionCapable!(PooledConnection); #[async_trait] impl Queryable for PooledConnection { @@ -58,7 +58,7 @@ impl Queryable for PooledConnection { self.inner.is_healthy() } - async fn server_reset_query(&self, tx: &Transaction<'_>) -> crate::Result<()> { + async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { self.inner.server_reset_query(tx).await } diff --git a/quaint/src/prelude.rs b/quaint/src/prelude.rs index bd5a180c3804..58e87ad8c1d5 100644 --- a/quaint/src/prelude.rs +++ b/quaint/src/prelude.rs @@ -1,6 +1,6 @@ //! A "prelude" for users of the `quaint` crate. pub use crate::ast::*; pub use crate::connector::{ - ConnectionInfo, Queryable, ResultRow, ResultSet, SqlFamily, Transaction, TransactionCapable, + ConnectionInfo, DefaultTransaction, Queryable, ResultRow, ResultSet, SqlFamily, TransactionCapable, }; pub use crate::{col, val, values}; diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 3cf96de3e1a9..3dcb6eb86a33 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -4,7 +4,7 @@ use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; use crate::{ ast, - connector::{self, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, + connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, }; use async_trait::async_trait; use std::{fmt, sync::Arc}; @@ -25,7 +25,7 @@ impl fmt::Debug for Quaint { } } -impl TransactionCapable for Quaint {} +impl_default_TransactionCapable!(Quaint); impl Quaint { /// Create a new connection to the database. The connection string diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index 7d0417d64c40..0247e8c4b601 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -62,7 +62,7 @@ where let fut_tx = self.inner.start_transaction(isolation_level); catch(self.connection_info.clone(), async move { - let tx: quaint::connector::Transaction = fut_tx.await.map_err(SqlError::from)?; + let tx = fut_tx.await.map_err(SqlError::from)?; Ok(Box::new(SqlConnectorTransaction::new(tx, connection_info, features)) as Box) }) diff --git a/query-engine/connectors/sql-query-connector/src/database/js.rs b/query-engine/connectors/sql-query-connector/src/database/js.rs index 284d4a48cd5e..ed4415323cc6 100644 --- a/query-engine/connectors/sql-query-connector/src/database/js.rs +++ b/query-engine/connectors/sql-query-connector/src/database/js.rs @@ -8,7 +8,7 @@ use connector_interface::{ }; use once_cell::sync::Lazy; use quaint::{ - connector::IsolationLevel, + connector::{IsolationLevel, Transaction}, prelude::{Queryable as QuaintQueryable, *}, }; use std::{ @@ -32,7 +32,7 @@ fn registered_js_connector(provider: &str) -> connector::Result { .map(|conn_ref| conn_ref.to_owned()) } -pub fn register_js_connector(provider: &str, connector: Arc) -> Result<(), String> { +pub fn register_js_connector(provider: &str, connector: Arc) -> Result<(), String> { let mut lock = REGISTRY.lock().unwrap(); let entry = lock.entry(provider.to_string()); match entry { @@ -128,7 +128,7 @@ impl Connector for Js { // in this object, and implementing TransactionCapable (and quaint::Queryable) explicitly for it. #[derive(Clone)] struct JsConnector { - connector: Arc, + connector: Arc, } #[async_trait] @@ -183,4 +183,12 @@ impl QuaintQueryable for JsConnector { } } -impl TransactionCapable for JsConnector {} +#[async_trait] +impl TransactionCapable for JsConnector { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> quaint::Result> { + self.connector.start_transaction(isolation).await + } +} diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index 75eac9beca56..517c293457f1 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -12,14 +12,14 @@ use quaint::prelude::ConnectionInfo; use std::collections::HashMap; pub struct SqlConnectorTransaction<'tx> { - inner: quaint::connector::Transaction<'tx>, + inner: Box, connection_info: ConnectionInfo, features: psl::PreviewFeatures, } impl<'tx> SqlConnectorTransaction<'tx> { pub fn new( - tx: quaint::connector::Transaction<'tx>, + tx: Box, connection_info: &ConnectionInfo, features: psl::PreviewFeatures, ) -> Self { @@ -74,7 +74,7 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); read::get_single_record( - &self.inner, + self.inner.as_queryable(), model, filter, &selected_fields.into(), @@ -97,7 +97,7 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); read::get_many_records( - &self.inner, + self.inner.as_queryable(), model, query_arguments, &selected_fields.into(), @@ -117,7 +117,7 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { ) -> connector::Result> { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); - read::get_related_m2m_record_ids(&self.inner, from_field, from_record_ids, &ctx).await + read::get_related_m2m_record_ids(self.inner.as_queryable(), from_field, from_record_ids, &ctx).await }) .await } @@ -133,7 +133,16 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { ) -> connector::Result> { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); - read::aggregate(&self.inner, model, query_arguments, selections, group_by, having, &ctx).await + read::aggregate( + self.inner.as_queryable(), + model, + query_arguments, + selections, + group_by, + having, + &ctx, + ) + .await }) .await } @@ -151,7 +160,7 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); write::create_record( - &self.inner, + self.inner.as_queryable(), &self.connection_info.sql_family(), model, args, @@ -172,7 +181,7 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { ) -> connector::Result { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); - write::create_records(&self.inner, model, args, skip_duplicates, &ctx).await + write::create_records(self.inner.as_queryable(), model, args, skip_duplicates, &ctx).await }) .await } @@ -186,7 +195,7 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { ) -> connector::Result { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); - write::update_records(&self.inner, model, record_filter, args, &ctx).await + write::update_records(self.inner.as_queryable(), model, record_filter, args, &ctx).await }) .await } @@ -202,7 +211,15 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); - write::update_record(&self.inner, model, record_filter, args, selected_fields, &ctx).await + write::update_record( + self.inner.as_queryable(), + model, + record_filter, + args, + selected_fields, + &ctx, + ) + .await }) .await } @@ -215,7 +232,7 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { ) -> connector::Result { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); - write::delete_records(&self.inner, model, record_filter, &ctx).await + write::delete_records(self.inner.as_queryable(), model, record_filter, &ctx).await }) .await } @@ -227,7 +244,7 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { ) -> connector::Result { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); - upsert::native_upsert(&self.inner, upsert, &ctx).await + upsert::native_upsert(self.inner.as_queryable(), upsert, &ctx).await }) .await } @@ -241,7 +258,7 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { ) -> connector::Result<()> { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); - write::m2m_connect(&self.inner, field, parent_id, child_ids, &ctx).await + write::m2m_connect(self.inner.as_queryable(), field, parent_id, child_ids, &ctx).await }) .await } @@ -255,14 +272,14 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { ) -> connector::Result<()> { catch(self.connection_info.clone(), async move { let ctx = Context::new(&self.connection_info, trace_id.as_deref()); - write::m2m_disconnect(&self.inner, field, parent_id, child_ids, &ctx).await + write::m2m_disconnect(self.inner.as_queryable(), field, parent_id, child_ids, &ctx).await }) .await } async fn execute_raw(&mut self, inputs: HashMap) -> connector::Result { catch(self.connection_info.clone(), async move { - write::execute_raw(&self.inner, self.features, inputs).await + write::execute_raw(self.inner.as_queryable(), self.features, inputs).await }) .await } @@ -274,7 +291,7 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { _query_type: Option, ) -> connector::Result { catch(self.connection_info.clone(), async move { - write::query_raw(&self.inner, inputs).await + write::query_raw(self.inner.as_queryable(), inputs).await }) .await } diff --git a/query-engine/js-connectors/js/js-connector-utils/src/binder.ts b/query-engine/js-connectors/js/js-connector-utils/src/binder.ts index 7a91cb9430db..1f64f3e98700 100644 --- a/query-engine/js-connectors/js/js-connector-utils/src/binder.ts +++ b/query-engine/js-connectors/js/js-connector-utils/src/binder.ts @@ -1,13 +1,23 @@ -import type { Connector } from './types'; +import type { Connector, Transaction } from './types'; -// *.bind(db) is required to preserve the `this` context. -// There are surely other ways than this to use class methods defined in JS within a -// driver context, but this is the most straightforward. -export const binder = (queryable: Connector): Connector => ({ - queryRaw: queryable.queryRaw.bind(queryable), - executeRaw: queryable.executeRaw.bind(queryable), - version: queryable.version.bind(queryable), - isHealthy: queryable.isHealthy.bind(queryable), - close: queryable.close.bind(queryable), - flavour: queryable.flavour, +// *.bind(connector) is required to preserve the `this` context of functions whose +// execution is delegated to napi.rs. +export const bindConnector = (connector: Connector): Connector => ({ + queryRaw: connector.queryRaw.bind(connector), + executeRaw: connector.executeRaw.bind(connector), + flavour: connector.flavour, + startTransaction: connector.startTransaction.bind(connector), + close: connector.close.bind(connector) }) + +// *.bind(transaction) is required to preserve the `this` context of functions whose +// execution is delegated to napi.rs. +export const bindTransaction = (transaction: Transaction): Transaction => { + return ({ + flavour: transaction.flavour, + queryRaw: transaction.queryRaw.bind(transaction), + executeRaw: transaction.executeRaw.bind(transaction), + commit: transaction.commit.bind(transaction), + rollback: transaction.rollback.bind(transaction) + }); +} \ No newline at end of file diff --git a/query-engine/js-connectors/js/js-connector-utils/src/const.ts b/query-engine/js-connectors/js/js-connector-utils/src/const.ts index f7950c5a760e..48fef04539a8 100644 --- a/query-engine/js-connectors/js/js-connector-utils/src/const.ts +++ b/query-engine/js-connectors/js/js-connector-utils/src/const.ts @@ -21,20 +21,3 @@ export const ColumnTypeEnum = { // 'Array': 15, // ... } as const - -export const connectionHealthErrorCodes = [ - // Unable to resolve the domain name to an IP address. - 'ENOTFOUND', - - // Failed to get a response from the DNS server. - 'EAI_AGAIN', - - // The connection was refused by the database server. - 'ECONNREFUSED', - - // The connection attempt timed out. - 'ETIMEDOUT', - - // The connection was unexpectedly closed by the database server. - 'ECONNRESET', -] as const diff --git a/query-engine/js-connectors/js/js-connector-utils/src/index.ts b/query-engine/js-connectors/js/js-connector-utils/src/index.ts index 69a062dc3122..921411d50987 100644 --- a/query-engine/js-connectors/js/js-connector-utils/src/index.ts +++ b/query-engine/js-connectors/js/js-connector-utils/src/index.ts @@ -1,5 +1,4 @@ -export { binder } from './binder' +export { bindConnector, bindTransaction } from './binder' export { ColumnTypeEnum } from './const' export { Debug } from './debug' export type * from './types' -export { isConnectionUnhealthy } from './util' diff --git a/query-engine/js-connectors/js/js-connector-utils/src/types.ts b/query-engine/js-connectors/js/js-connector-utils/src/types.ts index 0d3f021d8684..196ef2f7dbec 100644 --- a/query-engine/js-connectors/js/js-connector-utils/src/types.ts +++ b/query-engine/js-connectors/js/js-connector-utils/src/types.ts @@ -32,8 +32,8 @@ export type Query = { args: Array } -export type Connector = { - readonly flavour: 'mysql' | 'postgres', +export interface Queryable { + readonly flavour: 'mysql' | 'postgres' /** * Execute a query given as SQL, interpolating the given parameters, @@ -41,7 +41,7 @@ export type Connector = { * * This is the preferred way of executing `SELECT` queries. */ - queryRaw: (params: Query) => Promise + queryRaw(params: Query): Promise /** * Execute a query given as SQL, interpolating the given parameters, @@ -50,23 +50,31 @@ export type Connector = { * This is the preferred way of executing `INSERT`, `UPDATE`, `DELETE` queries, * as well as transactional queries. */ - executeRaw: (params: Query) => Promise + executeRaw(params: Query): Promise +} +export interface Connector extends Queryable { /** - * Return the version of the underlying database, queried directly from the - * source. + * Starts new transation with the specified isolation level + * @param isolationLevel */ - version: () => Promise + startTransaction(isolationLevel?: string): Promise /** - * Returns true, if connection is considered to be in a working state. + * Closes the connection to the database, if any. */ - isHealthy: () => boolean + close: () => Promise +} +export interface Transaction extends Queryable { /** - * Closes the connection to the database, if any. + * Commit the transaction */ - close: () => Promise + commit(): Promise + /** + * Rolls back the transaction. + */ + rollback(): Promise } /** diff --git a/query-engine/js-connectors/js/js-connector-utils/src/util.ts b/query-engine/js-connectors/js/js-connector-utils/src/util.ts deleted file mode 100644 index bdd7005aaa27..000000000000 --- a/query-engine/js-connectors/js/js-connector-utils/src/util.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { connectionHealthErrorCodes } from './const' - -type ConnectionHealthErrorCode = typeof connectionHealthErrorCodes[number] - -export function isConnectionUnhealthy(errorCode: E | ConnectionHealthErrorCode): errorCode is ConnectionHealthErrorCode { - // Note: `Array.includes` is too narrow, see https://github.com/microsoft/TypeScript/issues/26255. - return (connectionHealthErrorCodes as readonly string[]).includes(errorCode) -} diff --git a/query-engine/js-connectors/js/neon-js-connector/package.json b/query-engine/js-connectors/js/neon-js-connector/package.json index 573ef787c466..6fb1aaa5c030 100644 --- a/query-engine/js-connectors/js/neon-js-connector/package.json +++ b/query-engine/js-connectors/js/neon-js-connector/package.json @@ -19,7 +19,7 @@ "sideEffects": false, "dependencies": { "@jkomyno/prisma-js-connector-utils": "workspace:*", - "@neondatabase/serverless": "^0.5.6", + "@neondatabase/serverless": "^0.6.0", "ws": "^8.13.0" } } diff --git a/query-engine/js-connectors/js/neon-js-connector/src/neon.ts b/query-engine/js-connectors/js/neon-js-connector/src/neon.ts index 1a4e4519c1f0..1582e25167e8 100644 --- a/query-engine/js-connectors/js/neon-js-connector/src/neon.ts +++ b/query-engine/js-connectors/js/neon-js-connector/src/neon.ts @@ -1,8 +1,8 @@ -import { Client, neon, neonConfig } from '@neondatabase/serverless' -import type { NeonConfig, NeonQueryFunction } from '@neondatabase/serverless' +import { FullQueryResults, PoolClient, neon, neonConfig } from '@neondatabase/serverless' +import { NeonConfig, NeonQueryFunction, Pool, QueryResult } from '@neondatabase/serverless' import ws from 'ws' -import { binder, isConnectionUnhealthy, Debug } from '@jkomyno/prisma-js-connector-utils' -import type { Connector, ResultSet, Query, ConnectorConfig } from '@jkomyno/prisma-js-connector-utils' +import { bindConnector, bindTransaction, Debug } from '@jkomyno/prisma-js-connector-utils' +import type { Connector, ResultSet, Query, ConnectorConfig, Queryable, Transaction } from '@jkomyno/prisma-js-connector-utils' import { fieldToColumnType } from './conversion' neonConfig.webSocketConstructor = ws @@ -11,82 +11,17 @@ const debug = Debug('prisma:js-connector:neon') export type PrismaNeonConfig = ConnectorConfig & Partial> & { httpMode?: boolean } -const TRANSACTION_BEGIN = 'BEGIN' -const TRANSACTION_COMMIT = 'COMMIT' -const TRANSACTION_ROLLBACK = 'ROLLBACK' - type ARRAY_MODE_DISABLED = false type FULL_RESULTS_ENABLED = true -type ModeSpecificDriver - = { - /** - * Indicates that we're using the HTTP mode. - */ - mode: 'http' - - /** - * The Neon HTTP client, without transaction support. - */ - client: NeonQueryFunction - } - | { - /** - * Indicates that we're using the WebSocket mode. - */ - mode: 'ws' - - /** - * The standard Neon client, with transaction support. - */ - client: Client - } - -class PrismaNeon implements Connector { - readonly flavour = 'postgres' - - private driver: ModeSpecificDriver - private isRunning: boolean = true - private inTransaction: boolean = false - private _isHealthy: boolean = true - private _version: string | undefined = undefined - - constructor(config: PrismaNeonConfig) { - const { url: connectionString, httpMode, ...rest } = config - if (!httpMode) { - this.driver = { - mode: 'ws', - client: new Client({ connectionString, ...rest }) - } - // connect the client in the background, all requests will be queued until connection established - this.driver.client.connect() - } else { - this.driver = { - mode: 'http', - client: neon(connectionString, { fullResults: true, ...rest }) - } - } - } +type PerformIOResult = QueryResult | FullQueryResults - async close(): Promise { - if (this.isRunning) { - if (this.driver.mode === 'ws') { - await this.driver.client.end() - } - this.isRunning = false - } - } +/** + * Base class for http client, ws client and ws transaction + */ +abstract class NeonQueryable implements Queryable { + flavour = 'postgres' as const - /** - * Returns true, if connection is considered to be in a working state. - */ - isHealthy(): boolean { - return this.isRunning && this._isHealthy - } - - /** - * Execute a query given as SQL, interpolating the given parameters. - */ async queryRaw(query: Query): Promise { const tag = '[js::query_raw]' debug(`${tag} %O`, query) @@ -103,91 +38,99 @@ class PrismaNeon implements Connector { return resultSet } - /** - * Execute a query given as SQL, interpolating the given parameters and - * returning the number of affected rows. - * Note: Queryable expects a u64, but napi.rs only supports u32. - */ async executeRaw(query: Query): Promise { const tag = '[js::execute_raw]' debug(`${tag} %O`, query) - switch (query.sql) { - case TRANSACTION_BEGIN: { - if (this.driver.mode === 'http') { - throw new Error('Transactions are not supported in HTTP mode') - } - - // check if a transaction is already in progress - if (this.inTransaction) { - throw new Error('A transaction is already in progress') - } - - this.inTransaction = true - debug(`${tag} transaction began`) - - return Promise.resolve(-1) - } - case TRANSACTION_COMMIT: { - this.inTransaction = false - debug(`${tag} transaction ended successfully`) - return Promise.resolve(-1) - } - case TRANSACTION_ROLLBACK: { - this.inTransaction = false - debug(`${tag} transaction ended with error`) - return Promise.reject(query.sql) - } - default: { - const { rowCount: rowsAffected } = await this.performIO(query) - return rowsAffected - } - } + const { rowCount: rowsAffected } = await this.performIO(query) + return rowsAffected } - /** - * Return the version of the underlying database, queried directly from the - * source. This corresponds to the `version()` function on PostgreSQL for - * example. The version string is returned directly without any form of - * parsing or normalization. - */ - async version(): Promise { - if (this._version) { - return Promise.resolve(this._version) - } + abstract performIO(query: Query): Promise +} - const { rows } = await this.performIO({ sql: 'SELECT VERSION()', args: [] }) - this._version = rows[0]['version'] as string - return this._version +/** + * Base class for WS-based queryables: top-level client and transaction + */ +class NeonWsQueryable extends NeonQueryable { + constructor(protected client: ClientT) { + super() } - /** - * Run a query against the database, returning the result set. - * Should the query fail due to a connection error, the connection is - * marked as unhealthy. - */ - private async performIO(query: Query) { + override performIO(query: Query): Promise { const { sql, args: values } = query + return this.client.query(sql, values) + } +} +class NeonTransaction extends NeonWsQueryable implements Transaction { + async commit(): Promise { try { - if (this.driver.mode === 'ws') { - return await this.driver.client.query(sql, values) - } else { - return await this.driver.client(sql, values) - } - } catch (e) { - const error = e as Error & { code: string } - - if (isConnectionUnhealthy(error.code)) { - this._isHealthy = false - } - - throw e + await this.client.query('COMMIT'); + } finally { + this.client.release() } } + + async rollback(): Promise { + try { + await this.client.query('ROLLBACK'); + } finally { + this.client.release() + } + } + +} + +class NeonWsConnector extends NeonWsQueryable implements Connector { + private isRunning = true + constructor(config: PrismaNeonConfig) { + const { url: connectionString, httpMode, ...rest } = config + super(new Pool({ connectionString, ...rest })) + } + + async startTransaction(isolationLevel?: string | undefined): Promise { + const connection = await this.client.connect() + await connection.query('BEGIN') + if (isolationLevel) { + await connection.query(`SET TRANSACTION ISOLATION LEVEL ${isolationLevel}`) + } + + return bindTransaction(new NeonTransaction(connection)) + } + + async close() { + this.client.on('error', e => console.log(e)) + if (this.isRunning) { + await this.client.end() + this.isRunning = false + } + } +} + +class NeonHttpConnector extends NeonQueryable implements Connector { + private client: NeonQueryFunction + + constructor(config: PrismaNeonConfig) { + super() + const { url: connectionString, httpMode, ...rest } = config + this.client = neon(connectionString, { fullResults: true, ...rest}) + } + + override async performIO(query: Query): Promise { + const { sql, args: values } = query + return await this.client(sql, values) + } + + startTransaction(): Promise { + return Promise.reject(new Error('Transactions are not supported in HTTP mode')) + } + + async close() {} + } export const createNeonConnector = (config: PrismaNeonConfig): Connector => { - const db = new PrismaNeon(config) - return binder(db) + const db = config.httpMode ? new NeonHttpConnector(config) : new NeonWsConnector(config) + return bindConnector(db) } diff --git a/query-engine/js-connectors/js/planetscale-js-connector/src/deferred.ts b/query-engine/js-connectors/js/planetscale-js-connector/src/deferred.ts new file mode 100644 index 000000000000..013409c8424f --- /dev/null +++ b/query-engine/js-connectors/js/planetscale-js-connector/src/deferred.ts @@ -0,0 +1,13 @@ +export type Deferred = { + resolve(value: T | PromiseLike): void; + reject(reason: unknown): void; +} + + +export function createDeferred(): [Deferred, Promise] { + const deferred = {} as Deferred + return [deferred, new Promise((resolve, reject) => { + deferred.resolve = resolve + deferred.reject = reject + })] +} \ No newline at end of file diff --git a/query-engine/js-connectors/js/planetscale-js-connector/src/planetscale.ts b/query-engine/js-connectors/js/planetscale-js-connector/src/planetscale.ts index 105797b9805d..8bb160cc6583 100644 --- a/query-engine/js-connectors/js/planetscale-js-connector/src/planetscale.ts +++ b/query-engine/js-connectors/js/planetscale-js-connector/src/planetscale.ts @@ -1,73 +1,29 @@ import * as planetScale from '@planetscale/database' import type { Config as PlanetScaleConfig } from '@planetscale/database' -import { EventEmitter } from 'node:events' -import { setImmediate } from 'node:timers/promises' -import { binder, isConnectionUnhealthy, Debug } from '@jkomyno/prisma-js-connector-utils' -import type { Connector, ResultSet, Query, ConnectorConfig } from '@jkomyno/prisma-js-connector-utils' +import { bindConnector, bindTransaction, Debug } from '@jkomyno/prisma-js-connector-utils' +import type { Connector, ResultSet, Query, ConnectorConfig, Queryable, Transaction } from '@jkomyno/prisma-js-connector-utils' import { type PlanetScaleColumnType, fieldToColumnType } from './conversion' +import { createDeferred, Deferred } from './deferred' const debug = Debug('prisma:js-connector:planetscale') export type PrismaPlanetScaleConfig = ConnectorConfig & Partial -type TransactionCapableDriver - = { - /** - * Indicates a transaction is in progress in this connector's instance. - */ - inTransaction: true - - /** - * The PlanetScale client, scoped in transaction mode. - */ - client: planetScale.Transaction - } - | { - /** - * Indicates that no transactions are in progress in this connector's instance. - */ - inTransaction: false - - /** - * The standard PlanetScale client. - */ - client: planetScale.Connection - } - -const TRANSACTION_BEGIN = 'BEGIN' -const TRANSACTION_COMMIT = 'COMMIT' -const TRANSACTION_ROLLBACK = 'ROLLBACK' - -class PrismaPlanetScale implements Connector { - readonly flavour = 'mysql' - - private driver: TransactionCapableDriver - private isRunning: boolean = true - private _isHealthy: boolean = true - private _version: string | undefined = undefined - private txEmitter = new EventEmitter() - - constructor(config: PrismaPlanetScaleConfig) { - const client = planetScale.connect(config) +class RollbackError extends Error { + constructor() { + super('ROLLBACK') + this.name = 'RollbackError' - // initialize the driver as a non-transactional client - this.driver = { - client, - inTransaction: false, + if (Error.captureStackTrace) { + Error.captureStackTrace(this, RollbackError); } } +} - async close(): Promise { - if (this.isRunning) { - this.isRunning = false - } - } - /** - * Returns true, if connection is considered to be in a working state. - */ - isHealthy(): boolean { - return this.isRunning && this._isHealthy +class PlanetScaleQueryable implements Queryable { + readonly flavour = 'mysql' + constructor(protected client: ClientT) { } /** @@ -99,78 +55,8 @@ class PrismaPlanetScale implements Connector { const tag = '[js::execute_raw]' debug(`${tag} %O`, query) - const connection = this.driver.client - const { sql } = query - - switch (sql) { - case TRANSACTION_BEGIN: { - // check if a transaction is already in progress - if (this.driver.inTransaction) { - throw new Error('A transaction is already in progress') - } - - (this.driver.client as planetScale.Connection).transaction(async (tx) => { - // tx holds the scope for executing queries in transaction mode - this.driver.client = tx - - // signal the transaction began - this.driver.inTransaction = true - debug(`${tag} transaction began`) - - await new Promise((resolve, reject) => { - this.txEmitter.once(TRANSACTION_COMMIT, () => { - this.driver.inTransaction = false - debug(`${tag} transaction ended successfully`) - this.driver.client = connection - resolve(undefined) - }) - - this.txEmitter.once(TRANSACTION_ROLLBACK, () => { - this.driver.inTransaction = false - debug(`${tag} transaction ended with error`) - this.driver.client = connection - reject('ROLLBACK') - }) - }) - }) - - // ensure that this.driver.client is set to `planetScale.Transaction` - await setImmediate(0, { - // we do not require the event loop to remain active - ref: false, - }) - - return Promise.resolve(-1) - } - case TRANSACTION_COMMIT: { - this.txEmitter.emit(sql) - return Promise.resolve(-1) - } - case TRANSACTION_ROLLBACK: { - this.txEmitter.emit(sql) - return Promise.resolve(-2) - } - default: { - const { rowsAffected } = await this.performIO(query) - return rowsAffected - } - } - } - - /** - * Return the version of the underlying database, queried directly from the - * source. This corresponds to the `version()` function on PostgreSQL for - * example. The version string is returned directly without any form of - * parsing or normalization. - */ - async version(): Promise { - if (this._version) { - return Promise.resolve(this._version) - } - - const { rows } = await this.performIO({ sql: 'SELECT @@version', args: [] }) - this._version = rows[0]['@@version'] as string - return this._version + const { rowsAffected } = await this.performIO(query) + return rowsAffected } /** @@ -181,21 +67,67 @@ class PrismaPlanetScale implements Connector { private async performIO(query: Query) { const { sql, args: values } = query - try { - return await this.driver.client.execute(sql, values) - } catch (e) { - const error = e as Error & { code: string } - - if (isConnectionUnhealthy(error.code)) { - this._isHealthy = false - } + return await this.client.execute(sql, values) + } +} + +class PlanetScaleTransaction extends PlanetScaleQueryable implements Transaction { + constructor(tx: planetScale.Transaction, private txDeferred: Deferred, private txResultPromise: Promise) { + super(tx) + } - throw e - } + commit(): Promise { + const tag = '[js::commit]' + debug(`${tag} committing transaction`) + this.txDeferred.resolve() + return this.txResultPromise; + } + + rollback(): Promise { + const tag = '[js::rollback]' + debug(`${tag} rolling back the transaction`) + this.txDeferred.reject(new RollbackError()) + return this.txResultPromise; + } + +} + +class PrismaPlanetScale extends PlanetScaleQueryable implements Connector { + constructor(config: PrismaPlanetScaleConfig) { + const client = planetScale.connect(config) + + super(client) + } + + async startTransaction(isolationLevel?: string) { + return new Promise((resolve) => { + const txResultPromise = this.client.transaction(async tx => { + if (isolationLevel) { + await tx.execute(`SET TRANSACTION ISOLATION LEVEL ${isolationLevel}`) + } + const [txDeferred, deferredPromise] = createDeferred() + const txWrapper = new PlanetScaleTransaction(tx, txDeferred, txResultPromise) + + resolve(bindTransaction(txWrapper)); + + return deferredPromise + }).catch(error => { + // Rollback error is ignored (so that tx.rollback() won't crash) + // any other error is legit and is re-thrown + if (!(error instanceof RollbackError)) { + return Promise.reject(error) + } + + return undefined + }); + }) + } + + async close() {} } export const createPlanetScaleConnector = (config: PrismaPlanetScaleConfig): Connector => { const db = new PrismaPlanetScale(config) - return binder(db) + return bindConnector(db) } diff --git a/query-engine/js-connectors/js/pnpm-lock.yaml b/query-engine/js-connectors/js/pnpm-lock.yaml index ecc0cb2b12d9..2d9b2bd06152 100644 --- a/query-engine/js-connectors/js/pnpm-lock.yaml +++ b/query-engine/js-connectors/js/pnpm-lock.yaml @@ -34,8 +34,8 @@ importers: specifier: workspace:* version: link:../js-connector-utils '@neondatabase/serverless': - specifier: ^0.5.6 - version: 0.5.6 + specifier: ^0.6.0 + version: 0.6.0 ws: specifier: ^8.13.0 version: 8.13.0 @@ -524,8 +524,8 @@ packages: '@jridgewell/sourcemap-codec': 1.4.14 dev: true - /@neondatabase/serverless@0.5.6: - resolution: {integrity: sha512-Ru0lG6W/nQtHRkDFVQFF+1PJYx8wd3jereln0Ep0YkiHey50hjTLVUycQoE4X977605pXMuFWORweuktzph+Xg==} + /@neondatabase/serverless@0.6.0: + resolution: {integrity: sha512-qXxBRYN0m2v8kVQBfMxbzNGn2xFAhTXFibzQlE++NfJ56Shz3m7+MyBBtXDlEH+3Wfa6lToDXf1MElocY4sJ3w==} dependencies: '@types/pg': 8.6.6 dev: false @@ -590,11 +590,16 @@ packages: /@types/node@20.4.5: resolution: {integrity: sha512-rt40Nk13II9JwQBdeYqmbn2Q6IVTA5uPhvSO+JVqdXw/6/4glI6oR9ezty/A9Hg5u7JH4OmYmuQ+XvjKm0Datg==} + dev: true + + /@types/node@20.5.0: + resolution: {integrity: sha512-Mgq7eCtoTjT89FqNoTzzXg2XvCi5VMhRV6+I2aYanc6kQCBImeNaAYRs/DyoVqk1YEUJK5gN9VO7HRIdz4Wo3Q==} + dev: false /@types/pg@8.6.6: resolution: {integrity: sha512-O2xNmXebtwVekJDD+02udOncjVcMZQuTEQEMpKJ0ZRf5E7/9JJX3izhKUcUifBkyKpljyUM6BTgy2trmviKlpw==} dependencies: - '@types/node': 20.4.5 + '@types/node': 20.5.0 pg-protocol: 1.6.0 pg-types: 2.2.0 dev: false diff --git a/query-engine/js-connectors/js/smoke-test-js/src/test.ts b/query-engine/js-connectors/js/smoke-test-js/src/test.ts index 505b91fe32dc..628b1e0bdd5e 100644 --- a/query-engine/js-connectors/js/smoke-test-js/src/test.ts +++ b/query-engine/js-connectors/js/smoke-test-js/src/test.ts @@ -22,6 +22,7 @@ export async function smokeTest(db: Connector, prismaSchemaRelativePath: string) await test.testFindManyTypeTest() await test.createAutoIncrement() await test.testCreateAndDeleteChildParent() + await test.testTransaction() // Note: calling `engine.disconnect` won't actually close the database connection. console.log('[nodejs] disconnecting...') @@ -38,12 +39,10 @@ export async function smokeTest(db: Connector, prismaSchemaRelativePath: string) await engine.disconnect('trace') console.log('[nodejs] re-disconnected') - // Close the database connection. + // Close the database connection. This is required to prevent the process from hanging. console.log('[nodejs] closing database connection...') await db.close() console.log('[nodejs] closed database connection') - - // process.exit(0) } class SmokeTest { @@ -280,4 +279,24 @@ class SmokeTest { `, 'trace', undefined) console.log('[nodejs] resultDeleteMany', JSON.stringify(JSON.parse(resultDeleteMany), null, 2)) } + + async testTransaction() { + const startResponse = await this.engine.startTransaction(JSON.stringify({ isolation_level: 'Serializable', max_wait: 5000, timeout: 15000 }), 'trace') + + const tx_id = JSON.parse(startResponse).id + + console.log('[nodejs] transaction id', tx_id) + await this.engine.query(` + { + "action": "findMany", + "modelName": "Author", + "query": { + "selection": { "$scalars": true } + } + } + `, 'trace', tx_id) + + const commitResponse = await this.engine.commitTransaction(tx_id, 'trace') + console.log('[nodejs] commited', commitResponse) + } } diff --git a/query-engine/js-connectors/src/error.rs b/query-engine/js-connectors/src/error.rs index c48ef4f046c7..f2fbb7dd9caf 100644 --- a/query-engine/js-connectors/src/error.rs +++ b/query-engine/js-connectors/src/error.rs @@ -24,18 +24,6 @@ where .unwrap_or_else(panic_to_napi_err) } -/// catches a panic thrown during the execution of a closure and transforms it into a the Error -/// variant of a napi::Result. -pub(crate) fn unwinding_panic(f: F) -> napi::Result -where - F: Fn() -> napi::Result, -{ - match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) { - Ok(result) => result, - Err(panic_payload) => panic_to_napi_err(panic_payload), - } -} - fn panic_to_napi_err(panic_payload: Box) -> napi::Result { panic_payload .downcast_ref::<&str>() diff --git a/query-engine/js-connectors/src/lib.rs b/query-engine/js-connectors/src/lib.rs index d41219fd61e1..9e2664621e8d 100644 --- a/query-engine/js-connectors/src/lib.rs +++ b/query-engine/js-connectors/src/lib.rs @@ -10,4 +10,5 @@ mod error; mod proxy; mod queryable; +mod transaction; pub use queryable::{from_napi, JsQueryable}; diff --git a/query-engine/js-connectors/src/proxy.rs b/query-engine/js-connectors/src/proxy.rs index 70471c0eb7f5..4d81f0436aef 100644 --- a/query-engine/js-connectors/src/proxy.rs +++ b/query-engine/js-connectors/src/proxy.rs @@ -1,13 +1,13 @@ use core::panic; use std::str::FromStr; -use std::sync::{Arc, Condvar, Mutex}; use crate::error::*; +use crate::transaction::JsTransaction; use napi::bindgen_prelude::{FromNapiValue, Promise as JsPromise, ToNapiValue}; use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction}; use napi::{Env, JsObject, JsString}; use napi_derive::napi; -use quaint::connector::ResultSet as QuaintResultSet; +use quaint::connector::{IsolationLevel, ResultSet as QuaintResultSet}; use quaint::Value as QuaintValue; // TODO(jkomyno): import these 3rd-party crates from the `quaint-core` crate. @@ -18,7 +18,7 @@ use chrono::{NaiveDate, NaiveTime}; /// Proxy is a struct wrapping a javascript object that exhibits basic primitives for /// querying and executing SQL (i.e. a client connector). The Proxy uses NAPI ThreadSafeFunction to /// invoke the code within the node runtime that implements the client connector. -pub struct Proxy { +pub struct CommonProxy { /// Execute a query given as SQL, interpolating the given parameters. query_raw: ThreadsafeFunction, @@ -26,56 +26,23 @@ pub struct Proxy { /// returning the number of affected rows. execute_raw: ThreadsafeFunction, - /// Return the version of the underlying database, queried directly from the - /// source. - version: ThreadsafeFunction<(), ErrorStrategy::Fatal>, - - /// Closes the underlying database connection. - #[allow(dead_code)] - close: ThreadsafeFunction<(), ErrorStrategy::Fatal>, - - /// Return true iff the underlying database connection is healthy. - /// Note: we already attempted turning `is_healthy` into just a `JsFunction` - /// (which would result in a simpler `call` API), but any call to it panics, - /// and `unsafe impl Send/Sync` for `Proxy` become necessary. - /// Moreover, `JsFunction` is not `Clone`. - is_healthy: ThreadsafeFunction<(), ErrorStrategy::Fatal>, - /// Return the flavour for this driver. - #[allow(dead_code)] pub(crate) flavour: String, } -/// Reify creates a Rust proxy to access the JS driver passed in as a parameter. -pub fn reify(napi_env: &Env, js_connector: JsObject) -> napi::Result { - let mut query_raw = - js_connector.get_named_property::>("queryRaw")?; - let mut execute_raw = - js_connector.get_named_property::>("executeRaw")?; - let mut version = js_connector.get_named_property::>("version")?; - let mut close = js_connector.get_named_property::>("close")?; - let mut is_healthy = - js_connector.get_named_property::>("isHealthy")?; - - // Note: calling `unref` on every ThreadsafeFunction is necessary to avoid hanging the JS event loop. - query_raw.unref(napi_env)?; - execute_raw.unref(napi_env)?; - version.unref(napi_env)?; - close.unref(napi_env)?; - is_healthy.unref(napi_env)?; - - let flavour: JsString = js_connector.get_named_property("flavour")?; - let flavour: String = flavour.into_utf8()?.as_str()?.to_owned(); - - let driver = Proxy { - query_raw, - execute_raw, - version, - close, - is_healthy, - flavour, - }; - Ok(driver) +/// This is a JS proxy for accessing the methods specific to top level +/// JS driver objects +pub struct DriverProxy { + start_transaction: ThreadsafeFunction, ErrorStrategy::Fatal>, +} +/// This a JS proxy for accessing the methods, specific +/// to JS transaction objects +pub struct TransactionProxy { + /// commit transaction + commit: ThreadsafeFunction<(), ErrorStrategy::Fatal>, + + /// rollback transcation + rollback: ThreadsafeFunction<(), ErrorStrategy::Fatal>, } /// This result set is more convenient to be manipulated from both Rust and NodeJS. @@ -328,7 +295,24 @@ impl From for QuaintResultSet { } } -impl Proxy { +impl CommonProxy { + pub fn new(object: &JsObject, env: &Env) -> napi::Result { + let query_raw = object.get_named_property("queryRaw")?; + let execute_raw = object.get_named_property("executeRaw")?; + let flavour: JsString = object.get_named_property("flavour")?; + + let mut result = Self { + query_raw, + execute_raw, + flavour: flavour.into_utf8()?.as_str()?.to_owned(), + }; + + result.query_raw.unref(env)?; + result.execute_raw.unref(env)?; + + Ok(result) + } + pub async fn query_raw(&self, params: Query) -> napi::Result { async_unwinding_panic(async { let promise = self.query_raw.call_async::>(params).await?; @@ -346,50 +330,57 @@ impl Proxy { }) .await } +} - pub async fn version(&self) -> napi::Result> { - async_unwinding_panic(async { - let version = self.version.call_async::>(()).await?; - Ok(version) +impl DriverProxy { + pub fn new(js_connector: &JsObject, env: &Env) -> napi::Result { + let start_transaction = js_connector.get_named_property("startTransaction")?; + let mut result = Self { start_transaction }; + result.start_transaction.unref(env)?; + + Ok(result) + } + + pub async fn start_transaction(&self, isolation_level: Option) -> napi::Result> { + async_unwinding_panic(async move { + let promise = self + .start_transaction + .call_async::>(isolation_level.map(|l| l.to_string())) + .await?; + + let tx = promise.await?; + Ok(Box::new(tx)) }) .await } +} - pub async fn close(&self) -> napi::Result<()> { - async_unwinding_panic(async { self.close.call_async::<()>(()).await }).await - } +impl TransactionProxy { + pub fn new(js_transaction: &JsObject, env: &Env) -> napi::Result { + let commit = js_transaction.get_named_property("commit")?; + let rollback = js_transaction.get_named_property("rollback")?; - pub fn is_healthy(&self) -> napi::Result { - unwinding_panic(|| { - let result_arc = Arc::new((Mutex::new(None), Condvar::new())); - let result_arc_clone: Arc<(Mutex>, Condvar)> = result_arc.clone(); - - let set_value_callback = move |value: bool| { - let (lock, cvar) = &*result_arc_clone; - let mut result_guard = lock.lock().unwrap(); - *result_guard = Some(value); - cvar.notify_one(); - - Ok(()) - }; - - // Should anyone find a less mind-boggling way to retrieve the result of a synchronous JS - // function, please do so. - self.is_healthy.call_with_return_value( - (), - napi::threadsafe_function::ThreadsafeFunctionCallMode::Blocking, - set_value_callback, - ); - - // wait for `set_value_callback` to be called and to set the result - let (lock, cvar) = &*result_arc; - let mut result_guard = lock.lock().unwrap(); - while result_guard.is_none() { - result_guard = cvar.wait(result_guard).unwrap(); - } + let mut result = Self { commit, rollback }; + + result.commit.unref(env)?; + result.rollback.unref(env)?; - Ok(result_guard.unwrap_or_default()) + Ok(result) + } + + pub async fn commit(&self) -> napi::Result<()> { + async_unwinding_panic(async move { + let promise = self.commit.call_async::>(()).await?; + promise.await + }) + .await + } + pub async fn rollback(&self) -> napi::Result<()> { + async_unwinding_panic(async move { + let promise = self.rollback.call_async::>(()).await?; + promise.await }) + .await } } diff --git a/query-engine/js-connectors/src/queryable.rs b/query-engine/js-connectors/src/queryable.rs index e40def4841b9..d75f3c376988 100644 --- a/query-engine/js-connectors/src/queryable.rs +++ b/query-engine/js-connectors/src/queryable.rs @@ -1,12 +1,12 @@ use crate::{ error::into_quaint_error, - proxy::{self, Proxy, Query}, + proxy::{CommonProxy, DriverProxy, Query}, }; use async_trait::async_trait; use napi::{Env, JsObject}; use psl::datamodel_connector::Flavour; use quaint::{ - connector::IsolationLevel, + connector::{IsolationLevel, Transaction}, error::{Error, ErrorKind}, prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, visitor::{self, Visitor}, @@ -27,13 +27,14 @@ use tracing::{info_span, Instrument}; /// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector /// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. /// -pub struct JsQueryable { - pub(crate) proxy: Proxy, +pub struct JsBaseQueryable { + pub(crate) proxy: CommonProxy, pub(crate) flavour: Flavour, } -impl JsQueryable { - pub fn new(proxy: Proxy, flavour: Flavour) -> Self { +impl JsBaseQueryable { + pub fn new(proxy: CommonProxy) -> Self { + let flavour: Flavour = proxy.flavour.to_owned().parse().unwrap(); Self { proxy, flavour } } @@ -47,86 +48,49 @@ impl JsQueryable { } } -impl std::fmt::Display for JsQueryable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JSQueryable(driver)") - } -} - -impl std::fmt::Debug for JsQueryable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JSQueryable(driver)") - } -} - #[async_trait] -impl QuaintQueryable for JsQueryable { - /// Execute the given query. +impl QuaintQueryable for JsBaseQueryable { async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { let (sql, params) = self.visit_query(q)?; self.query_raw(&sql, ¶ms).await } - /// Execute a query given as SQL, interpolating the given parameters. async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { let span = info_span!("js:query", user_facing = true); self.do_query_raw(sql, params).instrument(span).await } - /// Execute a query given as SQL, interpolating the given parameters. - /// - /// On Postgres, query parameters types will be inferred from the values - /// instead of letting Postgres infer them based on their usage in the SQL query. - /// - /// NOTE: This method will eventually be removed & merged into Queryable::query_raw(). async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { self.query_raw(sql, params).await } - /// Execute the given query, returning the number of affected rows. async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { let (sql, params) = self.visit_query(q)?; self.execute_raw(&sql, ¶ms).await } - /// Execute a query given as SQL, interpolating the given parameters and - /// returning the number of affected rows. async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { let span = info_span!("js:query", user_facing = true); self.do_execute_raw(sql, params).instrument(span).await } - /// Execute a query given as SQL, interpolating the given parameters and - /// returning the number of affected rows. - /// - /// On Postgres, query parameters types will be inferred from the values - /// instead of letting Postgres infer them based on their usage in the SQL query. - /// - /// NOTE: This method will eventually be removed & merged into Queryable::query_raw(). async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { self.execute_raw(sql, params).await } - /// Run a command in the database, for queries that can't be run using - /// prepared statements. async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { self.execute_raw(cmd, &[]).await?; Ok(()) } - /// Return the version of the underlying database, queried directly from the - /// source. This corresponds to the `version()` function on PostgreSQL for - /// example. The version string is returned directly without any form of - /// parsing or normalization. async fn version(&self) -> quaint::Result> { - // Todo: convert napi::Error to quaint::error::Error. - let version = self.proxy.version().await.unwrap(); - Ok(version) + // Note: JS Connectors don't use this method. + Ok(None) } - /// Returns false, if connection is considered to not be in a working state. fn is_healthy(&self) -> bool { - self.proxy.is_healthy().unwrap_or(false) + // Note: JS Connectors don't use this method. + true } /// Sets the transaction isolation level to given value. @@ -137,12 +101,9 @@ impl QuaintQueryable for JsQueryable { } self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) + .await } - /// Signals if the isolation level SET needs to happen before or after the tx BEGIN. fn requires_isolation_first(&self) -> bool { match self.flavour { Flavour::Mysql => true, @@ -152,7 +113,7 @@ impl QuaintQueryable for JsQueryable { } } -impl JsQueryable { +impl JsBaseQueryable { async fn build_query(sql: &str, values: &[quaint::Value<'_>]) -> Query { let sql: String = sql.to_string(); let args = values.iter().map(|v| v.clone().into()).collect(); @@ -190,10 +151,105 @@ impl JsQueryable { } } -impl TransactionCapable for JsQueryable {} +/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the +/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint +/// types to types that can be translated into javascript and viceversa. This is to let the rest of +/// the query engine work as if it was using quaint itself. The aforementioned transformations are: +/// +/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL +/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client +/// connector) +/// +/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) +/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector +/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. +/// +pub struct JsQueryable { + inner: JsBaseQueryable, + driver_proxy: DriverProxy, +} + +impl std::fmt::Display for JsQueryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSQueryable(driver)") + } +} + +impl std::fmt::Debug for JsQueryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSQueryable(driver)") + } +} + +#[async_trait] +impl QuaintQueryable for JsQueryable { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.query(q).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw(sql, params).await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw_typed(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.execute(q).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw(sql, params).await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw_typed(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + self.inner.raw_cmd(cmd).await + } + + async fn version(&self) -> quaint::Result> { + self.inner.version().await + } + + fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + self.inner.set_tx_isolation_level(isolation_level).await + } + + fn requires_isolation_first(&self) -> bool { + self.inner.requires_isolation_first() + } +} + +#[async_trait] +impl TransactionCapable for JsQueryable { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> quaint::Result> { + let tx = self + .driver_proxy + .start_transaction(isolation) + .await + .map_err(into_quaint_error)?; + + Ok(tx) + } +} pub fn from_napi(napi_env: &Env, driver: JsObject) -> JsQueryable { - let driver = proxy::reify(napi_env, driver).unwrap(); - let flavour = driver.flavour.parse().unwrap(); - JsQueryable::new(driver, flavour) + let common = CommonProxy::new(&driver, napi_env).unwrap(); + let driver_proxy = DriverProxy::new(&driver, napi_env).unwrap(); + + JsQueryable { + inner: JsBaseQueryable::new(common), + driver_proxy, + } } diff --git a/query-engine/js-connectors/src/transaction.rs b/query-engine/js-connectors/src/transaction.rs new file mode 100644 index 000000000000..b7c456467cc5 --- /dev/null +++ b/query-engine/js-connectors/src/transaction.rs @@ -0,0 +1,104 @@ +use async_trait::async_trait; +use napi::{bindgen_prelude::FromNapiValue, Env, JsObject}; +use quaint::{ + connector::{IsolationLevel, Transaction as QuaintTransaction}, + prelude::{Query as QuaintQuery, Queryable, ResultSet}, + Value, +}; + +use crate::{ + error::into_quaint_error, + proxy::{CommonProxy, TransactionProxy}, + queryable::JsBaseQueryable, +}; + +// Wrapper around JS transaction objects that implements Queryable +// and quaint::Transaction. Can be used in place of quaint transaction, +// but delegates most operations to JS +pub struct JsTransaction { + tx_proxy: TransactionProxy, + inner: JsBaseQueryable, +} + +impl JsTransaction { + pub fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { + Self { inner, tx_proxy } + } +} + +#[async_trait] +impl QuaintTransaction for JsTransaction { + async fn commit(&self) -> quaint::Result<()> { + self.tx_proxy.commit().await.map_err(into_quaint_error) + } + + async fn rollback(&self) -> quaint::Result<()> { + self.tx_proxy.rollback().await.map_err(into_quaint_error) + } + + fn as_queryable(&self) -> &dyn Queryable { + self + } +} + +#[async_trait] +impl Queryable for JsTransaction { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.query(q).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw(sql, params).await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw_typed(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.execute(q).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw(sql, params).await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw_typed(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + self.inner.raw_cmd(cmd).await + } + + async fn version(&self) -> quaint::Result> { + self.inner.version().await + } + + fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + self.inner.set_tx_isolation_level(isolation_level).await + } + + fn requires_isolation_first(&self) -> bool { + self.inner.requires_isolation_first() + } +} + +/// Implementing unsafe `from_napi_value` is only way I managed to get threadsafe +/// JsTransaction value in `DriverProxy`. Going through any intermediate safe napi.rs value, +/// like `JsObject` or `JsUnknown` wrapped inside `JsPromise` makes it impossible to extract the value +/// out of promise while keeping the future `Send`. +impl FromNapiValue for JsTransaction { + unsafe fn from_napi_value(env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { + let object = JsObject::from_napi_value(env, napi_val)?; + let env_safe = Env::from_raw(env); + let common_proxy = CommonProxy::new(&object, &env_safe)?; + let tx_proxy = TransactionProxy::new(&object, &env_safe)?; + + Ok(Self::new(JsBaseQueryable::new(common_proxy), tx_proxy)) + } +}