diff --git a/Cargo.lock b/Cargo.lock index 98c10a7a..f94586d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -181,6 +181,7 @@ dependencies = [ name = "cornucopia-example-basic" version = "0.1.0" dependencies = [ + "cornucopia", "deadpool-postgres", "postgres-types", "serde", diff --git a/src/codegen.rs b/src/codegen.rs index 5b63f600..7538ee96 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -11,10 +11,7 @@ use crate::{ use super::prepare_queries::PreparedModule; -pub(crate) fn generate_query( - module_name: &str, - query: &PreparedQuery, -) -> Result<(String, String), Error> { +pub(crate) fn generate_query(module_name: &str, query: &PreparedQuery) -> Result { let name = &query.name; let query_struct = generate_query_struct(query)?.unwrap_or_default(); let params = generate_query_params(query)?; @@ -29,12 +26,10 @@ pub(crate) fn generate_query( let body = generate_query_body(query, ret_ty)?; let query_string = format!( "{query_struct} - pub async fn {name}(client:&Client, {params}) -> Result<{ret},Error> {{{body}}}" + pub async fn {name}(client:&T, {params}) -> Result<{ret},Error> {{{body}}}" ); - let transaction_string = format!("pub async fn {name}<'a>(client:&Transaction<'a>, {params}) -> Result<{ret}, Error> {{{body}}}"); - - Ok((query_string, transaction_string)) + Ok(query_string) } pub(crate) fn generate_custom_type(ty: &CornucopiaType) -> String { @@ -281,7 +276,7 @@ pub(crate) fn generate_query_body(query: &PreparedQuery, ret_ty: String) -> Resu }; Ok(format!( - "let stmt = client.prepare_cached({query_string}).await?; + "let stmt = client.prepare({query_string}).await?; let {res_var_name} = client.{query_method}(&stmt, &[{query_param_values}]).await?; {ret_value}" @@ -293,25 +288,20 @@ pub(crate) fn generate( modules: Vec, destination: &str, ) -> Result<(), Error> { - let query_imports = r#"use deadpool_postgres::Client; -use tokio_postgres::{error::Error};"#; - let transaction_imports = r#"use deadpool_postgres::Transaction; - use tokio_postgres::{error::Error};"#; + let query_imports = r#" +use cornucopia::GenericClient; +use tokio_postgres::Error;"#; let type_modules = generate_type_modules(type_registrar); let mut query_modules = Vec::new(); - let mut transaction_modules = Vec::new(); for module in modules { let mut query_strings = Vec::new(); - let mut transaction_strings = Vec::new(); for query in module.queries { - let (query_string, transaction_string) = generate_query(&module.name, &query)?; + let query_string = generate_query(&module.name, &query)?; query_strings.push(query_string); - transaction_strings.push(transaction_string) } let queries_string = query_strings.join("\n\n"); - let transactions_string = transaction_strings.join("\n\n"); let module_name = module.name; query_modules.push(format!( @@ -319,21 +309,13 @@ use tokio_postgres::{error::Error};"#; {query_imports} {queries_string} -}}"# - )); - transaction_modules.push(format!( - r#"pub mod {module_name} {{ -{transaction_imports} - -{transactions_string} }}"# )); } let generated_modules = format!( - "{type_modules} \n pub mod queries {{ {} }} \n pub mod transactions {{ {} }}", + "{type_modules} \n pub mod queries {{ {} }}", query_modules.join("\n\n"), - transaction_modules.join("\n\n") ); std::fs::write(destination, generated_modules)?; diff --git a/src/generic_client.rs b/src/generic_client.rs index 4e60816b..4002ce48 100644 --- a/src/generic_client.rs +++ b/src/generic_client.rs @@ -1,13 +1,12 @@ +use async_trait::async_trait; +use deadpool_postgres::{Client, Transaction, ClientWrapper}; +use tokio_postgres::{Client as PgClient, Error, Statement, Transaction as PgTransaction}; + // This trait acts as an umbrella for all four of // - `tokio_postgres::Client` // - `deadpool_postgres::Client` // - `tokio_postgres::Transaction` // - `deadpool_postgres::Transaction` - -use async_trait::async_trait; -use deadpool_postgres::{Client, Transaction}; -use tokio_postgres::{Client as PgClient, Error, Statement, Transaction as PgTransaction}; - #[async_trait] pub trait GenericClient { async fn prepare(&self, query: &str) -> Result; @@ -44,7 +43,7 @@ pub trait GenericClient { #[async_trait] impl GenericClient for Transaction<'_> { async fn prepare(&self, query: &str) -> Result { - PgTransaction::prepare(self, query).await + Transaction::prepare_cached(&self, query).await } async fn execute( @@ -146,7 +145,7 @@ impl GenericClient for PgTransaction<'_> { #[async_trait] impl GenericClient for Client { async fn prepare(&self, query: &str) -> Result { - PgClient::prepare(self, query).await + ClientWrapper::prepare_cached(self, query).await } async fn execute( diff --git a/src/integration/cornucopia_gen.rs b/src/integration/cornucopia_gen.rs index 316cc9ca..b9777a2d 100644 --- a/src/integration/cornucopia_gen.rs +++ b/src/integration/cornucopia_gen.rs @@ -24,12 +24,13 @@ pub mod types { } pub mod queries { pub mod module_1 { - use deadpool_postgres::Client; - use tokio_postgres::error::Error; - pub async fn insert_book_one(client: &Client) -> Result<(), Error> { + use crate::GenericClient; + use tokio_postgres::Error; + + pub async fn insert_book_one(client: &T) -> Result<(), Error> { let stmt = client - .prepare_cached( + .prepare( "INSERT INTO Book (title) VALUES ('bob'); ", @@ -40,9 +41,9 @@ VALUES ('bob'); Ok(()) } - pub async fn insert_book_zero_or_one(client: &Client) -> Result<(), Error> { + pub async fn insert_book_zero_or_one(client: &T) -> Result<(), Error> { let stmt = client - .prepare_cached( + .prepare( "INSERT INTO Book (title) VALUES ('alice'); ", @@ -53,9 +54,9 @@ VALUES ('alice'); Ok(()) } - pub async fn insert_book_zero_or_more(client: &Client) -> Result<(), Error> { + pub async fn insert_book_zero_or_more(client: &T) -> Result<(), Error> { let stmt = client - .prepare_cached( + .prepare( "INSERT INTO Book (title) VALUES ('carl'); ", @@ -68,12 +69,15 @@ VALUES ('carl'); } pub mod module_2 { - use deadpool_postgres::Client; - use tokio_postgres::error::Error; - pub async fn authors(client: &Client) -> Result, Error> { + use crate::GenericClient; + use tokio_postgres::Error; + + pub async fn authors( + client: &T, + ) -> Result, Error> { let stmt = client - .prepare_cached( + .prepare( "SELECT * FROM @@ -99,11 +103,11 @@ Author; pub struct Books { pub title: String, } - pub async fn books( - client: &Client, + pub async fn books( + client: &T, ) -> Result, Error> { let stmt = client - .prepare_cached( + .prepare( "SELECT Title FROM @@ -129,409 +133,11 @@ Book; pub struct BooksOptRetParam { pub title: Option, } - pub async fn books_opt_ret_param( - client: &Client, - ) -> Result, Error> { - let stmt = client - .prepare_cached( - "SELECT -Title -FROM -Book; -", - ) - .await?; - let res = client.query(&stmt, &[]).await?; - - let return_value = res - .iter() - .map(|res| { - let return_value_0: Option = res.get(0); - super::super::queries::module_2::BooksOptRetParam { - title: return_value_0, - } - }) - .collect::>(); - Ok(return_value) - } - - pub async fn books_from_author_id(client: &Client, id: &i32) -> Result, Error> { - let stmt = client - .prepare_cached( - "SELECT -Book.Title -FROM -BookAuthor -INNER JOIN Author ON Author.Id = BookAuthor.AuthorId -INNER JOIN Book ON Book.Id = BookAuthor.BookId -WHERE -Author.Id = $1; -", - ) - .await?; - let res = client.query(&stmt, &[&id]).await?; - - let return_value = res - .iter() - .map(|row| { - let value: String = row.get(0); - value - }) - .collect::>(); - Ok(return_value) - } - - pub async fn author_name_by_id_opt( - client: &Client, - id: &i32, - ) -> Result, Error> { - let stmt = client - .prepare_cached( - "SELECT -Author.Name -FROM -Author -WHERE -Author.Id = $1; -", - ) - .await?; - let res = client.query_opt(&stmt, &[&id]).await?; - - let return_value = res.map(|row| { - let value: String = row.get(0); - value - }); - Ok(return_value) - } - - pub async fn author_name_by_id(client: &Client, id: &i32) -> Result { - let stmt = client - .prepare_cached( - "SELECT -Author.Name -FROM -Author -WHERE -Author.Id = $1; -", - ) - .await?; - let res = client.query_one(&stmt, &[&id]).await?; - - let return_value: String = res.get(0); - Ok(return_value) - } - - pub async fn author_name_starting_with( - client: &Client, - s: &str, - ) -> Result, Error> { - let stmt = client - .prepare_cached( - "SELECT -BookAuthor.AuthorId, -Author.Name, -BookAuthor.BookId, -Book.Title -FROM -BookAuthor -INNER JOIN Author ON Author.id = BookAuthor.AuthorId -INNER JOIN Book ON Book.Id = BookAuthor.BookId -WHERE -Author.Name LIKE CONCAT($1::text, '%'); -", - ) - .await?; - let res = client.query(&stmt, &[&s]).await?; - - let return_value = res - .iter() - .map(|res| { - let return_value_0: i32 = res.get(0); - let return_value_1: String = res.get(1); - let return_value_2: i32 = res.get(2); - let return_value_3: String = res.get(3); - ( - return_value_0, - return_value_1, - return_value_2, - return_value_3, - ) - }) - .collect::>(); - Ok(return_value) - } - - pub async fn return_custom_type( - client: &Client, - ) -> Result { - let stmt = client - .prepare_cached( - "SELECT -col1 -FROM -CustomTable; -", - ) - .await?; - let res = client.query_one(&stmt, &[]).await?; - - let return_value: super::super::types::public::CustomComposite = res.get(0); - Ok(return_value) - } - - pub async fn select_where_custom_type( - client: &Client, - spongebob_character: &super::super::types::public::SpongebobCharacter, - ) -> Result { - let stmt = client - .prepare_cached( - "SELECT -col2 -FROM -CustomTable -WHERE (col1).nice = $1; -", - ) - .await?; - let res = client.query_one(&stmt, &[&spongebob_character]).await?; - - let return_value: super::super::types::public::SpongebobCharacter = res.get(0); - Ok(return_value) - } - - pub async fn select_everything( - client: &Client, - ) -> Result< - ( - bool, - bool, - i8, - i16, - i16, - i16, - i16, - i32, - i32, - i32, - i32, - i64, - i64, - i64, - i64, - f32, - f32, - f64, - f64, - String, - String, - Vec, - time::PrimitiveDateTime, - time::PrimitiveDateTime, - time::OffsetDateTime, - time::OffsetDateTime, - time::Date, - time::Time, - serde_json::Value, - serde_json::Value, - uuid::Uuid, - std::net::IpAddr, - eui48::MacAddress, - ), - Error, - > { - let stmt = client - .prepare_cached( - "SELECT -* -FROM -Everything; -", - ) - .await?; - let res = client.query_one(&stmt, &[]).await?; - - let return_value = { - let return_value_0: bool = res.get(0); - let return_value_1: bool = res.get(1); - let return_value_2: i8 = res.get(2); - let return_value_3: i16 = res.get(3); - let return_value_4: i16 = res.get(4); - let return_value_5: i16 = res.get(5); - let return_value_6: i16 = res.get(6); - let return_value_7: i32 = res.get(7); - let return_value_8: i32 = res.get(8); - let return_value_9: i32 = res.get(9); - let return_value_10: i32 = res.get(10); - let return_value_11: i64 = res.get(11); - let return_value_12: i64 = res.get(12); - let return_value_13: i64 = res.get(13); - let return_value_14: i64 = res.get(14); - let return_value_15: f32 = res.get(15); - let return_value_16: f32 = res.get(16); - let return_value_17: f64 = res.get(17); - let return_value_18: f64 = res.get(18); - let return_value_19: String = res.get(19); - let return_value_20: String = res.get(20); - let return_value_21: Vec = res.get(21); - let return_value_22: time::PrimitiveDateTime = res.get(22); - let return_value_23: time::PrimitiveDateTime = res.get(23); - let return_value_24: time::OffsetDateTime = res.get(24); - let return_value_25: time::OffsetDateTime = res.get(25); - let return_value_26: time::Date = res.get(26); - let return_value_27: time::Time = res.get(27); - let return_value_28: serde_json::Value = res.get(28); - let return_value_29: serde_json::Value = res.get(29); - let return_value_30: uuid::Uuid = res.get(30); - let return_value_31: std::net::IpAddr = res.get(31); - let return_value_32: eui48::MacAddress = res.get(32); - ( - return_value_0, - return_value_1, - return_value_2, - return_value_3, - return_value_4, - return_value_5, - return_value_6, - return_value_7, - return_value_8, - return_value_9, - return_value_10, - return_value_11, - return_value_12, - return_value_13, - return_value_14, - return_value_15, - return_value_16, - return_value_17, - return_value_18, - return_value_19, - return_value_20, - return_value_21, - return_value_22, - return_value_23, - return_value_24, - return_value_25, - return_value_26, - return_value_27, - return_value_28, - return_value_29, - return_value_30, - return_value_31, - return_value_32, - ) - }; - Ok(return_value) - } - } -} -pub mod transactions { - pub mod module_1 { - use deadpool_postgres::Transaction; - use tokio_postgres::error::Error; - - pub async fn insert_book_one<'a>(client: &Transaction<'a>) -> Result<(), Error> { - let stmt = client - .prepare_cached( - "INSERT INTO Book (title) -VALUES ('bob'); -", - ) - .await?; - let _ = client.execute(&stmt, &[]).await?; - - Ok(()) - } - - pub async fn insert_book_zero_or_one<'a>(client: &Transaction<'a>) -> Result<(), Error> { - let stmt = client - .prepare_cached( - "INSERT INTO Book (title) -VALUES ('alice'); -", - ) - .await?; - let _ = client.execute(&stmt, &[]).await?; - - Ok(()) - } - - pub async fn insert_book_zero_or_more<'a>(client: &Transaction<'a>) -> Result<(), Error> { - let stmt = client - .prepare_cached( - "INSERT INTO Book (title) -VALUES ('carl'); -", - ) - .await?; - let _ = client.execute(&stmt, &[]).await?; - - Ok(()) - } - } - - pub mod module_2 { - use deadpool_postgres::Transaction; - use tokio_postgres::error::Error; - - pub async fn authors<'a>( - client: &Transaction<'a>, - ) -> Result, Error> { - let stmt = client - .prepare_cached( - "SELECT -* -FROM -Author; -", - ) - .await?; - let res = client.query(&stmt, &[]).await?; - - let return_value = res - .iter() - .map(|res| { - let return_value_0: i32 = res.get(0); - let return_value_1: String = res.get(1); - let return_value_2: String = res.get(2); - (return_value_0, return_value_1, return_value_2) - }) - .collect::>(); - Ok(return_value) - } - - pub async fn books<'a>( - client: &Transaction<'a>, - ) -> Result, Error> { - let stmt = client - .prepare_cached( - "SELECT -Title -FROM -Book; -", - ) - .await?; - let res = client.query(&stmt, &[]).await?; - - let return_value = res - .iter() - .map(|res| { - let return_value_0: String = res.get(0); - super::super::queries::module_2::Books { - title: return_value_0, - } - }) - .collect::>(); - Ok(return_value) - } - - pub async fn books_opt_ret_param<'a>( - client: &Transaction<'a>, + pub async fn books_opt_ret_param( + client: &T, ) -> Result, Error> { let stmt = client - .prepare_cached( + .prepare( "SELECT Title FROM @@ -553,12 +159,12 @@ Book; Ok(return_value) } - pub async fn books_from_author_id<'a>( - client: &Transaction<'a>, + pub async fn books_from_author_id( + client: &T, id: &i32, ) -> Result, Error> { let stmt = client - .prepare_cached( + .prepare( "SELECT Book.Title FROM @@ -582,12 +188,12 @@ Author.Id = $1; Ok(return_value) } - pub async fn author_name_by_id_opt<'a>( - client: &Transaction<'a>, + pub async fn author_name_by_id_opt( + client: &T, id: &i32, ) -> Result, Error> { let stmt = client - .prepare_cached( + .prepare( "SELECT Author.Name FROM @@ -606,12 +212,12 @@ Author.Id = $1; Ok(return_value) } - pub async fn author_name_by_id<'a>( - client: &Transaction<'a>, + pub async fn author_name_by_id( + client: &T, id: &i32, ) -> Result { let stmt = client - .prepare_cached( + .prepare( "SELECT Author.Name FROM @@ -627,12 +233,12 @@ Author.Id = $1; Ok(return_value) } - pub async fn author_name_starting_with<'a>( - client: &Transaction<'a>, + pub async fn author_name_starting_with( + client: &T, s: &str, ) -> Result, Error> { let stmt = client - .prepare_cached( + .prepare( "SELECT BookAuthor.AuthorId, Author.Name, @@ -667,11 +273,11 @@ Author.Name LIKE CONCAT($1::text, '%'); Ok(return_value) } - pub async fn return_custom_type<'a>( - client: &Transaction<'a>, + pub async fn return_custom_type( + client: &T, ) -> Result { let stmt = client - .prepare_cached( + .prepare( "SELECT col1 FROM @@ -685,12 +291,12 @@ CustomTable; Ok(return_value) } - pub async fn select_where_custom_type<'a>( - client: &Transaction<'a>, + pub async fn select_where_custom_type( + client: &T, spongebob_character: &super::super::types::public::SpongebobCharacter, ) -> Result { let stmt = client - .prepare_cached( + .prepare( "SELECT col2 FROM @@ -705,8 +311,8 @@ WHERE (col1).nice = $1; Ok(return_value) } - pub async fn select_everything<'a>( - client: &Transaction<'a>, + pub async fn select_everything( + client: &T, ) -> Result< ( bool, @@ -746,7 +352,7 @@ WHERE (col1).nice = $1; Error, > { let stmt = client - .prepare_cached( + .prepare( "SELECT * FROM diff --git a/src/lib.rs b/src/lib.rs index 88d9fac0..3a73e32d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,6 @@ pub(crate) mod cli; pub(crate) mod codegen; pub(crate) mod container; - pub(crate) mod error; pub(crate) mod pg_type; pub(crate) mod pool; diff --git a/src/run_migrations.rs b/src/run_migrations.rs index 3c9e1f09..f358249a 100644 --- a/src/run_migrations.rs +++ b/src/run_migrations.rs @@ -5,7 +5,6 @@ use error::Error; pub async fn run_migrations(client: &Object, path: &str) -> Result<(), Error> { create_migration_table(client).await?; for migration in read_migrations(path)? { - println!("{}", migration.name); let migration_not_installed = !migration_is_installed(client, &migration.timestamp, &migration.name).await?; if migration_not_installed {