Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion iocore/net/P_SSLNetVConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class SSLNetVConnection : public UnixNetVConnection,
ink_hrtime sslLastWriteTime = 0;
int64_t sslTotalBytesSent = 0;

SSL_SESSION *client_sess = nullptr;
std::shared_ptr<SSL_SESSION> client_sess = nullptr;

// The serverName is either a pointer to the (null-terminated) name fetched from the
// SSL object or the empty string.
Expand Down
28 changes: 11 additions & 17 deletions iocore/net/SSLNetVConnection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -953,10 +953,9 @@ SSLNetVConnection::clear()
// operation, e.g. by using d2i_SSL_SESSION(3). It must not be called
// on other SSL_SESSION objects, as this would cause incorrect
// reference counts and therefore program failures.
if (client_sess != nullptr) {
SSL_SESSION_free(client_sess);
client_sess = nullptr;
}
// Since we created the shared pointer with a custom deleter,
// resetting here will decrement the ref-counter.
client_sess.reset();

if (ssl != nullptr) {
SSL_free(ssl);
Expand Down Expand Up @@ -2089,31 +2088,26 @@ SSLNetVConnection::_ssl_connect()

Debug("ssl.origin_session_cache", "origin session cache lookup key = %s", lookup_key.c_str());

sess = this->getOriginSession(ssl, lookup_key);
if (sess) {
if (SSL_set_session(ssl, sess) == 0) {
SSL_SESSION_free(sess);
} else {
if (this->client_sess) {
SSL_SESSION_free(this->client_sess);
}
this->client_sess = sess;
}
std::shared_ptr<SSL_SESSION> shared_sess = this->getOriginSession(ssl, lookup_key);

if (shared_sess && SSL_set_session(ssl, shared_sess.get())) {
// Keep a reference of this shared pointer in the connection
this->client_sess = shared_sess;
}
}
}

int ret = SSL_connect(ssl);

