Skip to content
This repository was archived by the owner on Jan 14, 2025. It is now read-only.

Commit ab5a691

Browse files
tomhoulepimeys
authored andcommitted
tokio-postgres: prepare and execute unnamed statements in one roundtrip (#6)
* tokio-postgres: prepare and execute unnamed statements in one roundtrip * cleanup --------- Co-authored-by: Julius de Bruijn <julius@nauk.io>
1 parent 1bdc8ec commit ab5a691

File tree

7 files changed

+93
-94
lines changed

7 files changed

+93
-94
lines changed

postgres/src/client.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -521,15 +521,6 @@ impl Client {
521521
CancelToken::new(self.client.cancel_token())
522522
}
523523

524-
/// Clears the client's type information cache.
525-
///
526-
/// When user-defined types are used in a query, the client loads their definitions from the database and caches
527-
/// them for the lifetime of the client. If those definitions are changed in the database, this method can be used
528-
/// to flush the local cache and allow the new, updated definitions to be loaded.
529-
pub fn clear_type_cache(&self) {
530-
self.client.clear_type_cache();
531-
}
532-
533524
/// Determines if the client's connection has already closed.
534525
///
535526
/// If this returns `true`, the client is no longer usable.

tokio-postgres/src/client.rs

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::simple_query::SimpleQueryStream;
99
#[cfg(feature = "runtime")]
1010
use crate::tls::MakeTlsConnect;
1111
use crate::tls::TlsConnect;
12-
use crate::types::{Oid, ToSql, Type};
12+
use crate::types::{ToSql, Type};
1313
#[cfg(feature = "runtime")]
1414
use crate::Socket;
1515
use crate::{
@@ -23,7 +23,6 @@ use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
2323
use parking_lot::Mutex;
2424
use postgres_protocol::message::{backend::Message, frontend};
2525
use postgres_types::BorrowToSql;
26-
use std::collections::HashMap;
2726
use std::fmt;
2827
#[cfg(feature = "runtime")]
2928
use std::net::IpAddr;
@@ -63,8 +62,6 @@ impl Responses {
6362

6463
pub struct InnerClient {
6564
sender: mpsc::UnboundedSender<Request>,
66-
cached_typeinfo: Mutex<HashMap<Oid, Type>>,
67-
6865
/// A buffer to use when writing out postgres commands.
6966
buffer: Mutex<BytesMut>,
7067
}
@@ -83,14 +80,6 @@ impl InnerClient {
8380
})
8481
}
8582

86-
pub fn type_(&self, oid: Oid) -> Option<Type> {
87-
self.cached_typeinfo.lock().get(&oid).cloned()
88-
}
89-
90-
pub fn clear_type_cache(&self) {
91-
self.cached_typeinfo.lock().clear();
92-
}
93-
9483
/// Call the given function with a buffer to be used when writing out
9584
/// postgres commands.
9685
pub fn with_buf<F, R>(&self, f: F) -> R
@@ -146,7 +135,6 @@ impl Client {
146135
Client {
147136
inner: Arc::new(InnerClient {
148137
sender,
149-
cached_typeinfo: Default::default(),
150138
buffer: Default::default(),
151139
}),
152140
#[cfg(feature = "runtime")]
@@ -322,19 +310,13 @@ impl Client {
322310

323311
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
324312
/// to save a roundtrip
325-
pub async fn query_raw_txt<'a, T, S, I>(
326-
&self,
327-
statement: &T,
328-
params: I,
329-
) -> Result<RowStream, Error>
313+
pub async fn query_raw_txt<'a, S, I>(&self, query: &str, params: I) -> Result<RowStream, Error>
330314
where
331-
T: ?Sized + ToStatement,
332315
S: AsRef<str>,
333316
I: IntoIterator<Item = Option<S>>,
334317
I::IntoIter: ExactSizeIterator,
335318
{
336-
let statement = statement.__convert().into_statement(self).await?;
337-
query::query_txt(&self.inner, statement, params).await
319+
query::query_txt(&self.inner, query, params).await
338320
}
339321

340322
/// Executes a statement, returning the number of rows modified.
@@ -527,15 +509,6 @@ impl Client {
527509
self.cancel_token().cancel_query_raw(stream, tls).await
528510
}
529511

530-
/// Clears the client's type information cache.
531-
///
532-
/// When user-defined types are used in a query, the client loads their definitions from the database and caches
533-
/// them for the lifetime of the client. If those definitions are changed in the database, this method can be used
534-
/// to flush the local cache and allow the new, updated definitions to be loaded.
535-
pub fn clear_type_cache(&self) {
536-
self.inner().clear_type_cache();
537-
}
538-
539512
/// Determines if the connection to the server has already closed.
540513
///
541514
/// In that case, all future queries will fail.

tokio-postgres/src/generic_client.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,12 @@ pub trait GenericClient: private::Sealed {
5757
I::IntoIter: ExactSizeIterator;
5858

5959
/// Like `Client::query_raw_txt`.
60-
async fn query_raw_txt<'a, T, S, I>(
60+
async fn query_raw_txt<'a, S, I>(
6161
&self,
62-
statement: &T,
62+
statement: &str,
6363
params: I,
6464
) -> Result<RowStream, Error>
6565
where
66-
T: ?Sized + ToStatement + Sync + Send,
6766
S: AsRef<str> + Sync + Send,
6867
I: IntoIterator<Item = Option<S>> + Sync + Send,
6968
I::IntoIter: ExactSizeIterator + Sync + Send;
@@ -148,9 +147,8 @@ impl GenericClient for Client {
148147
self.query_raw(statement, params).await
149148
}
150149

151-
async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
150+
async fn query_raw_txt<'a, S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
152151
where
153-
T: ?Sized + ToStatement + Sync + Send,
154152
S: AsRef<str> + Sync + Send,
155153
I: IntoIterator<Item = Option<S>> + Sync + Send,
156154
I::IntoIter: ExactSizeIterator + Sync + Send,
@@ -244,9 +242,8 @@ impl GenericClient for Transaction<'_> {
244242
self.query_raw(statement, params).await
245243
}
246244

247-
async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
245+
async fn query_raw_txt<'a, S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
248246
where
249-
T: ?Sized + ToStatement + Sync + Send,
250247
S: AsRef<str> + Sync + Send,
251248
I: IntoIterator<Item = Option<S>> + Sync + Send,
252249
I::IntoIter: ExactSizeIterator + Sync + Send,

tokio-postgres/src/prepare.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ pub async fn prepare(
4747
let mut parameters = vec![];
4848
let mut it = parameter_description.parameters();
4949
while let Some(oid) = it.next().map_err(Error::parse)? {
50-
let type_ = get_type(client, oid).await?;
50+
let type_ = get_type(oid);
5151
parameters.push(type_);
5252
}
5353

@@ -67,13 +67,18 @@ pub async fn prepare(
6767
}
6868

6969
if unnamed {
70-
Ok(Statement::unnamed(query.to_owned(), parameters, columns))
70+
Ok(Statement::unnamed(parameters, columns))
7171
} else {
7272
Ok(Statement::named(client, name, parameters, columns))
7373
}
7474
}
7575

76-
fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
76+
pub(crate) fn encode(
77+
client: &InnerClient,
78+
name: &str,
79+
query: &str,
80+
types: &[Type],
81+
) -> Result<Bytes, Error> {
7782
if types.is_empty() {
7883
debug!("preparing query {}: {}", name, query);
7984
} else {
@@ -88,14 +93,10 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
8893
})
8994
}
9095

91-
pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
96+
pub fn get_type(oid: Oid) -> Type {
9297
if let Some(type_) = Type::from_oid(oid) {
93-
return Ok(type_);
94-
}
95-
96-
if let Some(type_) = client.type_(oid) {
97-
return Ok(type_);
98+
return type_;
9899
}
99100

100-
Ok(Type::TEXT)
101+
Type::TEXT
101102
}

tokio-postgres/src/query.rs

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ use crate::client::{InnerClient, Responses};
22
use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
44
use crate::types::{BorrowToSql, IsNull};
5-
use crate::{Error, Portal, Row, Statement};
5+
use crate::{Column, Error, Portal, Row, Statement};
66
use bytes::{BufMut, Bytes, BytesMut};
7+
use fallible_iterator::FallibleIterator;
78
use futures_util::{ready, Stream};
89
use log::{debug, log_enabled, Level};
910
use pin_project_lite::pin_project;
10-
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
11+
use postgres_protocol::message::backend::{
12+
CommandCompleteBody, Message, ParameterDescriptionBody, RowDescriptionBody,
13+
};
1114
use postgres_protocol::message::frontend;
1215
use postgres_types::Format;
1316
use std::fmt;
@@ -50,21 +53,22 @@ where
5053
} else {
5154
encode(client, &statement, params)?
5255
};
53-
let responses = start(client, buf).await?;
56+
let (statement, responses) = start(client, buf).await?;
5457
Ok(RowStream {
5558
statement,
5659
responses,
5760
rows_affected: None,
5861
command_tag: None,
5962
status: None,
6063
output_format: Format::Binary,
64+
parameter_description: None,
6165
_p: PhantomPinned,
6266
})
6367
}
6468

