77#include < openssl/err.h>
88#include < openssl/asn1.h>
99
10- #include < iostream>
1110
1211namespace {
1312
1413std::string getCertificateInfo (X509* cert)
1514{
15+ if (!cert)
16+ return " No certificate" ;
17+
1618 std::unique_ptr<BIO, decltype (&BIO_free)> mem_bio (BIO_new (BIO_s_mem ()), &BIO_free);
1719 X509_print (mem_bio.get (), cert);
20+
1821 char * data = nullptr ;
19- size_t len = BIO_get_mem_data (mem_bio.get (), &data);
22+ auto len = BIO_get_mem_data (mem_bio.get (), &data);
23+ if (len < 0 )
24+ return " Can't get certificate info due to BIO error " + std::to_string (len);
2025
2126 return std::string (data, len);
2227}
2328
24- void throwSSLError (SSL * ssl, int error, const char * location, const char * statement) {
29+ void throwSSLError (SSL * ssl, int error, const char * /* location*/ , const char * /* statement*/ ) {
2530 const auto detail_error = ERR_get_error ();
2631 auto reason = ERR_reason_error_string (detail_error);
2732 reason = reason ? reason : " Unknown SSL error" ;
2833
2934 std::string reason_str = reason;
3035 if (ssl) {
31- if (auto ssl_session = SSL_get_session (ssl))
32- if (auto server_certificate = SSL_SESSION_get0_peer (ssl_session))
33- reason_str += " \n Server certificate: " + getCertificateInfo (server_certificate);
36+ // TODO: maybe print certificate only if handshake isn't completed (SSL_get_state(ssl) != TLS_ST_OK)
37+ if (auto ssl_session = SSL_get_session (ssl)) {
38+ reason_str += " \n Server certificate: " + getCertificateInfo (SSL_SESSION_get0_peer (ssl_session));
39+ }
3440 }
3541
36- std::cerr << " !!! SSL error at " << location
37- << " \n\t caused by " << statement
38- << " \n\t : " << reason_str << " (" << error << " )"
39- << " \n\t last err: " << ERR_peek_last_error ()
40- << std::endl;
42+ // std::cerr << "!!! SSL error at " << location
43+ // << "\n\tcaused by " << statement
44+ // << "\n\t: "<< reason_str << "(" << error << ")"
45+ // << "\n\t last err: " << ERR_peek_last_error()
46+ // << std::endl;
4147
4248 throw std::runtime_error (std::string (" OpenSSL error: " ) + std::to_string (error) + " : " + reason_str);
4349}
@@ -64,8 +70,8 @@ SSL_CTX * prepareSSLContext(const clickhouse::SSLParams & context_params) {
6470 throw std::runtime_error (" Failed to initialize SSL context" );
6571
6672#define HANDLE_SSL_CTX_ERROR (statement ) do { \
67- if (const auto ret_code = statement; !ret_code) \
68- throwSSLError (nullptr , ERR_peek_error (), LOCATION, STRINGIFY ( statement) ); \
73+ if (const auto ret_code = ( statement) ; !ret_code) \
74+ throwSSLError (nullptr , ERR_peek_error (), LOCATION, # statement); \
6975} while (false );
7076
7177 if (context_params.use_default_ca_locations )
@@ -91,47 +97,48 @@ SSL_CTX * prepareSSLContext(const clickhouse::SSLParams & context_params) {
9197 SSL_CTX_set_max_proto_version (ctx.get (), context_params.max_protocol_version ));
9298
9399 return ctx.release ();
100+ #undef HANDLE_SSL_CTX_ERROR
94101}
95102
96-
97-
98103}
99104
100- #define HANDLE_SSL_ERROR (statement ) do { \
101- if (const auto ret_code = statement; ret_code <= 0 ) \
102- throwSSLError (ssl_, SSL_get_error (ssl_, ret_code), LOCATION, STRINGIFY (statement)); \
103- } while (false );
104-
105105namespace clickhouse {
106106
107107SSLContext::SSLContext (SSL_CTX & context)
108- : context_(&context)
108+ : context_(&context, &SSL_CTX_free )
109109{
110- SSL_CTX_up_ref (context_);
110+ SSL_CTX_up_ref (context_. get () );
111111}
112112
113113SSLContext::SSLContext (const SSLParams & context_params)
114- : context_(prepareSSLContext(context_params))
114+ : context_(prepareSSLContext(context_params), &SSL_CTX_free )
115115{
116116}
117117
118- SSLContext::~SSLContext () {
119- SSL_CTX_free (context_);
120- }
121-
122118SSL_CTX * SSLContext::getContext () {
123- return context_;
119+ return context_. get () ;
124120}
125121
122+ // Allows caller to use returned value of `statement` if there was no error, throws exception otherwise.
123+ #define HANDLE_SSL_ERROR (statement ) [&] { \
124+ if (const auto ret_code = (statement); ret_code <= 0 ) { \
125+ throwSSLError (ssl_, SSL_get_error (ssl_, ret_code), LOCATION, #statement); \
126+ return static_cast <decltype (ret_code)>(0 ); \
127+ } \
128+ else \
129+ return ret_code; \
130+ }()
131+
126132/* // debug macro for tracing SSL state
127133#define LOG_SSL_STATE() std::cerr << "!!!!" << LOCATION << " @" << __FUNCTION__ \
128134 << "\t" << SSL_get_version(ssl_) << " state: " << SSL_state_string_long(ssl_) \
129135 << "\n\t handshake state: " << SSL_get_state(ssl_) \
130136 << std::endl
131137*/
132138SSLSocket::SSLSocket (const NetworkAddress& addr, const SSLParams & ssl_params, SSLContext& context)
133- : Socket(addr),
134- ssl_ (SSL_new(context.getContext()))
139+ : Socket(addr)
140+ , ssl_ptr_(SSL_new(context.getContext()), &SSL_free)
141+ , ssl_(ssl_ptr_.get())
135142{
136143 if (!ssl_)
137144 throw std::runtime_error (" Failed to create SSL instance" );
@@ -143,41 +150,37 @@ SSLSocket::SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, S
143150 SSL_set_connect_state (ssl_);
144151 HANDLE_SSL_ERROR (SSL_connect (ssl_));
145152 HANDLE_SSL_ERROR (SSL_set_mode (ssl_, SSL_MODE_AUTO_RETRY));
153+ auto peer_certificate = SSL_get_peer_certificate (ssl_);
146154
147- if (const auto verify_result = SSL_get_verify_result (ssl_); verify_result != X509_V_OK) {
148- auto error_message = X509_verify_cert_error_string (verify_result);
149- auto ssl_session = SSL_get_session (ssl_);
150- auto cert = SSL_SESSION_get0_peer (ssl_session);
155+ if (!peer_certificate)
156+ throw std::runtime_error (" Failed to verify SSL connection: server provided no ceritificate." );
151157
158+ if (const auto verify_result = SSL_get_verify_result (ssl_); verify_result != X509_V_OK) {
159+ auto error_message = X509_verify_cert_error_string (verify_result);
152160 throw std::runtime_error (" Failed to verify SSL connection, X509_v error: "
153- + std::to_string (verify_result)
154- + " " + error_message + " \n " + getCertificateInfo (cert));
161+ + std::to_string (verify_result)
162+ + " " + error_message
163+ + " \n Server certificate: " + getCertificateInfo (peer_certificate));
155164 }
156165
157166 if (ssl_params.use_SNI ) {
158- auto ssl_session = SSL_get_session (ssl_);
159- auto peer_cert = SSL_SESSION_get0_peer (ssl_session);
160167 auto hostname = addr.Host ();
161168 char * out_name = nullptr ;
162169
163170 std::unique_ptr<ASN1_OCTET_STRING, decltype (&ASN1_OCTET_STRING_free)> addr (a2i_IPADDRESS (hostname.c_str ()), &ASN1_OCTET_STRING_free);
164171 if (addr) {
165172 // if hostname is actually an IP address
166173 HANDLE_SSL_ERROR (X509_check_ip (
167- peer_cert ,
174+ peer_certificate ,
168175 ASN1_STRING_get0_data (addr.get ()),
169176 ASN1_STRING_length (addr.get ()),
170177 0 ));
171178 } else {
172- HANDLE_SSL_ERROR (X509_check_host (peer_cert , hostname.c_str (), hostname.length (), 0 , &out_name));
179+ HANDLE_SSL_ERROR (X509_check_host (peer_certificate , hostname.c_str (), hostname.length (), 0 , &out_name));
173180 }
174181 }
175182}
176183
177- SSLSocket::~SSLSocket () {
178- SSL_free (ssl_);
179- }
180-
181184std::unique_ptr<InputStream> SSLSocket::makeInputStream () const {
182185 return std::make_unique<SSLSocketInput>(ssl_);
183186}
@@ -190,8 +193,6 @@ SSLSocketInput::SSLSocketInput(SSL *ssl)
190193 : ssl_(ssl)
191194{}
192195
193- SSLSocketInput::~SSLSocketInput () = default ;
194-
195196size_t SSLSocketInput::DoRead (void * buf, size_t len) {
196197 size_t actually_read;
197198 HANDLE_SSL_ERROR (SSL_read_ex (ssl_, buf, len, &actually_read));
@@ -202,8 +203,6 @@ SSLSocketOutput::SSLSocketOutput(SSL *ssl)
202203 : ssl_(ssl)
203204{}
204205
205- SSLSocketOutput::~SSLSocketOutput () = default ;
206-
207206void SSLSocketOutput::DoWrite (const void * data, size_t len) {
208207 HANDLE_SSL_ERROR (SSL_write (ssl_, data, len));
209208}
0 commit comments