if (ret > 0) {
if (sess && SSL_session_reused(ssl)) {
if (SSL_session_reused(ssl)) {
SSL_INCREMENT_DYN_STAT(ssl_origin_session_reused_count);
if (is_debug_tag_set("ssl.origin_session_cache")) {
Debug("ssl.origin_session_cache", "reused session to origin server = %p", sess);
Debug("ssl.origin_session_cache", "reused session to origin server");
}
} else {
if (is_debug_tag_set("ssl.origin_session_cache")) {
Debug("ssl.origin_session_cache", "new session to origin server = %p", sess);
Debug("ssl.origin_session_cache", "new session to origin server");
}
}
return SSL_ERROR_NONE;
Expand Down
56 changes: 34 additions & 22 deletions iocore/net/SSLSessionCache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,13 @@ SSLSessionBucket::removeSession(const SSLSessionID &id)
return;
}

// Custom deleter for shared origin sessions
void
SSLSessDeleter(SSL_SESSION *_p)
{
SSL_SESSION_free(_p);
}

/* Session Bucket */
SSLSessionBucket::SSLSessionBucket() {}

Expand All @@ -322,10 +329,6 @@ SSLOriginSessionCache::~SSLOriginSessionCache() {}
void
SSLOriginSessionCache::insert_session(const std::string &lookup_key, SSL_SESSION *sess, SSL *ssl)
{
if (is_debug_tag_set("ssl.origin_session_cache")) {
Debug("ssl.origin_session_cache", "insert session: %s = %p", lookup_key.c_str(), sess);
}

size_t len = i2d_SSL_SESSION(sess, nullptr); // make sure we're not going to need more than SSL_MAX_ORIG_SESSION_SIZE bytes

/* do not cache a session that's too big. */
Expand All @@ -338,32 +341,43 @@ SSLOriginSessionCache::insert_session(const std::string &lookup_key, SSL_SESSION
return;
}

Ptr<IOBufferData> buf;
buf = new_IOBufferData(buffer_size_to_index(len, MAX_BUFFER_SIZE_INDEX), MEMALIGNED);
ink_release_assert(static_cast<size_t>(buf->block_size()) >= len);
unsigned char *loc = reinterpret_cast<unsigned char *>(buf->data());
i2d_SSL_SESSION(sess, &loc);
// Duplicate the session from the connection, we'll be keeping track the ref-count with a shared pointer ourself
SSL_SESSION *sess_ptr = SSL_SESSION_dup(sess);

if (is_debug_tag_set("ssl.origin_session_cache")) {
Debug("ssl.origin_session_cache", "insert session: %s = %p", lookup_key.c_str(), sess_ptr);
}

// Create the shared pointer to the session, with the custom deleter
std::shared_ptr<SSL_SESSION> shared_sess(sess_ptr, SSLSessDeleter);
ssl_curve_id curve = (ssl == nullptr) ? 0 : SSLGetCurveNID(ssl);
ats_scoped_obj<SSLOriginSession> ssl_orig_session(new SSLOriginSession(lookup_key, buf, len, curve));
ats_scoped_obj<SSLOriginSession> ssl_orig_session(new SSLOriginSession(lookup_key, curve, shared_sess));
auto new_node = ssl_orig_session.release();

std::unique_lock lock(mutex);
auto entry = orig_sess_map.find(lookup_key);
if (entry != orig_sess_map.end()) {
auto node = entry->second;
if (is_debug_tag_set("ssl.origin_session_cache")) {
Debug("ssl.origin_session_cache", "found duplicate key: %s, replacing %p with %p", lookup_key.c_str(),
node->shared_sess.get(), sess_ptr);
}
orig_sess_que.remove(node);
orig_sess_map.erase(entry);
delete node;
} else if (orig_sess_map.size() >= SSLConfigParams::origin_session_cache_size) {
if (is_debug_tag_set("ssl.origin_session_cache")) {
Debug("ssl.origin_session_cache", "origin session cache full, removing oldest session");
}
remove_oldest_session(lock);
}

orig_sess_que.enqueue(new_node);
orig_sess_map[lookup_key] = new_node;
}

bool
SSLOriginSessionCache::get_session(const std::string &lookup_key, SSL_SESSION **sess, ssl_curve_id *curve)
std::shared_ptr<SSL_SESSION>
SSLOriginSessionCache::get_session(const std::string &lookup_key, ssl_curve_id *curve)
{
if (is_debug_tag_set("ssl.origin_session_cache")) {
Debug("ssl.origin_session_cache", "get session: %s", lookup_key.c_str());
Expand All @@ -372,27 +386,26 @@ SSLOriginSessionCache::get_session(const std::string &lookup_key, SSL_SESSION **
std::shared_lock lock(mutex);
auto entry = orig_sess_map.find(lookup_key);
if (entry == orig_sess_map.end()) {
return false;
return nullptr;
}

const unsigned char *loc = reinterpret_cast<const unsigned char *>(entry->second->asn1_data->data());
*sess = d2i_SSL_SESSION(nullptr, &loc, entry->second->len_asn1_data);
if (curve != nullptr) {
*curve = entry->second->curve_id;
}
return true;

return entry->second->shared_sess;
}

void
SSLOriginSessionCache::remove_oldest_session(const std::unique_lock<std::shared_mutex> &lock)
{
// Caller must hold the bucket shared_mutex with unique_lock.
ink_assert(lock.owns_lock());
ink_release_assert(lock.owns_lock());

while (orig_sess_que.head && orig_sess_que.size >= static_cast<int>(SSLConfigParams::origin_session_cache_size)) {
auto node = orig_sess_que.pop();
if (is_debug_tag_set("ssl.origin_session_cache")) {
Debug("ssl.origin_session_cache", "remove oldest session: %s", node->key.c_str());
Debug("ssl.origin_session_cache", "remove oldest session: %s, session ptr: %p", node->key.c_str(), node->shared_sess.get());
}
orig_sess_map.erase(node->key);
delete node;
Expand All @@ -403,14 +416,13 @@ void
SSLOriginSessionCache::remove_session(const std::string &lookup_key)
{
// We can't bail on contention here because this session MUST be removed.
if (is_debug_tag_set("ssl.origin_session_cache")) {
Debug("ssl.origin_session_cache", "remove session: %s", lookup_key.c_str());
}

std::unique_lock lock(mutex);
auto entry = orig_sess_map.find(lookup_key);
if (entry != orig_sess_map.end()) {
auto node = entry->second;
if (is_debug_tag_set("ssl.origin_session_cache")) {
Debug("ssl.origin_session_cache", "remove session: %s, session ptr: %p", lookup_key.c_str(), node->shared_sess.get());
}
orig_sess_que.remove(node);
orig_sess_map.erase(entry);
delete node;
Expand Down
9 changes: 4 additions & 5 deletions iocore/net/SSLSessionCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,11 @@ class SSLOriginSession
{
public:
std::string key;
Ptr<IOBufferData> asn1_data; /* this is the ASN1 representation of the SSL_CTX */
size_t len_asn1_data;
ssl_curve_id curve_id;
std::shared_ptr<SSL_SESSION> shared_sess = nullptr;

SSLOriginSession(const std::string &lookup_key, const Ptr<IOBufferData> &asn1, size_t len_asn1, ssl_curve_id curve)
: key(lookup_key), asn1_data(asn1), len_asn1_data(len_asn1), curve_id(curve)
SSLOriginSession(const std::string &lookup_key, ssl_curve_id curve, std::shared_ptr<SSL_SESSION> session)
: key(lookup_key), curve_id(curve), shared_sess(session)
{
}

Expand All @@ -207,7 +206,7 @@ class SSLOriginSessionCache
~SSLOriginSessionCache();

void insert_session(const std::string &lookup_key, SSL_SESSION *sess, SSL *ssl);
bool get_session(const std::string &lookup_key, SSL_SESSION **sess, ssl_curve_id *curve);
std::shared_ptr<SSL_SESSION> get_session(const std::string &lookup_key, ssl_curve_id *curve);
void remove_session(const std::string &lookup_key);

private:
Expand Down
15 changes: 6 additions & 9 deletions iocore/net/TLSSessionResumptionSupport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,17 @@ TLSSessionResumptionSupport::getSession(SSL *ssl, const unsigned char *id, int l
return session;
}

SSL_SESSION *
std::shared_ptr<SSL_SESSION>
TLSSessionResumptionSupport::getOriginSession(SSL *ssl, const std::string &lookup_key)
{
SSL_SESSION *session = nullptr;
ssl_curve_id curve = 0;
if (origin_sess_cache->get_session(lookup_key, &session, &curve)) {
ink_assert(session);
ssl_curve_id curve = 0;
std::shared_ptr<SSL_SESSION> shared_sess = origin_sess_cache->get_session(lookup_key, &curve);

if (shared_sess != nullptr) {
// Double check the timeout
if (is_ssl_session_timed_out(session)) {
if (is_ssl_session_timed_out(shared_sess.get())) {
SSL_INCREMENT_DYN_STAT(ssl_origin_session_cache_miss);
origin_sess_cache->remove_session(lookup_key);
SSL_SESSION_free(session);
session = nullptr;
} else {
SSL_INCREMENT_DYN_STAT(ssl_origin_session_cache_hit);
this->_setSSLSessionCacheHit(true);
Expand All @@ -195,7 +192,7 @@ TLSSessionResumptionSupport::getOriginSession(SSL *ssl, const std::string &looku
} else {
SSL_INCREMENT_DYN_STAT(ssl_origin_session_cache_miss);
}
return session;
return shared_sess;
}

void
Expand Down
2 changes: 1 addition & 1 deletion iocore/net/TLSSessionResumptionSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TLSSessionResumptionSupport
ssl_curve_id getSSLCurveNID() const;

SSL_SESSION *getSession(SSL *ssl, const unsigned char *id, int len, int *copy);
SSL_SESSION *getOriginSession(SSL *ssl, const std::string &lookup_key);
std::shared_ptr<SSL_SESSION> getOriginSession(SSL *ssl, const std::string &lookup_key);

protected:
void clear();
Expand Down