Skip to content

Commit a6535b4

Browse files
committed
Internal connect refactoring
1 parent e0d1137 commit a6535b4

File tree

3 files changed

+197
-167
lines changed

3 files changed

+197
-167
lines changed

tokio-postgres/src/proto/connect_once.rs

Lines changed: 14 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,27 @@
11
#![allow(clippy::large_enum_variant)]
22

33
use futures::{try_ready, Async, Future, Poll, Stream};
4-
use futures_cpupool::{CpuFuture, CpuPool};
5-
use lazy_static::lazy_static;
64
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
75
use std::io;
8-
use std::net::{SocketAddr, ToSocketAddrs};
9-
use std::time::Instant;
10-
use std::vec;
11-
use tokio_tcp::TcpStream;
12-
use tokio_timer::Delay;
13-
#[cfg(unix)]
14-
use tokio_uds::UnixStream;
156

16-
use crate::proto::{Client, Connection, HandshakeFuture, SimpleQueryStream};
17-
use crate::{Config, Error, Host, Socket, TargetSessionAttrs, TlsMode};
18-
19-
lazy_static! {
20-
static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new()
21-
.name_prefix("postgres-dns-")
22-
.pool_size(2)
23-
.create();
24-
}
7+
use crate::proto::{Client, ConnectSocketFuture, Connection, HandshakeFuture, SimpleQueryStream};
8+
use crate::{Config, Error, Socket, TargetSessionAttrs, TlsMode};
259

