Skip to content

Commit 85acbf2

Browse files
authored
Merge pull request redis-rs#548 from particle-iot/cluster-tls
Adds support for TLS with Cluster mode
2 parents 0ee441c + 75457e0 commit 85acbf2

File tree

5 files changed

+278
-124
lines changed

5 files changed

+278
-124
lines changed

src/cluster.rs

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,23 @@ pub struct ClusterConnection {
7171
password: Option<String>,
7272
read_timeout: RefCell<Option<Duration>>,
7373
write_timeout: RefCell<Option<Duration>>,
74+
tls: Option<TlsMode>,
75+
}
76+
77+
#[derive(Clone, Copy)]
78+
enum TlsMode {
79+
Secure,
80+
Insecure,
81+
}
82+
83+
impl TlsMode {
84+
fn from_insecure_flag(insecure: bool) -> TlsMode {
85+
if insecure {
86+
TlsMode::Insecure
87+
} else {
88+
TlsMode::Secure
89+
}
90+
}
7491
}
7592

7693
impl ClusterConnection {
@@ -83,14 +100,33 @@ impl ClusterConnection {
83100
Self::create_initial_connections(&initial_nodes, readonly, password.clone())?;
84101

85102
let connection = ClusterConnection {
86-
initial_nodes,
87103
connections: RefCell::new(connections),
88104
slots: RefCell::new(SlotMap::new()),
89105
auto_reconnect: RefCell::new(true),
90106
readonly,
91107
password,
92108
read_timeout: RefCell::new(None),
93109
write_timeout: RefCell::new(None),
110+
#[cfg(feature = "tls")]
111+
tls: {
112+
if initial_nodes.is_empty() {
113+
None
114+
} else {
115+
// TODO: Maybe should run through whole list and make sure they're all matching?
116+
match &initial_nodes.get(0).unwrap().addr {
117+
ConnectionAddr::Tcp(_, _) => None,
118+
ConnectionAddr::TcpTls {
119+
host: _,
120+
port: _,
121+
insecure,
122+
} => Some(TlsMode::from_insecure_flag(*insecure)),
123+
_ => None,
124+
}
125+
}
126+
},
127+
#[cfg(not(feature = "tls"))]
128+
tls: None,
129+
initial_nodes,
94130
};
95131
connection.refresh_slots()?;
96132

@@ -166,6 +202,14 @@ impl ClusterConnection {
166202
for info in initial_nodes.iter() {
167203
let addr = match info.addr {
168204
ConnectionAddr::Tcp(ref host, port) => format!("redis://{}:{}", host, port),
205+
ConnectionAddr::TcpTls {
206+
ref host,
207+
port,
208+
insecure,
209+
} => {
210+
let tls_mode = TlsMode::from_insecure_flag(insecure);
211+
build_connection_string(host, Some(port), Some(tls_mode))
212+
}
169213
_ => panic!("No reach."),
170214
};
171215

@@ -180,7 +224,7 @@ impl ClusterConnection {
180224
if connections.is_empty() {
181225
return Err(RedisError::from((
182226
ErrorKind::IoError,
183-
"It is failed to check startup nodes.",
227+
"It failed to check startup nodes.",
184228
)));
185229
}
186230
Ok(connections)
@@ -246,7 +290,7 @@ impl ClusterConnection {
246290
let mut samples = connections.values_mut().choose_multiple(&mut rng, len);
247291

248292
for mut conn in samples.iter_mut() {
249-
if let Ok(mut slots_data) = get_slots(&mut conn) {
293+
if let Ok(mut slots_data) = get_slots(&mut conn, self.tls) {
250294
slots_data.sort_by_key(|s| s.start());
251295
let last_slot = slots_data.iter().try_fold(0, |prev_end, slot_data| {
252296
if prev_end != slot_data.start() {
@@ -395,15 +439,19 @@ impl ClusterConnection {
395439
let kind = err.kind();
396440

397441
if kind == ErrorKind::Ask {
398-
redirected = err.redirect_node().map(|x| format!("redis://{}", x.0));
442+
redirected = err
443+
.redirect_node()
444+
.map(|(node, _slot)| build_connection_string(node, None, self.tls));
399445
is_asking = true;
400446
} else if kind == ErrorKind::Moved {
401447
// Refresh slots.
402448
self.refresh_slots()?;
403449
excludes.clear();
404450

405451
// Request again.
406-
redirected = err.redirect_node().map(|x| format!("redis://{}", x.0));
452+
redirected = err
453+
.redirect_node()
454+
.map(|(node, _slot)| build_connection_string(node, None, self.tls));
407455
is_asking = false;
408456
continue;
409457
} else if kind == ErrorKind::TryAgain || kind == ErrorKind::ClusterDown {
@@ -448,7 +496,7 @@ impl ClusterConnection {
448496
let mut results = vec![Value::Nil; cmds.len()];
449497

450498
let to_retry = self
451-
.send_all_commands(&cmds)
499+
.send_all_commands(cmds)
452500
.and_then(|node_cmds| self.recv_all_commands(&mut results, &node_cmds))?;
453501

454502
if to_retry.is_empty() {
@@ -505,7 +553,7 @@ impl ClusterConnection {
505553
let mut cmd_map: HashMap<String, NodeCmd> = HashMap::new();
506554

507555
for (idx, cmd) in cmds.iter().enumerate() {
508-
let addr = self.get_addr_for_cmd(&cmd)?;
556+
let addr = self.get_addr_for_cmd(cmd)?;
509557
let nc = cmd_map
510558
.entry(addr.clone())
511559
.or_insert_with(|| NodeCmd::new(addr));
@@ -682,7 +730,7 @@ fn get_random_connection<'a>(
682730
}
683731

684732
// Get slot data from connection.
685-
fn get_slots(connection: &mut Connection) -> RedisResult<Vec<Slot>> {
733+
fn get_slots(connection: &mut Connection, tls_mode: Option<TlsMode>) -> RedisResult<Vec<Slot>> {
686734
let mut cmd = Cmd::new();
687735
cmd.arg("CLUSTER").arg("SLOTS");
688736
let value = connection.req_command(&cmd)?;
@@ -728,11 +776,11 @@ fn get_slots(connection: &mut Connection) -> RedisResult<Vec<Slot>> {
728776
}
729777

730778
let port = if let Value::Int(port) = node[1] {
731-
port
779+
port as u16
732780
} else {
733781
return None;
734782
};
735-
Some(format!("redis://{}:{}", ip, port))
783+
Some(build_connection_string(&ip, Some(port), tls_mode))
736784
} else {
737785
None
738786
}
@@ -750,3 +798,17 @@ fn get_slots(connection: &mut Connection) -> RedisResult<Vec<Slot>> {
750798

751799
Ok(result)
752800
}
801+
802+
fn build_connection_string(host: &str, port: Option<u16>, tls_mode: Option<TlsMode>) -> String {
803+
let host_port = match port {
804+
Some(port) => format!("{}:{}", host, port),
805+
None => host.to_string(),
806+
};
807+
match tls_mode {
808+
None => format!("redis://{}", host_port),
809+
Some(TlsMode::Insecure) => {
810+
format!("rediss://{}/#insecure", host_port)
811+
}
812+
Some(TlsMode::Secure) => format!("rediss://{}", host_port),
813+
}
814+
}

src/cluster_routing.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ impl RoutingInfo {
5050
}
5151

5252
pub fn for_key(key: &[u8]) -> Option<RoutingInfo> {
53-
let key = match get_hashtag(&key) {
53+
let key = match get_hashtag(key) {
5454
Some(tag) => tag,
55-
None => &key,
55+
None => key,
5656
};
5757
Some(RoutingInfo::Slot(
5858
crc16::State::<crc16::XMODEM>::calculate(key) % SLOT_SIZE as u16,

src/connection.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,12 @@ impl ActualConnection {
385385
let tls = match timeout {
386386
None => {
387387
let tcp = TcpStream::connect((host, port))?;
388-
tls_connector.connect(host, tcp).unwrap()
388+
match tls_connector.connect(host, tcp) {
389+
Ok(res) => res,
390+
Err(e) => {
391+
fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string()));
392+
}
393+
}
389394
}
390395
Some(timeout) => {
391396
let mut tcp = None;
@@ -1114,7 +1119,7 @@ mod tests {
11141119
("tcp://127.0.0.1", false),
11151120
];
11161121
for (url, expected) in cases.into_iter() {
1117-
let res = parse_redis_url(&url);
1122+
let res = parse_redis_url(url);
11181123
assert_eq!(
11191124
res.is_some(),
11201125
expected,

tests/support/cluster.rs

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,54 @@
22
#![allow(dead_code)]
33

44
use std::convert::identity;
5-
use std::fs;
5+
use std::env;
66
use std::process;
77
use std::thread::sleep;
88
use std::time::Duration;
99

10-
use std::path::PathBuf;
10+
use tempfile::TempDir;
11+
12+
use crate::support::build_keys_and_certs_for_tls;
1113

1214
use super::RedisServer;
1315

16+
const LOCALHOST: &str = "127.0.0.1";
17+
18+
enum ClusterType {
19+
Tcp,
20+
TcpTls,
21+
}
22+
23+
impl ClusterType {
24+
fn get_intended() -> ClusterType {
25+
match env::var("REDISRS_SERVER_TYPE")
26+
.ok()
27+
.as_ref()
28+
.map(|x| &x[..])
29+
{
30+
Some("tcp") => ClusterType::Tcp,
31+
Some("tcp+tls") => ClusterType::TcpTls,
32+
val => {
33+
panic!("Unknown server type {:?}", val);
34+
}
35+
}
36+
}
37+
38+
fn build_addr(port: u16) -> redis::ConnectionAddr {
39+
match ClusterType::get_intended() {
40+
ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port),
41+
ClusterType::TcpTls => redis::ConnectionAddr::TcpTls {
42+
host: "127.0.0.1".into(),
43+
port,
44+
insecure: true,
45+
},
46+
}
47+
}
48+
}
49+
1450
pub struct RedisCluster {
1551
pub servers: Vec<RedisServer>,
16-
pub folders: Vec<PathBuf>,
52+
pub folders: Vec<TempDir>,
1753
}
1854

1955
impl RedisCluster {
@@ -22,25 +58,48 @@ impl RedisCluster {
2258
let mut folders = vec![];
2359
let mut addrs = vec![];
2460
let start_port = 7000;
61+
let mut tls_paths = None;
62+
let mut is_tls = false;
63+
64+
if let ClusterType::TcpTls = ClusterType::get_intended() {
65+
// Create a shared set of keys in cluster mode
66+
let tempdir = tempfile::Builder::new()
67+
.prefix("redis")
68+
.tempdir()
69+
.expect("failed to create tempdir");
70+
let files = build_keys_and_certs_for_tls(&tempdir);
71+
folders.push(tempdir);
72+
tls_paths = Some(files);
73+
is_tls = true;
74+
}
75+
2576
for node in 0..nodes {
2677
let port = start_port + node;
2778

2879
servers.push(RedisServer::new_with_addr(
29-
redis::ConnectionAddr::Tcp("127.0.0.1".into(), port),
80+
ClusterType::build_addr(port),
81+
tls_paths.clone(),
3082
|cmd| {
31-
let (a, b) = rand::random::<(u64, u64)>();
32-
let path = PathBuf::from(format!("/tmp/redis-rs-cluster-test-{}-{}-dir", a, b));
33-
fs::create_dir_all(&path).unwrap();
83+
let tempdir = tempfile::Builder::new()
84+
.prefix("redis")
85+
.tempdir()
86+
.expect("failed to create tempdir");
3487
cmd.arg("--cluster-enabled")
3588
.arg("yes")
3689
.arg("--cluster-config-file")
37-
.arg(&path.join("nodes.conf"))
90+
.arg(&tempdir.path().join("nodes.conf"))
3891
.arg("--cluster-node-timeout")
3992
.arg("5000")
4093
.arg("--appendonly")
4194
.arg("yes");
42-
cmd.current_dir(&path);
43-
folders.push(path);
95+
if is_tls {
96+
cmd.arg("--tls-cluster").arg("yes");
97+
if replicas > 0 {
98+
cmd.arg("--tls-replication").arg("yes");
99+
}
100+
}
101+
cmd.current_dir(&tempdir.path());
102+
folders.push(tempdir);
44103
addrs.push(format!("127.0.0.1:{}", port));
45104
dbg!(&cmd);
46105
cmd.spawn().unwrap()
@@ -59,6 +118,9 @@ impl RedisCluster {
59118
cmd.arg("--cluster-replicas").arg(replicas.to_string());
60119
}
61120
cmd.arg("--cluster-yes");
121+
if is_tls {
122+
cmd.arg("--tls").arg("--insecure");
123+
}
62124
let status = dbg!(cmd).status().unwrap();
63125
assert!(status.success());
64126

@@ -71,9 +133,15 @@ impl RedisCluster {
71133

72134
fn wait_for_replicas(&self, replicas: u16) {
73135
'server: for server in &self.servers {
74-
let addr = format!("redis://{}/", server.get_client_addr());
75-
eprintln!("waiting until {} knows required number of replicas", addr);
76-
let client = redis::Client::open(addr).unwrap();
136+
let conn_info = redis::ConnectionInfo {
137+
addr: server.get_client_addr().clone(),
138+
redis: Default::default(),
139+
};
140+
eprintln!(
141+
"waiting until {:?} knows required number of replicas",
142+
conn_info.addr
143+
);
144+
let client = redis::Client::open(conn_info).unwrap();
77145
let mut con = client.get_connection().unwrap();
78146

79147
// retry 100 times
@@ -98,9 +166,6 @@ impl RedisCluster {
98166
for server in &mut self.servers {
99167
server.stop();
100168
}
101-
for folder in &self.folders {
102-
fs::remove_dir_all(&folder).unwrap();
103-
}
104169
}
105170

106171
pub fn iter_servers(&self) -> impl Iterator<Item = &RedisServer> {
@@ -136,7 +201,10 @@ impl TestClusterContext {
136201
let mut builder = redis::cluster::ClusterClientBuilder::new(
137202
cluster
138203
.iter_servers()
139-
.map(|x| format!("redis://{}/", x.get_client_addr()))
204+
.map(|server| redis::ConnectionInfo {
205+
addr: server.get_client_addr().clone(),
206+
redis: Default::default(),
207+
})
140208
.collect(),
141209
);
142210
builder = initializer(builder);

0 commit comments

Comments
 (0)