Skip to content

Commit 2517100

Browse files
committed
Overhaul query
This is the template that we'll use for all other methods taking parameters. The `foo_raw` variant is the most flexible (but annoying to use), while `foo` covers the expected common case.
1 parent 1473c09 commit 2517100

File tree

12 files changed

+103
-112
lines changed

12 files changed

+103
-112
lines changed

postgres-native-tls/src/test.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use futures::{FutureExt, TryStreamExt};
1+
use futures::{FutureExt};
22
use native_tls::{self, Certificate};
33
use tokio::net::TcpStream;
44
use tokio_postgres::tls::TlsConnect;
@@ -23,7 +23,6 @@ where
2323
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
2424
let rows = client
2525
.query(&stmt, &[&1i32])
26-
.try_collect::<Vec<_>>()
2726
.await
2827
.unwrap();
2928

@@ -99,7 +98,6 @@ async fn runtime() {
9998
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
10099
let rows = client
101100
.query(&stmt, &[&1i32])
102-
.try_collect::<Vec<_>>()
103101
.await
104102
.unwrap();
105103

postgres-openssl/src/test.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use futures::{FutureExt, TryStreamExt};
1+
use futures::{FutureExt};
22
use openssl::ssl::{SslConnector, SslMethod};
33
use tokio::net::TcpStream;
44
use tokio_postgres::tls::TlsConnect;
@@ -21,7 +21,6 @@ where
2121
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
2222
let rows = client
2323
.query(&stmt, &[&1i32])
24-
.try_collect::<Vec<_>>()
2524
.await
2625
.unwrap();
2726

@@ -110,7 +109,6 @@ async fn runtime() {
110109
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
111110
let rows = client
112111
.query(&stmt, &[&1i32])
113-
.try_collect::<Vec<_>>()
114112
.await
115113
.unwrap();
116114

postgres/src/client.rs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,13 @@ impl Client {
122122
where
123123
T: ?Sized + ToStatement,
124124
{
125-
self.query_iter(query, params)?.collect()
125+
executor::block_on(self.0.query(query, params))
126126
}
127127

128-
/// Like `query`, except that it returns a fallible iterator over the resulting rows rather than buffering the
129-
/// response in memory.
128+
/// A maximally-flexible version of `query`.
129+
///
130+
/// It takes an iterator of parameters rather than a slice, and returns an iterator of rows rather than collecting
131+
/// them into an array.
130132
///
131133
/// # Panics
132134
///
@@ -137,12 +139,13 @@ impl Client {
137139
/// ```no_run
138140
/// use postgres::{Client, NoTls};
139141
/// use fallible_iterator::FallibleIterator;
142+
/// use std::iter;
140143
///
141144
/// # fn main() -> Result<(), postgres::Error> {
142145
/// let mut client = Client::connect("host=localhost user=postgres", NoTls)?;
143146
///
144147
/// let baz = true;
145-
/// let mut it = client.query_iter("SELECT foo FROM bar WHERE baz = $1", &[&baz])?;
148+
/// let mut it = client.query_raw("SELECT foo FROM bar WHERE baz = $1", iter::once(&baz as _))?;
146149
///
147150
/// while let Some(row) = it.next()? {
148151
/// let foo: i32 = row.get("foo");
@@ -151,15 +154,18 @@ impl Client {
151154
/// # Ok(())
152155
/// # }
153156
/// ```
154-
pub fn query_iter<'a, T>(
155-
&'a mut self,
156-
query: &'a T,
157-
params: &'a [&(dyn ToSql + Sync)],
158-
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'a, Error>
157+
pub fn query_raw<'a, T, I>(
158+
&mut self,
159+
query: &T,
160+
params: I,
161+
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error>
159162
where
160163
T: ?Sized + ToStatement,
164+
I: IntoIterator<Item = &'a dyn ToSql>,
165+
I::IntoIter: ExactSizeIterator,
161166
{
162-
Ok(Iter::new(self.0.query(query, params)))
167+
let stream = executor::block_on(self.0.query_raw(query, params))?;
168+
Ok(Iter::new(stream))
163169
}
164170

165171
/// Creates a new prepared statement.

postgres/src/transaction.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,22 @@ impl<'a> Transaction<'a> {
5555
where
5656
T: ?Sized + ToStatement,
5757
{
58-
self.query_iter(query, params)?.collect()
58+
executor::block_on(self.0.query(query, params))
5959
}
6060

61-
/// Like `Client::query_iter`.
62-
pub fn query_iter<'b, T>(
63-
&'b mut self,
64-
query: &'b T,
65-
params: &'b [&(dyn ToSql + Sync)],
66-
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'b, Error>
61+
/// Like `Client::query_raw`.
62+
pub fn query_raw<'b, T, I>(
63+
&mut self,
64+
query: &T,
65+
params: I,
66+
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error>
6767
where
6868
T: ?Sized + ToStatement,
69+
I: IntoIterator<Item = &'b dyn ToSql>,
70+
I::IntoIter: ExactSizeIterator,
6971
{
70-
Ok(Iter::new(self.0.query(query, params)))
72+
let stream = executor::block_on(self.0.query_raw(query, params))?;
73+
Ok(Iter::new(stream))
7174
}
7275

7376
/// Binds parameters to a statement, creating a "portal".

tokio-postgres/src/client.rs

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::cancel_query;
33
use crate::codec::BackendMessages;
44
use crate::config::{Host, SslMode};
55
use crate::connection::{Request, RequestMessages};
6+
use crate::query::RowStream;
67
use crate::slice_iter;
78
#[cfg(feature = "runtime")]
89
use crate::tls::MakeTlsConnect;
@@ -18,7 +19,7 @@ use crate::{Error, Statement};
1819
use bytes::{Bytes, IntoBuf};
1920
use fallible_iterator::FallibleIterator;
2021
use futures::channel::mpsc;
21-
use futures::{future, Stream, TryFutureExt, TryStream};
22+
use futures::{future, Stream, TryFutureExt, TryStream, TryStreamExt};
2223
use futures::{ready, StreamExt};
2324
use parking_lot::Mutex;
2425
use postgres_protocol::message::backend::Message;
@@ -190,40 +191,40 @@ impl Client {
190191
prepare::prepare(&self.inner, query, parameter_types).await
191192
}
192193

193-
/// Executes a statement, returning a stream of the resulting rows.
194+
/// Executes a statement, returning a vector of the resulting rows.
194195
///
195196
/// # Panics
196197
///
197198
/// Panics if the number of parameters provided does not match the number expected.
198-
pub fn query<'a, T>(
199-
&'a self,
200-
statement: &'a T,
201-
params: &'a [&(dyn ToSql + Sync)],
202-
) -> impl Stream<Item = Result<Row, Error>> + 'a
199+
pub async fn query<T>(
200+
&self,
201+
statement: &T,
202+
params: &[&(dyn ToSql + Sync)],
203+
) -> Result<Vec<Row>, Error>
203204
where
204205
T: ?Sized + ToStatement,
205206
{
206-
self.query_iter(statement, slice_iter(params))
207+
self.query_raw(statement, slice_iter(params))
208+
.await?
209+
.try_collect()
210+
.await
207211
}
208212

209-
/// Like [`query`], but takes an iterator of parameters rather than a slice.
213+
/// The maximally flexible version of [`query`].
214+
///
215+
/// # Panics
216+
///
217+
/// Panics if the number of parameters provided does not match the number expected.
210218
///
211219
/// [`query`]: #method.query
212-
pub fn query_iter<'a, T, I>(
213-
&'a self,
214-
statement: &'a T,
215-
params: I,
216-
) -> impl Stream<Item = Result<Row, Error>> + 'a
220+
pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
217221
where
218222
T: ?Sized + ToStatement,
219-
I: IntoIterator<Item = &'a dyn ToSql> + 'a,
223+
I: IntoIterator<Item = &'a dyn ToSql>,
220224
I::IntoIter: ExactSizeIterator,
221225
{
222-
let f = async move {
223-
let statement = statement.__convert().into_statement(self).await?;
224-
Ok(query::query(&self.inner, statement, params))
225-
};
226-
f.try_flatten_stream()
226+
let statement = statement.__convert().into_statement(self).await?;
227+
query::query(&self.inner, statement, params).await
227228
}
228229

229230
/// Executes a statement, returning the number of rows modified.
@@ -241,13 +242,17 @@ impl Client {
241242
where
242243
T: ?Sized + ToStatement,
243244
{
244-
self.execute_iter(statement, slice_iter(params)).await
245+
self.execute_raw(statement, slice_iter(params)).await
245246
}
246247

247-
/// Like [`execute`], but takes an iterator of parameters rather than a slice.
248+
/// The maximally flexible version of [`execute`].
249+
///
250+
/// # Panics
251+
///
252+
/// Panics if the number of parameters provided does not match the number expected.
248253
///
249254
/// [`execute`]: #method.execute
250-
pub async fn execute_iter<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
255+
pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
251256
where
252257
T: ?Sized + ToStatement,
253258
I: IntoIterator<Item = &'a dyn ToSql>,

tokio-postgres/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//! # Example
44
//!
55
//! ```no_run
6-
//! use futures::{FutureExt, TryStreamExt};
6+
//! use futures::FutureExt;
77
//! use tokio_postgres::{NoTls, Error, Row};
88
//!
99
//! # #[cfg(not(feature = "runtime"))] fn main() {}
@@ -29,7 +29,6 @@
2929
//! // And then execute it, returning a Stream of Rows which we collect into a Vec.
3030
//! let rows: Vec<Row> = client
3131
//! .query(&stmt, &[&"hello world"])
32-
//! .try_collect()
3332
//! .await?;
3433
//!
3534
//! // Now we can check that we got back the same string we sent over.

tokio-postgres/src/prepare.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::types::{Field, Kind, Oid, Type};
66
use crate::{query, slice_iter};
77
use crate::{Column, Error, Statement};
88
use fallible_iterator::FallibleIterator;
9-
use futures::{future, TryStreamExt};
9+
use futures::TryStreamExt;
1010
use pin_utils::pin_mut;
1111
use postgres_protocol::message::backend::Message;
1212
use postgres_protocol::message::frontend;
@@ -132,8 +132,7 @@ async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
132132

133133
let stmt = typeinfo_statement(client).await?;
134134

135-
let params = &[&oid as _];
136-
let rows = query::query(client, stmt, slice_iter(params));
135+
let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
137136
pin_mut!(rows);
138137

139138
let row = match rows.try_next().await? {
@@ -204,7 +203,8 @@ async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<St
204203
let stmt = typeinfo_enum_statement(client).await?;
205204

206205
query::query(client, stmt, slice_iter(&[&oid]))
207-
.and_then(|row| future::ready(row.try_get(0)))
206+
.await?
207+
.and_then(|row| async move { row.try_get(0) })
208208
.try_collect()
209209
.await
210210
}
@@ -230,6 +230,7 @@ async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec
230230
let stmt = typeinfo_composite_statement(client).await?;
231231

232232
let rows = query::query(client, stmt, slice_iter(&[&oid]))
233+
.await?
233234
.try_collect::<Vec<_>>()
234235
.await?;
235236

tokio-postgres/src/query.rs

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,21 @@ use postgres_protocol::message::frontend;
99
use std::pin::Pin;
1010
use std::task::{Context, Poll};
1111

12-
pub fn query<'a, I>(
13-
client: &'a InnerClient,
12+
pub async fn query<'a, I>(
13+
client: &InnerClient,
1414
statement: Statement,
1515
params: I,
16-
) -> impl Stream<Item = Result<Row, Error>> + 'a
16+
) -> Result<RowStream, Error>
1717
where
18-
I: IntoIterator<Item = &'a dyn ToSql> + 'a,
18+
I: IntoIterator<Item = &'a dyn ToSql>,
1919
I::IntoIter: ExactSizeIterator,
2020
{
21-
let f = async move {
22-
let buf = encode(&statement, params)?;
23-
let responses = start(client, buf).await?;
24-
Ok(Query {
25-
statement,
26-
responses,
27-
})
28-
};
29-
f.try_flatten_stream()
21+
let buf = encode(&statement, params)?;
22+
let responses = start(client, buf).await?;
23+
Ok(RowStream {
24+
statement,
25+
responses,
26+
})
3027
}
3128

3229
pub fn query_portal<'a>(
@@ -41,7 +38,7 @@ pub fn query_portal<'a>(
4138

4239
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
4340

44-
Ok(Query {
41+
Ok(RowStream {
4542
statement: portal.statement().clone(),
4643
responses,
4744
})
@@ -145,12 +142,12 @@ where
145142
}
146143
}
147144

148-
struct Query {
145+
pub struct RowStream {
149146
statement: Statement,
150147
responses: Responses,
151148
}
152149

153-
impl Stream for Query {
150+
impl Stream for RowStream {
154151
type Item = Result<Row, Error>;
155152

156153
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {

0 commit comments

Comments
 (0)