2610
#[derive(StateMachineFuture)]
2711
pub enum ConnectOnce<T>
2812
where
2913
T: TlsMode<Socket>,
3014
{
31-
#[state_machine_future(start)]
32-
#[cfg_attr(unix, state_machine_future(transitions(ConnectingUnix, ResolvingDns)))]
33-
#[cfg_attr(not(unix), state_machine_future(transitions(ConnectingTcp)))]
15+
#[state_machine_future(start, transitions(ConnectingSocket))]
3416
Start {
3517
idx: usize,
3618
tls_mode: T,
3719
config: Config,
3820
},
39-
#[cfg(unix)]
40-
#[state_machine_future(transitions(Handshaking))]
41-
ConnectingUnix {
42-
future: tokio_uds::ConnectFuture,
43-
timeout: Option<Delay>,
44-
tls_mode: T,
45-
config: Config,
46-
},
47-
#[state_machine_future(transitions(ConnectingTcp))]
48-
ResolvingDns {
49-
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
50-
timeout: Option<Delay>,
51-
tls_mode: T,
52-
config: Config,
53-
},
5421
#[state_machine_future(transitions(Handshaking))]
55-
ConnectingTcp {
56-
future: tokio_tcp::ConnectFuture,
57-
addrs: vec::IntoIter<SocketAddr>,
58-
timeout: Option<Delay>,
22+
ConnectingSocket {
23+
future: ConnectSocketFuture,
24+
idx: usize,
5925
tls_mode: T,
6026
config: Config,
6127
},
@@ -83,142 +49,23 @@ where
8349
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
8450
let state = state.take();
8551

86-
let port = *state
87-
.config
88-
.0
89-
.port
90-
.get(state.idx)
91-
.or_else(|| state.config.0.port.get(0))
92-
.unwrap_or(&5432);
93-
94-
let timeout = state
95-
.config
96-
.0
97-
.connect_timeout
98-
.map(|d| Delay::new(Instant::now() + d));
99-
100-
match &state.config.0.host[state.idx] {
101-
Host::Tcp(host) => {
102-
let host = host.clone();
103-
transition!(ResolvingDns {
104-
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
105-
timeout,
106-
tls_mode: state.tls_mode,
107-
config: state.config,
108-
})
109-
}
110-
#[cfg(unix)]
111-
Host::Unix(host) => {
112-
let path = host.join(format!(".s.PGSQL.{}", port));
113-
transition!(ConnectingUnix {
114-
future: UnixStream::connect(path),
115-
timeout,
116-
tls_mode: state.tls_mode,
117-
config: state.config,
118-
})
119-
}
120-
}
121-
}
122-
123-
#[cfg(unix)]
124-
fn poll_connecting_unix<'a>(
125-
state: &'a mut RentToOwn<'a, ConnectingUnix<T>>,
126-
) -> Poll<AfterConnectingUnix<T>, Error> {
127-
if let Some(timeout) = &mut state.timeout {
128-
match timeout.poll() {
129-
Ok(Async::Ready(())) => {
130-
return Err(Error::connect(io::Error::from(io::ErrorKind::TimedOut)))
131-
}
132-
Ok(Async::NotReady) => {}
133-
Err(e) => return Err(Error::connect(io::Error::new(io::ErrorKind::Other, e))),
134-
}
135-
}
136-
137-
let stream = try_ready!(state.future.poll().map_err(Error::connect));
138-
let stream = Socket::new_unix(stream);
139-
let state = state.take();
140-
141-
transition!(Handshaking {
142-
target_session_attrs: state.config.0.target_session_attrs,
143-
future: HandshakeFuture::new(stream, state.tls_mode, state.config),
144-
})
145-
}
146-
147-
fn poll_resolving_dns<'a>(
148-
state: &'a mut RentToOwn<'a, ResolvingDns<T>>,
149-
) -> Poll<AfterResolvingDns<T>, Error> {
150-
if let Some(timeout) = &mut state.timeout {
151-
match timeout.poll() {
152-
Ok(Async::Ready(())) => {
153-
return Err(Error::connect(io::Error::from(io::ErrorKind::TimedOut)))
154-
}
155-
Ok(Async::NotReady) => {}
156-
Err(e) => return Err(Error::connect(io::Error::new(io::ErrorKind::Other, e))),
157-
}
158-
}
159-
160-
let mut addrs = try_ready!(state.future.poll().map_err(Error::connect));
161-
let state = state.take();
162-
163-
let addr = match addrs.next() {
164-
Some(addr) => addr,
165-
None => {
166-
return Err(Error::connect(io::Error::new(
167-
io::ErrorKind::InvalidData,
168-
"resolved 0 addresses",
169-
)));
170-
}
171-
};
172-
173-
transition!(ConnectingTcp {
174-
future: TcpStream::connect(&addr),
175-
addrs,
176-
timeout: state.timeout,
52+
transition!(ConnectingSocket {
53+
future: ConnectSocketFuture::new(state.config.clone(), state.idx),
54+
idx: state.idx,
17755
tls_mode: state.tls_mode,
17856
config: state.config,
17957
})
18058
}
18159

182-
fn poll_connecting_tcp<'a>(
183-
state: &'a mut RentToOwn<'a, ConnectingTcp<T>>,
184-
) -> Poll<AfterConnectingTcp<T>, Error> {
185-
if let Some(timeout) = &mut state.timeout {
186-
match timeout.poll() {
187-
Ok(Async::Ready(())) => {
188-
return Err(Error::connect(io::Error::from(io::ErrorKind::TimedOut)))
189-
}
190-
Ok(Async::NotReady) => {}
191-
Err(e) => return Err(Error::connect(io::Error::new(io::ErrorKind::Other, e))),
192-
}
193-
}
194-
195-
let stream = loop {
196-
match state.future.poll() {
197-
Ok(Async::Ready(stream)) => break stream,
198-
Ok(Async::NotReady) => return Ok(Async::NotReady),
199-
Err(e) => {
200-
let addr = match state.addrs.next() {
201-
Some(addr) => addr,
202-
None => return Err(Error::connect(e)),
203-
};
204-
state.future = TcpStream::connect(&addr);
205-
}
206-
}
207-
};
60+
fn poll_connecting_socket<'a>(
61+
state: &'a mut RentToOwn<'a, ConnectingSocket<T>>,
62+
) -> Poll<AfterConnectingSocket<T>, Error> {
63+
let socket = try_ready!(state.future.poll());
20864
let state = state.take();
20965

210-
stream.set_nodelay(true).map_err(Error::connect)?;
211-
if state.config.0.keepalives {
212-
stream
213-
.set_keepalive(Some(state.config.0.keepalives_idle))
214-
.map_err(Error::connect)?;
215-
}
216-
217-
let stream = Socket::new_tcp(stream);
218-
21966
transition!(Handshaking {
22067
target_session_attrs: state.config.0.target_session_attrs,
221-
future: HandshakeFuture::new(stream, state.tls_mode, state.config),
68+
future: HandshakeFuture::new(socket, state.tls_mode, state.config),
22269
})
22370
}
22471

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
use futures::{try_ready, Async, Future, Poll};
2+
use futures_cpupool::{CpuFuture, CpuPool};
3+
use lazy_static::lazy_static;
4+
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
5+
use std::io;
6+
use std::net::{SocketAddr, ToSocketAddrs};
7+
use std::time::Instant;
8+
use std::vec;
9+
use tokio_tcp::TcpStream;
10+
use tokio_timer::Delay;
11+
#[cfg(unix)]
12+
use tokio_uds::UnixStream;
13+
14+
use crate::{Config, Error, Host, Socket};
15+
16+
lazy_static! {
17+
static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new()
18+
.name_prefix("postgres-dns-")
19+
.pool_size(2)
20+
.create();
21+
}
22+
23+
#[derive(StateMachineFuture)]
24+
pub enum ConnectSocket {
25+
#[state_machine_future(start)]
26+
#[cfg_attr(unix, state_machine_future(transitions(ConnectingUnix, ResolvingDns)))]
27+
#[cfg_attr(not(unix), state_machine_future(transitions(ResolvingDns)))]
28+
Start { config: Config, idx: usize },
29+
#[cfg(unix)]
30+
#[state_machine_future(transitions(Finished))]
31+
ConnectingUnix {
32+
future: tokio_uds::ConnectFuture,
33+
timeout: Option<Delay>,
34+
},
35+
#[state_machine_future(transitions(ConnectingTcp))]
36+
ResolvingDns {
37+
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
38+
config: Config,
39+
},
40+
#[state_machine_future(transitions(Finished))]
41+
ConnectingTcp {
42+
future: tokio_tcp::ConnectFuture,
43+
timeout: Option<Delay>,
44+
addrs: vec::IntoIter<SocketAddr>,
45+
config: Config,
46+
},
47+
#[state_machine_future(ready)]
48+
Finished(Socket),
49+
#[state_machine_future(error)]
50+
Failed(Error),
51+
}
52+
53+
impl PollConnectSocket for ConnectSocket {
54+
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
55+
let state = state.take();
56+
57+
let port = *state
58+
.config
59+
.0
60+
.port
61+
.get(state.idx)
62+
.or_else(|| state.config.0.port.get(0))
63+
.unwrap_or(&5432);
64+
65+
match &state.config.0.host[state.idx] {
66+
Host::Tcp(host) => transition!(ResolvingDns {
67+
future: DNS_POOL.spawn_fn({
68+
let host = host.clone();
69+
move || (&*host, port).to_socket_addrs()
70+
}),
71+
config: state.config,
72+
}),
73+
#[cfg(unix)]
74+
Host::Unix(host) => {
75+
let path = host.join(format!(".s.PGSQL.{}", port));
76+
let timeout = state
77+
.config
78+
.0
79+
.connect_timeout
80+
.map(|d| Delay::new(Instant::now() + d));
81+
transition!(ConnectingUnix {
82+
future: UnixStream::connect(path),
83+
timeout,
84+
})
85+
}
86+
}
87+
}
88+
89+
#[cfg(unix)]
90+
fn poll_connecting_unix<'a>(
91+
state: &'a mut RentToOwn<'a, ConnectingUnix>,
92+
) -> Poll<AfterConnectingUnix, Error> {
93+
if let Some(timeout) = &mut state.timeout {
94+
match timeout.poll() {
95+
Ok(Async::Ready(())) => {
96+
return Err(Error::connect(io::Error::from(io::ErrorKind::TimedOut)));
97+
}
98+
Ok(Async::NotReady) => {}
99+
Err(e) => return Err(Error::connect(io::Error::new(io::ErrorKind::Other, e))),
100+
}
101+
}
102+
let socket = try_ready!(state.future.poll().map_err(Error::connect));
103+
104+
transition!(Finished(Socket::new_unix(socket)))
105+
}
106+
107+
fn poll_resolving_dns<'a>(
108+
state: &'a mut RentToOwn<'a, ResolvingDns>,
109+
) -> Poll<AfterResolvingDns, Error> {
110+
let mut addrs = try_ready!(state.future.poll().map_err(Error::connect));
111+
let state = state.take();
112+
113+
let addr = match addrs.next() {
114+
Some(addr) => addr,
115+
None => {
116+
return Err(Error::connect(io::Error::new(
117+
io::ErrorKind::InvalidData,
118+
"resolved 0 addresses",
119+
)));
120+
}
121+
};
122+
123+
let timeout = state
124+
.config
125+
.0
126+
.connect_timeout
127+
.map(|d| Delay::new(Instant::now() + d));
128+
129+
transition!(ConnectingTcp {
130+
future: TcpStream::connect(&addr),
131+
addrs,
132+
timeout: timeout,
133+
config: state.config,
134+
})
135+
}
136+
137+
fn poll_connecting_tcp<'a>(
138+
state: &'a mut RentToOwn<'a, ConnectingTcp>,
139+
) -> Poll<AfterConnectingTcp, Error> {
140+
let stream = loop {
141+
let error = match state.future.poll() {
142+
Ok(Async::Ready(stream)) => break stream,
143+
Ok(Async::NotReady) => match &mut state.timeout {
144+
Some(timeout) => {
145+
try_ready!(timeout
146+
.poll()
147+
.map_err(|e| Error::connect(io::Error::new(io::ErrorKind::Other, e))));
148+
io::Error::from(io::ErrorKind::TimedOut)
149+
}
150+
None => return Ok(Async::NotReady),
151+
},
152+
Err(e) => e,
153+
};
154+
155+
let addr = state.addrs.next().ok_or_else(|| Error::connect(error))?;
156+
state.future = TcpStream::connect(&addr);
157+
state.timeout = state
158+
.config
159+
.0
160+
.connect_timeout
161+
.map(|d| Delay::new(Instant::now() + d));
162+
};
163+
164+
stream.set_nodelay(true).map_err(Error::connect)?;
165+
if state.config.0.keepalives {
166+
stream
167+
.set_keepalive(Some(state.config.0.keepalives_idle))
168+
.map_err(Error::connect)?;
169+
}
170+
171+
transition!(Finished(Socket::new_tcp(stream)));
172+
}
173+
}
174+
175+
impl ConnectSocketFuture {
176+
pub fn new(config: Config, idx: usize) -> ConnectSocketFuture {
177+
ConnectSocket::start(config, idx)
178+
}
179+
}

0 commit comments

Comments
 (0)