@@ -23,6 +23,12 @@ use tokio::{
23
23
#[ cfg( feature = "tokio-comp" ) ]
24
24
use tokio:: net:: TcpStream as TcpStreamTokio ;
25
25
26
+ #[ cfg( feature = "tokio-tls-comp" ) ]
27
+ use tokio_tls:: TlsStream as TlsStreamTokio ;
28
+
29
+ #[ cfg( feature = "tls" ) ]
30
+ use native_tls:: TlsConnector ;
31
+
26
32
#[ cfg( any( feature = "tokio-comp" , feature = "async-std-comp" ) ) ]
27
33
use tokio_util:: codec:: Decoder ;
28
34
@@ -52,7 +58,16 @@ use crate::aio_async_std;
52
58
pub ( crate ) trait Connect {
53
59
/// Performs a TCP connection
54
60
async fn connect_tcp ( socket_addr : SocketAddr ) -> RedisResult < ActualConnection > ;
55
- /// Performans an UNIX connection
61
+
62
+ // Performs a TCP TLS connection
63
+ #[ cfg( feature = "tls" ) ]
64
+ async fn connect_tcp_tls (
65
+ hostname : & str ,
66
+ socket_addr : SocketAddr ,
67
+ insecure : bool ,
68
+ ) -> RedisResult < ActualConnection > ;
69
+
70
+ /// Performs a UNIX connection
56
71
#[ cfg( unix) ]
57
72
async fn connect_unix ( path : & Path ) -> RedisResult < ActualConnection > ;
58
73
}
@@ -61,6 +76,9 @@ pub(crate) trait Connect {
61
76
mod tokio_aio {
62
77
use super :: { async_trait, ActualConnection , Connect , RedisResult , SocketAddr , TcpStreamTokio } ;
63
78
79
+ #[ cfg( feature = "tls" ) ]
80
+ use super :: TlsConnector ;
81
+
64
82
#[ cfg( unix) ]
65
83
use super :: { Path , UnixStreamTokio } ;
66
84
@@ -73,6 +91,27 @@ mod tokio_aio {
73
91
. await
74
92
. map ( ActualConnection :: TcpTokio ) ?)
75
93
}
94
+ #[ cfg( feature = "tls" ) ]
95
+ async fn connect_tcp_tls (
96
+ hostname : & str ,
97
+ socket_addr : SocketAddr ,
98
+ insecure : bool ,
99
+ ) -> RedisResult < ActualConnection > {
100
+ let tls_connector: tokio_tls:: TlsConnector = if insecure {
101
+ TlsConnector :: builder ( )
102
+ . danger_accept_invalid_certs ( true )
103
+ . danger_accept_invalid_hostnames ( true )
104
+ . use_sni ( false )
105
+ . build ( ) ?
106
+ } else {
107
+ TlsConnector :: new ( ) ?
108
+ }
109
+ . into ( ) ;
110
+ Ok ( tls_connector
111
+ . connect ( hostname, TcpStreamTokio :: connect ( & socket_addr) . await ?)
112
+ . await
113
+ . map ( ActualConnection :: TcpTlsTokio ) ?)
114
+ }
76
115
#[ cfg( unix) ]
77
116
async fn connect_unix ( path : & Path ) -> RedisResult < ActualConnection > {
78
117
Ok ( UnixStreamTokio :: connect ( path)
@@ -87,13 +126,19 @@ pub(crate) enum ActualConnection {
87
126
/// Represents a Tokio TCP connection.
88
127
#[ cfg( feature = "tokio-comp" ) ]
89
128
TcpTokio ( TcpStreamTokio ) ,
129
+ /// Represents a Tokio TLS encrypted TCP connection
130
+ #[ cfg( feature = "tokio-tls-comp" ) ]
131
+ TcpTlsTokio ( TlsStreamTokio < TcpStreamTokio > ) ,
90
132
/// Represents a Tokio Unix connection.
91
133
#[ cfg( unix) ]
92
134
#[ cfg( feature = "tokio-comp" ) ]
93
135
UnixTokio ( UnixStreamTokio ) ,
94
136
/// Represents an Async_std TCP connection.
95
137
#[ cfg( feature = "async-std-comp" ) ]
96
138
TcpAsyncStd ( aio_async_std:: TcpStreamAsyncStdWrapped ) ,
139
+ /// Represents an Async_std TLS encrypted TCP connection.
140
+ #[ cfg( feature = "async-std-tls-comp" ) ]
141
+ TcpTlsAsyncStd ( aio_async_std:: TlsStreamAsyncStdWrapped ) ,
97
142
/// Represents an Async_std Unix connection.
98
143
#[ cfg( feature = "async-std-comp" ) ]
99
144
#[ cfg( unix) ]
@@ -109,11 +154,15 @@ impl AsyncWrite for ActualConnection {
109
154
match & mut * self {
110
155
#[ cfg( feature = "tokio-comp" ) ]
111
156
ActualConnection :: TcpTokio ( r) => Pin :: new ( r) . poll_write ( cx, buf) ,
157
+ #[ cfg( feature = "tokio-tls-comp" ) ]
158
+ ActualConnection :: TcpTlsTokio ( r) => Pin :: new ( r) . poll_write ( cx, buf) ,
112
159
#[ cfg( unix) ]
113
160
#[ cfg( feature = "tokio-comp" ) ]
114
161
ActualConnection :: UnixTokio ( r) => Pin :: new ( r) . poll_write ( cx, buf) ,
115
162
#[ cfg( feature = "async-std-comp" ) ]
116
163
ActualConnection :: TcpAsyncStd ( r) => Pin :: new ( r) . poll_write ( cx, buf) ,
164
+ #[ cfg( feature = "async-std-tls-comp" ) ]
165
+ ActualConnection :: TcpTlsAsyncStd ( r) => Pin :: new ( r) . poll_write ( cx, buf) ,
117
166
#[ cfg( feature = "async-std-comp" ) ]
118
167
#[ cfg( unix) ]
119
168
ActualConnection :: UnixAsyncStd ( r) => Pin :: new ( r) . poll_write ( cx, buf) ,
@@ -124,11 +173,15 @@ impl AsyncWrite for ActualConnection {
124
173
match & mut * self {
125
174
#[ cfg( feature = "tokio-comp" ) ]
126
175
ActualConnection :: TcpTokio ( r) => Pin :: new ( r) . poll_flush ( cx) ,
176
+ #[ cfg( feature = "tokio-tls-comp" ) ]
177
+ ActualConnection :: TcpTlsTokio ( r) => Pin :: new ( r) . poll_flush ( cx) ,
127
178
#[ cfg( unix) ]
128
179
#[ cfg( feature = "tokio-comp" ) ]
129
180
ActualConnection :: UnixTokio ( r) => Pin :: new ( r) . poll_flush ( cx) ,
130
181
#[ cfg( feature = "async-std-comp" ) ]
131
182
ActualConnection :: TcpAsyncStd ( r) => Pin :: new ( r) . poll_flush ( cx) ,
183
+ #[ cfg( feature = "async-std-tls-comp" ) ]
184
+ ActualConnection :: TcpTlsAsyncStd ( r) => Pin :: new ( r) . poll_flush ( cx) ,
132
185
#[ cfg( feature = "async-std-comp" ) ]
133
186
#[ cfg( unix) ]
134
187
ActualConnection :: UnixAsyncStd ( r) => Pin :: new ( r) . poll_flush ( cx) ,
@@ -139,11 +192,15 @@ impl AsyncWrite for ActualConnection {
139
192
match & mut * self {
140
193
#[ cfg( feature = "tokio-comp" ) ]
141
194
ActualConnection :: TcpTokio ( r) => Pin :: new ( r) . poll_shutdown ( cx) ,
195
+ #[ cfg( feature = "tokio-tls-comp" ) ]
196
+ ActualConnection :: TcpTlsTokio ( r) => Pin :: new ( r) . poll_shutdown ( cx) ,
142
197
#[ cfg( unix) ]
143
198
#[ cfg( feature = "tokio-comp" ) ]
144
199
ActualConnection :: UnixTokio ( r) => Pin :: new ( r) . poll_shutdown ( cx) ,
145
200
#[ cfg( feature = "async-std-comp" ) ]
146
201
ActualConnection :: TcpAsyncStd ( r) => Pin :: new ( r) . poll_shutdown ( cx) ,
202
+ #[ cfg( feature = "async-std-tls-comp" ) ]
203
+ ActualConnection :: TcpTlsAsyncStd ( r) => Pin :: new ( r) . poll_shutdown ( cx) ,
147
204
#[ cfg( feature = "async-std-comp" ) ]
148
205
#[ cfg( unix) ]
149
206
ActualConnection :: UnixAsyncStd ( r) => Pin :: new ( r) . poll_shutdown ( cx) ,
@@ -160,11 +217,15 @@ impl AsyncRead for ActualConnection {
160
217
match & mut * self {
161
218
#[ cfg( feature = "tokio-comp" ) ]
162
219
ActualConnection :: TcpTokio ( r) => Pin :: new ( r) . poll_read ( cx, buf) ,
220
+ #[ cfg( feature = "tokio-tls-comp" ) ]
221
+ ActualConnection :: TcpTlsTokio ( r) => Pin :: new ( r) . poll_read ( cx, buf) ,
163
222
#[ cfg( unix) ]
164
223
#[ cfg( feature = "tokio-comp" ) ]
165
224
ActualConnection :: UnixTokio ( r) => Pin :: new ( r) . poll_read ( cx, buf) ,
166
225
#[ cfg( feature = "async-std-comp" ) ]
167
226
ActualConnection :: TcpAsyncStd ( r) => Pin :: new ( r) . poll_read ( cx, buf) ,
227
+ #[ cfg( feature = "async-std-tls-comp" ) ]
228
+ ActualConnection :: TcpTlsAsyncStd ( r) => Pin :: new ( r) . poll_read ( cx, buf) ,
168
229
#[ cfg( feature = "async-std-comp" ) ]
169
230
#[ cfg( unix) ]
170
231
ActualConnection :: UnixAsyncStd ( r) => Pin :: new ( r) . poll_read ( cx, buf) ,
@@ -414,10 +475,27 @@ async fn connect_simple<T: Connect>(
414
475
Ok ( match * connection_info. addr {
415
476
ConnectionAddr :: Tcp ( ref host, port) => {
416
477
let socket_addr = get_socket_addrs ( host, port) ?;
417
-
418
478
<T >:: connect_tcp ( socket_addr) . await ?
419
479
}
420
480
481
+ #[ cfg( feature = "tls" ) ]
482
+ ConnectionAddr :: TcpTls {
483
+ ref host,
484
+ port,
485
+ insecure,
486
+ } => {
487
+ let socket_addr = get_socket_addrs ( host, port) ?;
488
+ <T >:: connect_tcp_tls ( host, socket_addr, insecure) . await ?
489
+ }
490
+
491
+ #[ cfg( not( feature = "tls" ) ) ]
492
+ ConnectionAddr :: TcpTls { .. } => {
493
+ fail ! ( (
494
+ ErrorKind :: InvalidClientConfig ,
495
+ "Cannot connect to TCP with TLS without the tls feature"
496
+ ) ) ;
497
+ }
498
+
421
499
#[ cfg( unix) ]
422
500
ConnectionAddr :: Unix ( ref path) => <T >:: connect_unix ( path) . await ?,
423
501
@@ -796,12 +874,24 @@ impl MultiplexedConnection {
796
874
let ( pipeline, driver) = Pipeline :: new ( codec) ;
797
875
( pipeline, boxed ( driver) )
798
876
}
877
+ #[ cfg( feature = "tokio-tls-comp" ) ]
878
+ ActualConnection :: TcpTlsTokio ( tls) => {
879
+ let codec = ValueCodec :: default ( ) . framed ( tls) ;
880
+ let ( pipeline, driver) = Pipeline :: new ( codec) ;
881
+ ( pipeline, boxed ( driver) )
882
+ }
799
883
#[ cfg( feature = "async-std-comp" ) ]
800
884
ActualConnection :: TcpAsyncStd ( tcp) => {
801
885
let codec = ValueCodec :: default ( ) . framed ( tcp) ;
802
886
let ( pipeline, driver) = Pipeline :: new ( codec) ;
803
887
( pipeline, boxed ( driver) )
804
888
}
889
+ #[ cfg( feature = "async-std-tls-comp" ) ]
890
+ ActualConnection :: TcpTlsAsyncStd ( tcp) => {
891
+ let codec = ValueCodec :: default ( ) . framed ( tcp) ;
892
+ let ( pipeline, driver) = Pipeline :: new ( codec) ;
893
+ ( pipeline, boxed ( driver) )
894
+ }
805
895
#[ cfg( unix) ]
806
896
#[ cfg( feature = "tokio-comp" ) ]
807
897
ActualConnection :: UnixTokio ( unix) => {
0 commit comments