Skip to content

Allow users to request TLS client-side enforcement #525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/cassandra.h
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,12 @@ typedef enum CassSslVerifyFlags_ {
CASS_SSL_VERIFY_PEER_IDENTITY_DNS = 0x04
} CassSslVerifyFlags;

typedef enum CassSslTlsVersion_ {
CASS_SSL_VERSION_TLS1 = 0x00,
CASS_SSL_VERSION_TLS1_1 = 0x01,
CASS_SSL_VERSION_TLS1_2 = 0x02
} CassSslTlsVersion;

typedef enum CassProtocolVersion_ {
CASS_PROTOCOL_VERSION_V1 = 0x01, /**< Deprecated */
CASS_PROTOCOL_VERSION_V2 = 0x02, /**< Deprecated */
Expand Down Expand Up @@ -4687,6 +4693,20 @@ cass_ssl_set_private_key_n(CassSsl* ssl,
const char* password,
size_t password_length);

/**
* Set minimum supported client-side protocol version. This will prevent the
* connection using protocol versions earlier than the specified one. Useful
* for preventing TLS downgrade attacks.
*
* @public @memberof CassSsl
*
* @param[in] ssl
* @param[in] min_version
* @return CASS_OK if successful, otherwise an error occurred.
*/
CASS_EXPORT CassError
cass_ssl_set_min_protocol_version(CassSsl* ssl, CassSslTlsVersion min_version);

/***********************************************************************************
*
* Authenticator
Expand Down
4 changes: 4 additions & 0 deletions src/ssl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ CassError cass_ssl_set_private_key_n(CassSsl* ssl, const char* key, size_t key_l
return ssl->set_private_key(key, key_length, password, password_length);
}

CassError cass_ssl_set_min_protocol_version(CassSsl* ssl, CassSslTlsVersion min_version) {
return ssl->set_min_protocol_version(min_version);
}

} // extern "C"

template <class T>
Expand Down
1 change: 1 addition & 0 deletions src/ssl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class SslContext : public RefCounted<SslContext> {
virtual CassError set_cert(const char* cert, size_t cert_length) = 0;
virtual CassError set_private_key(const char* key, size_t key_length, const char* password,
size_t password_length) = 0;
virtual CassError set_min_protocol_version(CassSslTlsVersion min_version) = 0;

protected:
int verify_flags_;
Expand Down
4 changes: 4 additions & 0 deletions src/ssl/ssl_no_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,8 @@ CassError NoSslContext::set_private_key(const char* key, size_t key_length, cons
return CASS_ERROR_LIB_NOT_IMPLEMENTED;
}

CassError NoSslContext::set_min_protocol_version(CassSslTlsVersion min_version) {
return CASS_ERROR_LIB_NOT_IMPLEMENTED;
}

SslContext::Ptr NoSslContextFactory::create() { return SslContext::Ptr(new NoSslContext()); }
1 change: 1 addition & 0 deletions src/ssl/ssl_no_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class NoSslContext : public SslContext {
virtual CassError set_cert(const char* cert, size_t cert_length);
virtual CassError set_private_key(const char* key, size_t key_length, const char* password,
size_t password_length);
virtual CassError set_min_protocol_version(CassSslTlsVersion min_version);
};

class NoSslContextFactory : public SslContextFactoryBase<NoSslContextFactory> {
Expand Down
44 changes: 44 additions & 0 deletions src/ssl/ssl_openssl_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
!defined(LIBRESSL_VERSION_NUMBER) // Required as OPENSSL_VERSION_NUMBER for LibreSSL is defined
// as 2.0.0
#if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
#define SSL_CAN_SET_MIN_VERSION
#define SSL_CLIENT_METHOD TLS_client_method
#else
#define SSL_CLIENT_METHOD SSLv23_client_method
#endif
#else
#if (LIBRESSL_VERSION_NUMBER >= 0x20302000L)
#define SSL_CAN_SET_MIN_VERSION
#define SSL_CLIENT_METHOD TLS_client_method
#else
#define SSL_CLIENT_METHOD SSLv23_client_method
Expand Down Expand Up @@ -611,6 +613,48 @@ CassError OpenSslContext::set_private_key(const char* key, size_t key_length, co
return CASS_OK;
}

CassError OpenSslContext::set_min_protocol_version(CassSslTlsVersion min_version) {
#ifdef SSL_CAN_SET_MIN_VERSION
int method;
switch (min_version) {
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1:
method = TLS1_VERSION;
break;
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_1:
method = TLS1_1_VERSION;
break;
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_2:
method = TLS1_2_VERSION;
break;
default:
// unsupported version
return CASS_ERROR_LIB_BAD_PARAMS;
}
SSL_CTX_set_min_proto_version(ssl_ctx_, method);
return CASS_OK;
#else
// If we don't have the `set_min_proto_version` function then we do this via
// the (deprecated in later versions) options function.
int options = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
switch (min_version) {
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1:
break;
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_1:
options |= SSL_OP_NO_TLSv1;
break;
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_2:
options |= SSL_OP_NO_TLSv1;
options |= SSL_OP_NO_TLSv1_1;
break;
default:
// unsupported version
return CASS_ERROR_LIB_BAD_PARAMS;
}
SSL_CTX_set_options(ssl_ctx_, options);
return CASS_OK;
#endif
}

SslContext::Ptr OpenSslContextFactory::create() { return SslContext::Ptr(new OpenSslContext()); }

namespace openssl {
Expand Down
1 change: 1 addition & 0 deletions src/ssl/ssl_openssl_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class OpenSslContext : public SslContext {
virtual CassError set_cert(const char* cert, size_t cert_length);
virtual CassError set_private_key(const char* key, size_t key_length, const char* password,
size_t password_length);
virtual CassError set_min_protocol_version(CassSslTlsVersion min_version);

private:
SSL_CTX* ssl_ctx_;
Expand Down
17 changes: 17 additions & 0 deletions tests/src/unit/mockssandra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ using datastax::internal::core::UuidGen;
!defined(LIBRESSL_VERSION_NUMBER) // Required as OPENSSL_VERSION_NUMBER for LibreSSL is defined
// as 2.0.0
#if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
#define SSL_CAN_SET_MAX_VERSION
#define SSL_SERVER_METHOD TLS_server_method
#else
#define SSL_SERVER_METHOD SSLv23_server_method
#endif
#else
#if (LIBRESSL_VERSION_NUMBER >= 0x20302000L)
#define SSL_CAN_SET_MAX_VERSION
#define SSL_SERVER_METHOD TLS_server_method
#else
#define SSL_SERVER_METHOD SSLv23_server_method
Expand Down Expand Up @@ -555,6 +557,21 @@ bool ServerConnection::use_ssl(const String& key, const String& cert,
return true;
}

// Weaken the SSL connection, enforcing that it can only use TLS1.0 at max.
// This is used for testing client-side enforcement of more secure TLS
// protocols.
void ServerConnection::weaken_ssl() {
if (!ssl_context_) {
return;
}

#ifdef SSL_CAN_SET_MAX_VERSION
SSL_CTX_set_max_proto_version(ssl_context_, TLS1_VERSION);
#else
SSL_CTX_set_options(ssl_context_, SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1_2);
#endif
}

using datastax::internal::core::Task;

class RunListen : public Task {
Expand Down
13 changes: 12 additions & 1 deletion tests/src/unit/mockssandra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class ServerConnection : public RefCounted<ServerConnection> {

bool use_ssl(const String& key, const String& cert, const String& ca_cert = "",
bool require_client_cert = false);
void weaken_ssl();

void listen(EventLoopGroup* event_loop_group);
int wait_listen();
Expand Down Expand Up @@ -1161,6 +1162,7 @@ class Cluster {
~Cluster();

String use_ssl(const String& cn = "");
void weaken_ssl();

int start_all(EventLoopGroup* event_loop_group);
void start_all_async(EventLoopGroup* event_loop_group);
Expand Down Expand Up @@ -1264,7 +1266,8 @@ class SimpleEchoServer {
public:
SimpleEchoServer()
: factory_(new EchoClientConnectionFactory())
, event_loop_group_(1) {}
, event_loop_group_(1)
, ssl_weaken_(false) {}

~SimpleEchoServer() { close(); }

Expand All @@ -1281,6 +1284,8 @@ class SimpleEchoServer {
return ssl_cert_;
}

void weaken_ssl() { ssl_weaken_ = true; }

void use_connection_factory(internal::ClientConnectionFactory* factory) {
factory_.reset(factory);
}
Expand All @@ -1290,6 +1295,11 @@ class SimpleEchoServer {
if (!ssl_key_.empty() && !ssl_cert_.empty() && !server_->use_ssl(ssl_key_, ssl_cert_)) {
return -1;
}

if (ssl_weaken_) {
server_->weaken_ssl();
}

server_->listen(&event_loop_group_);
return server_->wait_listen();
}
Expand All @@ -1316,6 +1326,7 @@ class SimpleEchoServer {
internal::ServerConnection::Ptr server_;
String ssl_key_;
String ssl_cert_;
bool ssl_weaken_;
};

} // namespace mockssandra
Expand Down
32 changes: 32 additions & 0 deletions tests/src/unit/tests/test_socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class SocketUnitTest : public LoopTest {
return settings;
}

void weaken_ssl() { server_.weaken_ssl(); }

void listen(const Address& address = Address("127.0.0.1", 8888)) {
ASSERT_EQ(server_.listen(address), 0);
}
Expand Down Expand Up @@ -185,6 +187,17 @@ class SocketUnitTest : public LoopTest {
}
}

/* SSL handshake failures have different error codes on different versions of
* OpenSSL - this accounts for both of them
*/
static void on_socket_ssl_error(SocketConnector* connector, bool* is_error) {
SocketConnector::SocketError err = connector->error_code();
if ((err == SocketConnector::SOCKET_ERROR_CLOSE) ||
(err == SocketConnector::SOCKET_ERROR_SSL_HANDSHAKE)) {
*is_error = true;
}
}

static void on_socket_canceled(SocketConnector* connector, bool* is_canceled) {
if (connector->is_canceled()) {
*is_canceled = true;
Expand Down Expand Up @@ -409,3 +422,22 @@ TEST_F(SocketUnitTest, SslVerifyIdentityDns) {

EXPECT_EQ(result, "The socket is successfully connected and wrote data - Closed");
}

TEST_F(SocketUnitTest, SslEnforceTlsVersion) {
SocketSettings settings(use_ssl("127.0.0.1"));
weaken_ssl();

listen();

settings.ssl_context->set_min_protocol_version(CASS_SSL_VERSION_TLS1_2);

bool is_error;
SocketConnector::Ptr connector(new SocketConnector(
Address("127.0.0.1", 8888), bind_callback(on_socket_ssl_error, &is_error)));

connector->with_settings(settings)->connect(loop());

uv_run(loop(), UV_RUN_DEFAULT);

EXPECT_TRUE(is_error);
}