Skip to content

bpo-18233: Add internal methods to access peer chain (GH-25467) #25467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ctypes = None

ssl = import_helper.import_module("ssl")
import _ssl

from ssl import TLSVersion, _TLSContentType, _TLSMessageType, _TLSAlertType

Expand Down Expand Up @@ -297,7 +298,7 @@ def test_wrap_socket(sock, *,
return context.wrap_socket(sock, **kwargs)


def testing_context(server_cert=SIGNED_CERTFILE):
def testing_context(server_cert=SIGNED_CERTFILE, *, server_chain=True):
"""Create context

client_context, server_context, hostname = testing_context()
Expand All @@ -316,7 +317,8 @@ def testing_context(server_cert=SIGNED_CERTFILE):

server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.load_cert_chain(server_cert)
server_context.load_verify_locations(SIGNING_CA)
if server_chain:
server_context.load_verify_locations(SIGNING_CA)

return client_context, server_context, hostname

Expand Down Expand Up @@ -2479,6 +2481,12 @@ def run(self):
elif stripped == b'GETCERT':
cert = self.sslconn.getpeercert()
self.write(repr(cert).encode("us-ascii") + b"\n")
elif stripped == b'VERIFIEDCHAIN':
certs = self.sslconn._sslobj.get_verified_chain()
self.write(len(certs).to_bytes(1, "big") + b"\n")
elif stripped == b'UNVERIFIEDCHAIN':
certs = self.sslconn._sslobj.get_unverified_chain()
self.write(len(certs).to_bytes(1, "big") + b"\n")
else:
if (support.verbose and
self.server.connectionchatty):
Expand Down Expand Up @@ -4565,6 +4573,63 @@ def test_bpo37428_pha_cert_none(self):
# server cert has not been validated
self.assertEqual(s.getpeercert(), {})

def test_internal_chain_client(self):
client_context, server_context, hostname = testing_context(
server_chain=False
)
server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(
socket.socket(),
server_hostname=hostname
) as s:
s.connect((HOST, server.port))
vc = s._sslobj.get_verified_chain()
self.assertEqual(len(vc), 2)
ee, ca = vc
uvc = s._sslobj.get_unverified_chain()
self.assertEqual(len(uvc), 1)

self.assertEqual(ee, uvc[0])
self.assertEqual(hash(ee), hash(uvc[0]))
self.assertEqual(repr(ee), repr(uvc[0]))

self.assertNotEqual(ee, ca)
self.assertNotEqual(hash(ee), hash(ca))
self.assertNotEqual(repr(ee), repr(ca))
self.assertNotEqual(ee.get_info(), ca.get_info())
self.assertIn("CN=localhost", repr(ee))
self.assertIn("CN=our-ca-server", repr(ca))

pem = ee.public_bytes(_ssl.ENCODING_PEM)
der = ee.public_bytes(_ssl.ENCODING_DER)
self.assertIsInstance(pem, str)
self.assertIn("-----BEGIN CERTIFICATE-----", pem)
self.assertIsInstance(der, bytes)
self.assertEqual(
ssl.PEM_cert_to_DER_cert(pem), der
)

def test_internal_chain_server(self):
client_context, server_context, hostname = testing_context()
client_context.load_cert_chain(SIGNED_CERTFILE)
server_context.verify_mode = ssl.CERT_REQUIRED
server_context.maximum_version = ssl.TLSVersion.TLSv1_2

server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(
socket.socket(),
server_hostname=hostname
) as s:
s.connect((HOST, server.port))
s.write(b'VERIFIEDCHAIN\n')
res = s.recv(1024)
self.assertEqual(res, b'\x02\n')
s.write(b'UNVERIFIEDCHAIN\n')
res = s.recv(1024)
self.assertEqual(res, b'\x02\n')


HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
requires_keylog = unittest.skipUnless(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Certificate and PrivateKey classes were added to the ssl module.
Certificates and keys can now be loaded from memory buffer, too.
84 changes: 83 additions & 1 deletion Modules/_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,9 @@ _certificate_to_der(_sslmodulestate *state, X509 *certificate)
return retval;
}

#include "_ssl/misc.c"
#include "_ssl/cert.c"

/*[clinic input]
_ssl._test_decode_cert
path: object(converter="PyUnicode_FSConverter")
Expand Down Expand Up @@ -1798,6 +1801,70 @@ _ssl__SSLSocket_getpeercert_impl(PySSLSocket *self, int binary_mode)
return result;
}

/*[clinic input]
_ssl._SSLSocket.get_verified_chain

[clinic start generated code]*/

static PyObject *
_ssl__SSLSocket_get_verified_chain_impl(PySSLSocket *self)
/*[clinic end generated code: output=802421163cdc3110 input=5fb0714f77e2bd51]*/
{
/* borrowed reference */
STACK_OF(X509) *chain = SSL_get0_verified_chain(self->ssl);
if (chain == NULL) {
Py_RETURN_NONE;
}
return _PySSL_CertificateFromX509Stack(self->ctx->state, chain, 1);
}

/*[clinic input]
_ssl._SSLSocket.get_unverified_chain

[clinic start generated code]*/

static PyObject *
_ssl__SSLSocket_get_unverified_chain_impl(PySSLSocket *self)
/*[clinic end generated code: output=5acdae414e13f913 input=78c33c360c635cb5]*/
{
PyObject *retval;
/* borrowed reference */
/* TODO: include SSL_get_peer_certificate() for server-side sockets */
STACK_OF(X509) *chain = SSL_get_peer_cert_chain(self->ssl);
if (chain == NULL) {
Py_RETURN_NONE;
}
retval = _PySSL_CertificateFromX509Stack(self->ctx->state, chain, 1);
if (retval == NULL) {
return NULL;
}
/* OpenSSL does not include peer cert for server side connections */
if (self->socket_type == PY_SSL_SERVER) {
PyObject *peerobj = NULL;
X509 *peer = SSL_get_peer_certificate(self->ssl);

if (peer == NULL) {
peerobj = Py_None;
Py_INCREF(peerobj);
} else {
/* consume X509 reference on success */
peerobj = _PySSL_CertificateFromX509(self->ctx->state, peer, 0);
if (peerobj == NULL) {
X509_free(peer);
Py_DECREF(retval);
return NULL;
}
}
int res = PyList_Insert(retval, 0, peerobj);
Py_DECREF(peerobj);
if (res < 0) {
Py_DECREF(retval);
return NULL;
}
}
return retval;
}

static PyObject *
cipher_to_tuple(const SSL_CIPHER *cipher)
{
Expand Down Expand Up @@ -2809,6 +2876,8 @@ static PyMethodDef PySSLMethods[] = {
_SSL__SSLSOCKET_COMPRESSION_METHODDEF
_SSL__SSLSOCKET_SHUTDOWN_METHODDEF
_SSL__SSLSOCKET_VERIFY_CLIENT_POST_HANDSHAKE_METHODDEF
_SSL__SSLSOCKET_GET_UNVERIFIED_CHAIN_METHODDEF
_SSL__SSLSOCKET_GET_VERIFIED_CHAIN_METHODDEF
{NULL, NULL}
};

Expand Down Expand Up @@ -5784,6 +5853,10 @@ sslmodule_init_constants(PyObject *m)
X509_CHECK_FLAG_SINGLE_LABEL_SUBDOMAINS);
#endif

/* file types */
PyModule_AddIntConstant(m, "ENCODING_PEM", PY_SSL_ENCODING_PEM);
PyModule_AddIntConstant(m, "ENCODING_DER", PY_SSL_ENCODING_DER);

/* protocol versions */
PyModule_AddIntConstant(m, "PROTO_MINIMUM_SUPPORTED",
PY_PROTO_MINIMUM_SUPPORTED);
Expand Down Expand Up @@ -5986,6 +6059,12 @@ sslmodule_init_types(PyObject *module)
if (state->PySSLSession_Type == NULL)
return -1;

state->PySSLCertificate_Type = (PyTypeObject *)PyType_FromModuleAndSpec(
module, &PySSLCertificate_spec, NULL
);
if (state->PySSLCertificate_Type == NULL)
return -1;

if (PyModule_AddType(module, state->PySSLContext_Type))
return -1;
if (PyModule_AddType(module, state->PySSLSocket_Type))
Expand All @@ -5994,7 +6073,8 @@ sslmodule_init_types(PyObject *module)
return -1;
if (PyModule_AddType(module, state->PySSLSession_Type))
return -1;

if (PyModule_AddType(module, state->PySSLCertificate_Type))
return -1;
return 0;
}

Expand All @@ -6017,6 +6097,7 @@ sslmodule_traverse(PyObject *m, visitproc visit, void *arg)
Py_VISIT(state->PySSLSocket_Type);
Py_VISIT(state->PySSLMemoryBIO_Type);
Py_VISIT(state->PySSLSession_Type);
Py_VISIT(state->PySSLCertificate_Type);
Py_VISIT(state->PySSLErrorObject);
Py_VISIT(state->PySSLCertVerificationErrorObject);
Py_VISIT(state->PySSLZeroReturnErrorObject);
Expand All @@ -6041,6 +6122,7 @@ sslmodule_clear(PyObject *m)
Py_CLEAR(state->PySSLSocket_Type);
Py_CLEAR(state->PySSLMemoryBIO_Type);
Py_CLEAR(state->PySSLSession_Type);
Py_CLEAR(state->PySSLCertificate_Type);
Py_CLEAR(state->PySSLErrorObject);
Py_CLEAR(state->PySSLCertVerificationErrorObject);
Py_CLEAR(state->PySSLZeroReturnErrorObject);
Expand Down
31 changes: 30 additions & 1 deletion Modules/_ssl.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#ifndef Py_SSL_H
#define Py_SSL_H

/* OpenSSL header files */
#include "openssl/evp.h"
#include "openssl/x509.h"

/*
* ssl module state
*/
Expand All @@ -10,6 +14,7 @@ typedef struct {
PyTypeObject *PySSLSocket_Type;
PyTypeObject *PySSLMemoryBIO_Type;
PyTypeObject *PySSLSession_Type;
PyTypeObject *PySSLCertificate_Type;
/* SSL error object */
PyObject *PySSLErrorObject;
PyObject *PySSLCertVerificationErrorObject;
Expand Down Expand Up @@ -40,6 +45,30 @@ get_ssl_state(PyObject *module)
(get_ssl_state(_PyType_GetModuleByDef(type, &_sslmodule_def)))
#define get_state_ctx(c) (((PySSLContext *)(c))->state)
#define get_state_sock(s) (((PySSLSocket *)(s))->ctx->state)
#define get_state_mbio(b) ((_sslmodulestate *)PyType_GetModuleState(Py_TYPE(b)))
#define get_state_obj(o) ((_sslmodulestate *)PyType_GetModuleState(Py_TYPE(o)))
#define get_state_mbio(b) get_state_obj(b)
#define get_state_cert(c) get_state_obj(c)

/* ************************************************************************
* certificate
*/

enum py_ssl_encoding {
PY_SSL_ENCODING_PEM=X509_FILETYPE_PEM,
PY_SSL_ENCODING_DER=X509_FILETYPE_ASN1,
PY_SSL_ENCODING_PEM_AUX=X509_FILETYPE_PEM + 0x100,
};

typedef struct {
PyObject_HEAD
X509 *cert;
Py_hash_t hash;
} PySSLCertificate;

/* ************************************************************************
* helpers and utils
*/
static PyObject *_PySSL_BytesFromBIO(_sslmodulestate *state, BIO *bio);
static PyObject *_PySSL_UnicodeFromBIO(_sslmodulestate *state, BIO *bio, const char *error);

#endif /* Py_SSL_H */
Loading