Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@

import decorator
import urllib3
from urllib3.connection import match_hostname as urllib3_match_hostname
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket

from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
from .utils import SSL_PROTOCOL, MocketSocketCore, hexdump, hexload, wrap_ssl_socket
from .utils import SSL_PROTOCOL, MocketSocketCore, hexdump, hexload

xxh32 = None
try:
Expand Down Expand Up @@ -49,22 +50,32 @@
true_inet_pton = socket.inet_pton
true_urllib3_wrap_socket = urllib3_wrap_socket
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
true_urllib3_match_hostname = urllib3_match_hostname


class SuperFakeSSLContext(object):
""" For Python 3.6 """
"""For Python 3.6"""

class FakeSetter(int):
def __set__(self, *args):
pass

options = FakeSetter()
verify_mode = FakeSetter(ssl.CERT_OPTIONAL)
verify_mode = FakeSetter(ssl.CERT_NONE)


class FakeSSLContext(SuperFakeSSLContext):
sock = None
post_handshake_auth = None
_check_hostname = False

@property
def check_hostname(self):
return self._check_hostname

@check_hostname.setter
def check_hostname(self, *args):
self._check_hostname = False

def __init__(self, sock=None, server_hostname=None, _context=None, *args, **kwargs):
if isinstance(sock, MocketSocket):
Expand Down Expand Up @@ -141,16 +152,6 @@ def __init__(
self._truesocket_recording_dir = None
self.kwargs = kwargs

sock = kwargs.get("sock")
if sock is not None:
self.__dict__ = dict(sock.__dict__)

self.true_socket = wrap_ssl_socket(
true_ssl_socket,
self.true_socket,
true_ssl_context(protocol=SSL_PROTOCOL),
)

def __unicode__(self): # pragma: no cover
return str(self)

Expand Down Expand Up @@ -323,16 +324,10 @@ def true_sendall(self, data, *args, **kwargs):
host = true_gethostbyname(host)

if isinstance(self.true_socket, true_socket) and self._secure_socket:
try:
self = MocketSocket(sock=self)
except TypeError:
ssl_context = self.kwargs.get("ssl_context")
server_hostname = self.kwargs.get("server_hostname")
self.true_socket = true_ssl_context.wrap_socket(
self=ssl_context,
sock=self.true_socket,
server_hostname=server_hostname,
)
self.true_socket = true_urllib3_ssl_wrap_socket(
self.true_socket,
**self.kwargs,
)

try:
self.true_socket.connect((host, port))
Expand Down Expand Up @@ -388,7 +383,7 @@ def close(self):
self._fd = None

def __getattr__(self, name):
""" Do nothing catchall function, for methods like close() and shutdown() """
"""Do nothing catchall function, for methods like close() and shutdown()"""

def do_nothing(*args, **kwargs):
pass
Expand Down Expand Up @@ -479,6 +474,9 @@ def enable(namespace=None, truesocket_recording_dir=None):
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = FakeSSLContext.wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = lambda cert, hostname: None
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()
Expand Down Expand Up @@ -506,6 +504,9 @@ def disable():
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = true_urllib3_match_hostname
Mocket.reset()
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
Expand All @@ -521,7 +522,7 @@ def get_truesocket_recording_dir(cls):

@classmethod
def assert_fail_if_entries_not_served(cls):
""" Mocket checks that all entries have been served at least once. """
"""Mocket checks that all entries have been served at least once."""
assert all(
entry._served for entry in itertools.chain(*cls._entries.values())
), "Some Mocket entries have not been served"
Expand Down
29 changes: 0 additions & 29 deletions mocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,6 @@ def write(self, content):
os.write(Mocket.w_fd, content)


def wrap_ssl_socket(
cls,
sock,
context,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_NONE,
ssl_version=SSL_PROTOCOL,
ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
ciphers=None,
):
return cls(
sock=sock,
keyfile=keyfile,
certfile=certfile,
server_side=server_side,
cert_reqs=cert_reqs,
ssl_version=ssl_version,
ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers,
_context=context,
)


def hexdump(binary_string):
r"""
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
Expand Down