6569
pub async fn query_txt<S, I>(
6670
client: &Arc<InnerClient>,
67-
statement: Statement,
71+
query: &str,
6872
params: I,
6973
) -> Result<RowStream, Error>
7074
where
@@ -75,10 +79,15 @@ where
7579
let params = params.into_iter();
7680

7781
let buf = client.with_buf(|buf| {
82+
// prepare
83+
frontend::parse("", query, std::iter::empty(), buf).map_err(Error::encode)?;
84+
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
85+
frontend::flush(buf);
86+
7887
// Bind, pass params as text, retrieve as binary
7988
match frontend::bind(
8089
"", // empty string selects the unnamed portal
81-
statement.name(), // named prepared statement
90+
"", // unnamed prepared statement
8291
std::iter::empty(), // all parameters use the default format (text)
8392
params,
8493
|param, buf| match param {
@@ -105,9 +114,10 @@ where
105114
})?;
106115

107116
// now read the responses
108-
let responses = start(client, buf).await?;
117+
let (statement, responses) = start(client, buf).await?;
109118

110119
Ok(RowStream {
120+
parameter_description: None,
111121
statement,
112122
responses,
113123
command_tag: None,
@@ -132,7 +142,8 @@ pub async fn query_portal(
132142
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
133143

134144
Ok(RowStream {
135-
statement: portal.statement().clone(),
145+
parameter_description: None,
146+
statement: Some(portal.statement().clone()),
136147
responses,
137148
rows_affected: None,
138149
command_tag: None,
@@ -176,7 +187,7 @@ where
176187
} else {
177188
encode(client, &statement, params)?
178189
};
179-
let mut responses = start(client, buf).await?;
190+
let (_statement, mut responses) = start(client, buf).await?;
180191

181192
let mut rows = 0;
182193
loop {
@@ -192,19 +203,57 @@ where
192203
}
193204
}
194205

195-
async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
206+
async fn start(client: &InnerClient, buf: Bytes) -> Result<(Option<Statement>, Responses), Error> {
207+
let mut parameter_description: Option<ParameterDescriptionBody> = None;
208+
let mut statement = None;
196209
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
197210

198-
match responses.next().await? {
199-
Message::ParseComplete => match responses.next().await? {
200-
Message::BindComplete => {}
211+
loop {
212+
match responses.next().await? {
213+
Message::ParseComplete => {}
214+
Message::BindComplete => return Ok((statement, responses)),
215+
Message::ParameterDescription(body) => {
216+
parameter_description = Some(body); // tooo-o-ooo-o loooove
217+
}
218+
Message::NoData => {
219+
statement = Some(make_statement(parameter_description.take().unwrap(), None)?);
220+
}
221+
Message::RowDescription(body) => {
222+
statement = Some(make_statement(
223+
parameter_description.take().unwrap(),
224+
Some(body),
225+
)?);
226+
}
201227
m => return Err(Error::unexpected_message(m)),
202-
},
203-
Message::BindComplete => {}
204-
m => return Err(Error::unexpected_message(m)),
228+
}
229+
}
230+
}
231+
232+
fn make_statement(
233+
parameter_description: ParameterDescriptionBody,
234+
row_description: Option<RowDescriptionBody>,
235+
) -> Result<Statement, Error> {
236+
let mut parameters = vec![];
237+
let mut it = parameter_description.parameters();
238+
239+
while let Some(oid) = it.next().map_err(Error::parse).unwrap() {
240+
let type_ = crate::prepare::get_type(oid);
241+
parameters.push(type_);
205242
}
206243

207-
Ok(responses)
244+
let mut columns = Vec::new();
245+
246+
if let Some(row_description) = row_description {
247+
let mut it = row_description.fields();
248+
249+
while let Some(field) = it.next().map_err(Error::parse)? {
250+
let type_ = crate::prepare::get_type(field.type_oid());
251+
let column = Column::new(field.name().to_string(), type_, field);
252+
columns.push(column);
253+
}
254+
}
255+
256+
Ok(Statement::unnamed(parameters, columns))
208257
}
209258

210259
pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
@@ -214,12 +263,10 @@ where
214263
I::IntoIter: ExactSizeIterator,
215264
{
216265
client.with_buf(|buf| {
217-
if let Some(query) = statement.query() {
218-
frontend::parse("", query, [], buf).unwrap();
219-
}
220266
encode_bind(statement, params, "", buf)?;
221267
frontend::execute("", 0, buf).map_err(Error::encode)?;
222268
frontend::sync(buf);
269+
223270
Ok(buf.split().freeze())
224271
})
225272
}
@@ -276,12 +323,14 @@ where
276323
pin_project! {
277324
/// A stream of table rows.
278325
pub struct RowStream {
279-
statement: Statement,
326+
statement: Option<Statement>,
280327
responses: Responses,
281328
rows_affected: Option<u64>,
282329
command_tag: Option<String>,
283330
output_format: Format,
284331
status: Option<u8>,
332+
parameter_description: Option<ParameterDescriptionBody>,
333+
285334
#[pin]
286335
_p: PhantomPinned,
287336
}
@@ -292,11 +341,12 @@ impl Stream for RowStream {
292341

293342
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
294343
let this = self.project();
344+
295345
loop {
296346
match ready!(this.responses.poll_next(cx)?) {
297347
Message::DataRow(body) => {
298348
return Poll::Ready(Some(Ok(Row::new(
299-
this.statement.clone(),
349+
this.statement.as_ref().unwrap().clone(),
300350
body,
301351
*this.output_format,
302352
)?)))

0 commit comments

Comments
 (0)