|
1 | 1 | #![allow(clippy::large_enum_variant)]
|
2 | 2 |
|
3 | 3 | use futures::{try_ready, Async, Future, Poll, Stream};
|
4 |
| -use futures_cpupool::{CpuFuture, CpuPool}; |
5 |
| -use lazy_static::lazy_static; |
6 | 4 | use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
7 | 5 | use std::io;
|
8 |
| -use std::net::{SocketAddr, ToSocketAddrs}; |
9 |
| -use std::time::Instant; |
10 |
| -use std::vec; |
11 |
| -use tokio_tcp::TcpStream; |
12 |
| -use tokio_timer::Delay; |
13 |
| -#[cfg(unix)] |
14 |
| -use tokio_uds::UnixStream; |
15 | 6 |
|
16 |
| -use crate::proto::{Client, Connection, HandshakeFuture, SimpleQueryStream}; |
17 |
| -use crate::{Config, Error, Host, Socket, TargetSessionAttrs, TlsMode}; |
18 |
| - |
19 |
| -lazy_static! { |
20 |
| - static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new() |
21 |
| - .name_prefix("postgres-dns-") |
22 |
| - .pool_size(2) |
23 |
| - .create(); |
24 |
| -} |
| 7 | +use crate::proto::{Client, ConnectSocketFuture, Connection, HandshakeFuture, SimpleQueryStream}; |
| 8 | +use crate::{Config, Error, Socket, TargetSessionAttrs, TlsMode}; |
25 | 9 |
|
26 | 10 | #[derive(StateMachineFuture)]
|
27 | 11 | pub enum ConnectOnce<T>
|
28 | 12 | where
|
29 | 13 | T: TlsMode<Socket>,
|
30 | 14 | {
|
31 |
| - #[state_machine_future(start)] |
32 |
| - #[cfg_attr(unix, state_machine_future(transitions(ConnectingUnix, ResolvingDns)))] |
33 |
| - #[cfg_attr(not(unix), state_machine_future(transitions(ConnectingTcp)))] |
| 15 | + #[state_machine_future(start, transitions(ConnectingSocket))] |
34 | 16 | Start {
|
35 | 17 | idx: usize,
|
36 | 18 | tls_mode: T,
|
37 | 19 | config: Config,
|
38 | 20 | },
|
39 |
| - #[cfg(unix)] |
40 |
| - #[state_machine_future(transitions(Handshaking))] |
41 |
| - ConnectingUnix { |
42 |
| - future: tokio_uds::ConnectFuture, |
43 |
| - timeout: Option<Delay>, |
44 |
| - tls_mode: T, |
45 |
| - config: Config, |
46 |
| - }, |
47 |
| - #[state_machine_future(transitions(ConnectingTcp))] |
48 |
| - ResolvingDns { |
49 |
| - future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>, |
50 |
| - timeout: Option<Delay>, |
51 |
| - tls_mode: T, |
52 |
| - config: Config, |
53 |
| - }, |
54 | 21 | #[state_machine_future(transitions(Handshaking))]
|
55 |
| - ConnectingTcp { |
56 |
| - future: tokio_tcp::ConnectFuture, |
57 |
| - addrs: vec::IntoIter<SocketAddr>, |
58 |
| - timeout: Option<Delay>, |
| 22 | + ConnectingSocket { |
| 23 | + future: ConnectSocketFuture, |
| 24 | + idx: usize, |
59 | 25 | tls_mode: T,
|
60 | 26 | config: Config,
|
61 | 27 | },
|
@@ -83,142 +49,23 @@ where
|
83 | 49 | fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
|
84 | 50 | let state = state.take();
|
85 | 51 |
|
86 |
| - let port = *state |
87 |
| - .config |
88 |
| - .0 |
89 |
| - .port |
90 |
| - .get(state.idx) |
91 |
| - .or_else(|| state.config.0.port.get(0)) |
92 |
| - .unwrap_or(&5432); |
93 |
| - |
94 |
| - let timeout = state |
95 |
| - .config |
96 |
| - .0 |
97 |
| - .connect_timeout |
98 |
| - .map(|d| Delay::new(Instant::now() + d)); |
99 |
| - |
100 |
| - match &state.config.0.host[state.idx] { |
101 |
| - Host::Tcp(host) => { |
102 |
| - let host = host.clone(); |
103 |
| - transition!(ResolvingDns { |
104 |
| - future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()), |
105 |
| - timeout, |
106 |
| - tls_mode: state.tls_mode, |
107 |
| - config: state.config, |
108 |
| - }) |
109 |
| - } |
110 |
| - #[cfg(unix)] |
111 |
| - Host::Unix(host) => { |
112 |
| - let path = host.join(format!(".s.PGSQL.{}", port)); |
113 |
| - transition!(ConnectingUnix { |
114 |
| - future: UnixStream::connect(path), |
115 |
| - timeout, |
116 |
| - tls_mode: state.tls_mode, |
117 |
| - config: state.config, |
118 |
| - }) |
119 |
| - } |
120 |
| - } |
121 |
| - } |
122 |
| - |
123 |
| - #[cfg(unix)] |
124 |
| - fn poll_connecting_unix<'a>( |
125 |
| - state: &'a mut RentToOwn<'a, ConnectingUnix<T>>, |
126 |
| - ) -> Poll<AfterConnectingUnix<T>, Error> { |
127 |
| - if let Some(timeout) = &mut state.timeout { |
128 |
| - match timeout.poll() { |
129 |
| - Ok(Async::Ready(())) => { |
130 |
| - return Err(Error::connect(io::Error::from(io::ErrorKind::TimedOut))) |
131 |
| - } |
132 |
| - Ok(Async::NotReady) => {} |
133 |
| - Err(e) => return Err(Error::connect(io::Error::new(io::ErrorKind::Other, e))), |
134 |
| - } |
135 |
| - } |
136 |
| - |
137 |
| - let stream = try_ready!(state.future.poll().map_err(Error::connect)); |
138 |
| - let stream = Socket::new_unix(stream); |
139 |
| - let state = state.take(); |
140 |
| - |
141 |
| - transition!(Handshaking { |
142 |
| - target_session_attrs: state.config.0.target_session_attrs, |
143 |
| - future: HandshakeFuture::new(stream, state.tls_mode, state.config), |
144 |
| - }) |
145 |
| - } |
146 |
| - |
147 |
| - fn poll_resolving_dns<'a>( |
148 |
| - state: &'a mut RentToOwn<'a, ResolvingDns<T>>, |
149 |
| - ) -> Poll<AfterResolvingDns<T>, Error> { |
150 |
| - if let Some(timeout) = &mut state.timeout { |
151 |
| - match timeout.poll() { |
152 |
| - Ok(Async::Ready(())) => { |
153 |
| - return Err(Error::connect(io::Error::from(io::ErrorKind::TimedOut))) |
154 |
| - } |
155 |
| - Ok(Async::NotReady) => {} |
156 |
| - Err(e) => return Err(Error::connect(io::Error::new(io::ErrorKind::Other, e))), |
157 |
| - } |
158 |
| - } |
159 |
| - |
160 |
| - let mut addrs = try_ready!(state.future.poll().map_err(Error::connect)); |
161 |
| - let state = state.take(); |
162 |
| - |
163 |
| - let addr = match addrs.next() { |
164 |
| - Some(addr) => addr, |
165 |
| - None => { |
166 |
| - return Err(Error::connect(io::Error::new( |
167 |
| - io::ErrorKind::InvalidData, |
168 |
| - "resolved 0 addresses", |
169 |
| - ))); |
170 |
| - } |
171 |
| - }; |
172 |
| - |
173 |
| - transition!(ConnectingTcp { |
174 |
| - future: TcpStream::connect(&addr), |
175 |
| - addrs, |
176 |
| - timeout: state.timeout, |
| 52 | + transition!(ConnectingSocket { |
| 53 | + future: ConnectSocketFuture::new(state.config.clone(), state.idx), |
| 54 | + idx: state.idx, |
177 | 55 | tls_mode: state.tls_mode,
|
178 | 56 | config: state.config,
|
179 | 57 | })
|
180 | 58 | }
|
181 | 59 |
|
182 |
| - fn poll_connecting_tcp<'a>( |
183 |
| - state: &'a mut RentToOwn<'a, ConnectingTcp<T>>, |
184 |
| - ) -> Poll<AfterConnectingTcp<T>, Error> { |
185 |
| - if let Some(timeout) = &mut state.timeout { |
186 |
| - match timeout.poll() { |
187 |
| - Ok(Async::Ready(())) => { |
188 |
| - return Err(Error::connect(io::Error::from(io::ErrorKind::TimedOut))) |
189 |
| - } |
190 |
| - Ok(Async::NotReady) => {} |
191 |
| - Err(e) => return Err(Error::connect(io::Error::new(io::ErrorKind::Other, e))), |
192 |
| - } |
193 |
| - } |
194 |
| - |
195 |
| - let stream = loop { |
196 |
| - match state.future.poll() { |
197 |
| - Ok(Async::Ready(stream)) => break stream, |
198 |
| - Ok(Async::NotReady) => return Ok(Async::NotReady), |
199 |
| - Err(e) => { |
200 |
| - let addr = match state.addrs.next() { |
201 |
| - Some(addr) => addr, |
202 |
| - None => return Err(Error::connect(e)), |
203 |
| - }; |
204 |
| - state.future = TcpStream::connect(&addr); |
205 |
| - } |
206 |
| - } |
207 |
| - }; |
| 60 | + fn poll_connecting_socket<'a>( |
| 61 | + state: &'a mut RentToOwn<'a, ConnectingSocket<T>>, |
| 62 | + ) -> Poll<AfterConnectingSocket<T>, Error> { |
| 63 | + let socket = try_ready!(state.future.poll()); |
208 | 64 | let state = state.take();
|
209 | 65 |
|
210 |
| - stream.set_nodelay(true).map_err(Error::connect)?; |
211 |
| - if state.config.0.keepalives { |
212 |
| - stream |
213 |
| - .set_keepalive(Some(state.config.0.keepalives_idle)) |
214 |
| - .map_err(Error::connect)?; |
215 |
| - } |
216 |
| - |
217 |
| - let stream = Socket::new_tcp(stream); |
218 |
| - |
219 | 66 | transition!(Handshaking {
|
220 | 67 | target_session_attrs: state.config.0.target_session_attrs,
|
221 |
| - future: HandshakeFuture::new(stream, state.tls_mode, state.config), |
| 68 | + future: HandshakeFuture::new(socket, state.tls_mode, state.config), |
222 | 69 | })
|
223 | 70 | }
|
224 | 71 |
|
|
0 commit comments