Skip to content

Commit a7c6051

Browse files
authored
Add support for 'requests.Session.auth' via 'NodeConfig._extras'
1 parent dae01ce commit a7c6051

File tree

4 files changed

+47
-14
lines changed

4 files changed

+47
-14
lines changed

docs/sphinx/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@
4545

4646
intersphinx_mapping = {
4747
"python": ("https://docs.python.org/3", None),
48+
"requests": ("https://docs.python-requests.org/en/master", None),
4849
}

elastic_transport/_models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class NodeConfig:
232232
connections_per_node: int = 10
233233

234234
#: Number of seconds to wait before a request should timeout.
235-
request_timeout: Optional[int] = 10
235+
request_timeout: Optional[float] = 10.0
236236

237237
#: Set to ``True`` to enable HTTP compression
238238
#: of request and response bodies via gzip.
@@ -278,10 +278,10 @@ class NodeConfig:
278278
#: issued when using ``verify_certs=False``.
279279
ssl_show_warn: bool = True
280280

281-
# Extras that can be set to anything, typically used
282-
# for annotating this node with additional information for
283-
# future decisions like sniffing, instance roles, etc.
284-
# Third-party keys should start with an underscore and prefix.
281+
#: Extras that can be set to anything, typically used
282+
#: for annotating this node with additional information for
283+
#: future decisions like sniffing, instance roles, etc.
284+
#: Third-party keys should start with an underscore and prefix.
285285
_extras: Dict[str, Any] = field(default_factory=dict, hash=False)
286286

287287
def replace(self, **kwargs: Any) -> "NodeConfig":

elastic_transport/_node/_http_requests.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
try:
3838
import requests
3939
from requests.adapters import HTTPAdapter
40+
from requests.auth import AuthBase
4041

4142
_REQUESTS_AVAILABLE = True
4243
_REQUESTS_META_VERSION = client_meta_version(requests.__version__)
@@ -79,7 +80,12 @@ def init_poolmanager(
7980

8081

8182
class RequestsHttpNode(BaseNode):
82-
"""Synchronous node using the ``requests`` library communicating via HTTP"""
83+
"""Synchronous node using the ``requests`` library communicating via HTTP.
84+
85+
Supports setting :attr:`requests.Session.auth` via the
86+
:attr:`elastic_transport.NodeConfig._extras`
87+
using the ``requests.session.auth`` key.
88+
"""
8389

8490
_CLIENT_META_HTTP_CLIENT = ("rq", _REQUESTS_META_VERSION)
8591

@@ -96,6 +102,16 @@ def __init__(self, config: NodeConfig):
96102
self.session.headers.clear() # Empty out all the default session headers
97103
self.session.verify = config.verify_certs
98104

105+
# Requests supports setting 'session.auth' via _extras['requests.session.auth'] = ...
106+
try:
107+
requests_session_auth: Optional[AuthBase] = config._extras.pop(
108+
"requests.session.auth", None
109+
)
110+
except AttributeError:
111+
requests_session_auth = None
112+
if requests_session_auth is not None:
113+
self.session.auth = requests_session_auth
114+
99115
# Client certificates
100116
if config.client_cert:
101117
if config.client_key:

tests/node/test_http_requests.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
import pytest
2323
import requests
2424
from mock import Mock, patch
25+
from requests.auth import HTTPBasicAuth
2526

2627
from elastic_transport import NodeConfig, RequestsHttpNode
2728
from elastic_transport._node._base import DEFAULT_USER_AGENT
2829

2930

3031
class TestRequestsHttpNode:
31-
def _get_mode_node(self, node_config, response_body=b"{}"):
32+
def _get_mock_node(self, node_config, response_body=b"{}"):
3233
node = RequestsHttpNode(node_config)
3334

3435
def _dummy_send(*args, **kwargs):
@@ -69,7 +70,7 @@ def test_ssl_context(self):
6970
assert adapter.poolmanager.connection_pool_kw["ssl_context"] is ctx
7071

7172
def test_merge_headers(self):
72-
node = self._get_mode_node(
73+
node = self._get_mock_node(
7374
NodeConfig("http", "localhost", 80, headers={"h1": "v1", "h2": "v2"})
7475
)
7576
req = self._get_request(node, "GET", "/", headers={"h2": "v2p", "h3": "v3"})
@@ -78,15 +79,15 @@ def test_merge_headers(self):
7879
assert req.headers["h3"] == "v3"
7980

8081
def test_default_headers(self):
81-
node = self._get_mode_node(NodeConfig("http", "localhost", 80))
82+
node = self._get_mock_node(NodeConfig("http", "localhost", 80))
8283
req = self._get_request(node, "GET", "/")
8384
assert req.headers == {
8485
"connection": "keep-alive",
8586
"user-agent": DEFAULT_USER_AGENT,
8687
}
8788

8889
def test_no_http_compression(self):
89-
node = self._get_mode_node(
90+
node = self._get_mock_node(
9091
NodeConfig("http", "localhost", 80, http_compress=False)
9192
)
9293
assert not node.config.http_compress
@@ -108,7 +109,7 @@ def test_no_http_compression(self):
108109

109110
@pytest.mark.parametrize("empty_body", [None, b""])
110111
def test_http_compression(self, empty_body):
111-
node = self._get_mode_node(
112+
node = self._get_mock_node(
112113
NodeConfig("http", "localhost", 80, http_compress=True)
113114
)
114115
assert node.config.http_compress is True
@@ -135,7 +136,7 @@ def test_http_compression(self, empty_body):
135136

136137
@pytest.mark.parametrize("request_timeout", [None, 15])
137138
def test_timeout_override_default(self, request_timeout):
138-
node = self._get_mode_node(
139+
node = self._get_mock_node(
139140
NodeConfig("http", "localhost", 80, request_timeout=request_timeout)
140141
)
141142
assert node.config.request_timeout == request_timeout
@@ -214,8 +215,23 @@ def test_ca_certs_is_used_as_session_verify(self):
214215

215216
def test_surrogatepass_into_bytes(self):
216217
data = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
217-
con = self._get_mode_node(
218+
node = self._get_mock_node(
218219
NodeConfig("http", "localhost", 80), response_body=data
219220
)
220-
_, data = con.perform_request("GET", "/")
221+
_, data = node.perform_request("GET", "/")
221222
assert b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" == data
223+
224+
@pytest.mark.parametrize("_extras", [None, {}, {"requests.session.auth": None}])
225+
def test_requests_no_session_auth(self, _extras):
226+
node = self._get_mock_node(NodeConfig("http", "localhost", 80, _extras=_extras))
227+
assert node.session.auth is None
228+
229+
def test_requests_custom_auth(self):
230+
auth = HTTPBasicAuth("username", "password")
231+
node = self._get_mock_node(
232+
NodeConfig("http", "localhost", 80, _extras={"requests.session.auth": auth})
233+
)
234+
assert node.session.auth is auth
235+
node.perform_request("GET", "/")
236+
(request,), _ = node.session.send.call_args
237+
assert request.headers["authorization"] == "Basic dXNlcm5hbWU6cGFzc3dvcmQ="

0 commit comments

Comments
 (0)