Skip to content

Commit 23942dd

Browse files
committed
feat: Provide HTTP headers via Start event
1 parent 7f0fc62 commit 23942dd

File tree

7 files changed

+258
-17
lines changed

7 files changed

+258
-17
lines changed

ld_eventsource/actions.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Optional
2+
from typing import Any, Dict, Optional
33

44

55
class Action:
@@ -110,9 +110,28 @@ class Start(Action):
110110
Instances of this class are only available from :attr:`.SSEClient.all`.
111111
A ``Start`` is returned for the first successful connection. If the client reconnects
112112
after a failure, there will be a :class:`.Fault` followed by a ``Start``.
113+
114+
Each ``Start`` action may include HTTP response headers from the connection. These headers
115+
are available via the :attr:`headers` property. On reconnection, a new ``Start`` will be
116+
emitted with the headers from the new connection, which may differ from the previous one.
113117
"""
114118

115-
pass
119+
def __init__(self, headers: Optional[Dict[str, Any]] = None):
120+
self._headers = headers
121+
122+
@property
123+
def headers(self) -> Optional[Dict[str, Any]]:
124+
"""
125+
The HTTP response headers from the stream connection, if available.
126+
127+
For HTTP-based connections, this contains the headers from the SSE response.
128+
For non-HTTP connections, this will be ``None``.
129+
130+
The headers dict uses case-insensitive keys (via urllib3's HTTPHeaderDict).
131+
132+
:return: the response headers, or ``None`` if not available
133+
"""
134+
return self._headers
116135

117136

118137
class Fault(Action):

ld_eventsource/config/connect_strategy.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from dataclasses import dataclass
44
from logging import Logger
5-
from typing import Callable, Iterator, Optional, Union
5+
from typing import Any, Callable, Dict, Iterator, Optional, Union
66

77
from urllib3 import PoolManager
88

@@ -96,9 +96,10 @@ class ConnectionResult:
9696
The return type of :meth:`ConnectionClient.connect()`.
9797
"""
9898

99-
def __init__(self, stream: Iterator[bytes], closer: Optional[Callable]):
99+
def __init__(self, stream: Iterator[bytes], closer: Optional[Callable], headers: Optional[Dict[str, Any]] = None):
100100
self.__stream = stream
101101
self.__closer = closer
102+
self.__headers = headers
102103

103104
@property
104105
def stream(self) -> Iterator[bytes]:
@@ -107,6 +108,18 @@ def stream(self) -> Iterator[bytes]:
107108
"""
108109
return self.__stream
109110

111+
@property
112+
def headers(self) -> Optional[Dict[str, Any]]:
113+
"""
114+
The HTTP response headers, if available.
115+
116+
For HTTP connections, this contains the headers from the SSE stream response.
117+
For non-HTTP connections, this will be ``None``.
118+
119+
The headers dict uses case-insensitive keys (via urllib3's HTTPHeaderDict).
120+
"""
121+
return self.__headers
122+
110123
def close(self):
111124
"""
112125
Does whatever is necessary to release the connection.
@@ -139,8 +152,8 @@ def __init__(self, params: _HttpConnectParams, logger: Logger):
139152
self.__impl = _HttpClientImpl(params, logger)
140153

141154
def connect(self, last_event_id: Optional[str]) -> ConnectionResult:
142-
stream, closer = self.__impl.connect(last_event_id)
143-
return ConnectionResult(stream, closer)
155+
stream, closer, headers = self.__impl.connect(last_event_id)
156+
return ConnectionResult(stream, closer, headers)
144157

145158
def close(self):
146159
self.__impl.close()

ld_eventsource/http.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from logging import Logger
2-
from typing import Callable, Iterator, Optional, Tuple
2+
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, cast
33
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
44

55
from urllib3 import PoolManager
@@ -60,7 +60,7 @@ def __init__(self, params: _HttpConnectParams, logger: Logger):
6060
self.__should_close_pool = params.pool is not None
6161
self.__logger = logger
6262

63-
def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callable]:
63+
def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callable, Dict[str, Any]]:
6464
url = self.__params.url
6565
if self.__params.query_params is not None:
6666
qp = self.__params.query_params()
@@ -109,6 +109,7 @@ def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callab
109109
raise HTTPContentTypeError(content_type or '')
110110

111111
stream = resp.stream(_CHUNK_SIZE)
112+
response_headers = cast(Dict[str, Any], resp.headers)
112113

113114
def close():
114115
try:
@@ -117,7 +118,7 @@ def close():
117118
pass
118119
resp.release_conn()
119120

120-
return stream, close
121+
return stream, close, response_headers
121122

122123
def close(self):
123124
if self.__should_close_pool:

ld_eventsource/sse_client.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ class SSEClient:
3939
:meth:`.RetryDelayStrategy.default()`, this delay will double with each subsequent retry,
4040
and will also have a pseudo-random jitter subtracted. You can customize this behavior with
4141
``retry_delay_strategy``.
42+
43+
**HTTP Response Headers:**
44+
When using HTTP-based connections, the response headers from each connection are available
45+
via the :attr:`.Start.headers` property when reading from :attr:`all`. Each time the client
46+
connects or reconnects, a :class:`.Start` action is emitted containing the headers from that
47+
specific connection. This allows you to access server metadata such as rate limits, session
48+
identifiers, or custom headers.
4249
"""
4350

4451
def __init__(
@@ -178,9 +185,10 @@ def all(self) -> Iterable[Action]:
178185
# Reading implies starting the stream if it isn't already started. We might also
179186
# be restarting since we could have been interrupted at any time.
180187
while self.__connection_result is None:
181-
fault = self._try_start(True)
188+
result = self._try_start(True)
182189
# return either a Start action or a Fault action
183-
yield Start() if fault is None else fault
190+
if result is not None:
191+
yield result
184192

185193
lines = _BufferedLineReader.lines_from(self.__connection_result.stream)
186194
reader = _SSEReader(lines, self.__last_event_id, None)
@@ -263,7 +271,7 @@ def _compute_next_retry_delay(self):
263271
self.__current_retry_delay_strategy.apply(self.__base_retry_delay)
264272
)
265273

266-
def _try_start(self, can_return_fault: bool) -> Optional[Fault]:
274+
def _try_start(self, can_return_fault: bool) -> Union[None, Start, Fault]:
267275
if self.__connection_result is not None:
268276
return None
269277
while True:
@@ -297,7 +305,7 @@ def _try_start(self, can_return_fault: bool) -> Optional[Fault]:
297305
self._retry_reset_baseline = time.time()
298306
self.__current_error_strategy = self.__base_error_strategy
299307
self.__interrupted = False
300-
return None
308+
return Start(self.__connection_result.headers)
301309

302310
@property
303311
def last_event_id(self) -> Optional[str]:

ld_eventsource/testing/helpers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,17 @@ def apply(self) -> ConnectionResult:
6666

6767

6868
class RespondWithStream(MockConnectionHandler):
69-
def __init__(self, stream: Iterable[bytes]):
69+
def __init__(self, stream: Iterable[bytes], headers: Optional[dict] = None):
7070
self.__stream = stream
71+
self.__headers = headers
7172

7273
def apply(self) -> ConnectionResult:
73-
return ConnectionResult(stream=self.__stream.__iter__(), closer=None)
74+
return ConnectionResult(stream=self.__stream.__iter__(), closer=None, headers=self.__headers)
7475

7576

7677
class RespondWithData(RespondWithStream):
77-
def __init__(self, data: str):
78-
super().__init__([bytes(data, 'utf-8')])
78+
def __init__(self, data: str, headers: Optional[dict] = None):
79+
super().__init__([bytes(data, 'utf-8')], headers)
7980

8081

8182
class ExpectNoMoreRequests(MockConnectionHandler):
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import pytest
2+
3+
from ld_eventsource import *
4+
from ld_eventsource.actions import *
5+
from ld_eventsource.config import *
6+
from ld_eventsource.errors import *
7+
from ld_eventsource.testing.helpers import *
8+
9+
10+
def test_start_action_with_no_headers():
11+
"""Test that Start action can be created without headers"""
12+
start = Start()
13+
assert start.headers is None
14+
15+
16+
def test_start_action_with_headers():
17+
"""Test that Start action can be created with headers"""
18+
headers = {'Content-Type': 'text/event-stream', 'X-Custom': 'value'}
19+
start = Start(headers)
20+
assert start.headers == headers
21+
22+
23+
def test_headers_exposed_in_start_action():
24+
"""Test that headers from connection are exposed in Start action"""
25+
headers = {'Content-Type': 'text/event-stream', 'X-Test-Header': 'test-value'}
26+
mock = MockConnectStrategy(
27+
RespondWithData("event: test\ndata: data1\n\n", headers=headers)
28+
)
29+
30+
with SSEClient(connect=mock) as client:
31+
all_items = list(client.all)
32+
33+
# First item should be Start with headers
34+
assert isinstance(all_items[0], Start)
35+
assert all_items[0].headers == headers
36+
37+
# Second item should be the event
38+
assert isinstance(all_items[1], Event)
39+
assert all_items[1].event == 'test'
40+
41+
# Third item should be Fault (end of stream)
42+
assert isinstance(all_items[2], Fault)
43+
assert all_items[2].error is None
44+
45+
46+
def test_headers_not_visible_in_events_iterator():
47+
"""Test that headers are only visible when using .all, not .events"""
48+
headers = {'X-Custom': 'value'}
49+
mock = MockConnectStrategy(
50+
RespondWithData("event: test\ndata: data1\n\n", headers=headers)
51+
)
52+
53+
with SSEClient(connect=mock) as client:
54+
events = list(client.events)
55+
56+
# Should only get the event, no Start action
57+
assert len(events) == 1
58+
assert isinstance(events[0], Event)
59+
assert events[0].event == 'test'
60+
61+
62+
def test_no_headers_when_not_provided():
63+
"""Test that Start action has None headers when connection doesn't provide them"""
64+
mock = MockConnectStrategy(
65+
RespondWithData("event: test\ndata: data1\n\n")
66+
)
67+
68+
with SSEClient(connect=mock) as client:
69+
all_items = list(client.all)
70+
71+
# First item should be Start with no headers
72+
assert isinstance(all_items[0], Start)
73+
assert all_items[0].headers is None
74+
75+
76+
def test_different_headers_on_reconnection():
77+
"""Test that reconnection yields new Start with potentially different headers"""
78+
headers1 = {'X-Connection': 'first'}
79+
headers2 = {'X-Connection': 'second'}
80+
81+
mock = MockConnectStrategy(
82+
RespondWithData("event: test1\ndata: data1\n\n", headers=headers1),
83+
RespondWithData("event: test2\ndata: data2\n\n", headers=headers2)
84+
)
85+
86+
with SSEClient(
87+
connect=mock,
88+
error_strategy=ErrorStrategy.from_lambda(lambda _: (ErrorStrategy.CONTINUE, None)),
89+
retry_delay_strategy=no_delay()
90+
) as client:
91+
items = []
92+
for item in client.all:
93+
items.append(item)
94+
# Stop after we get the second Start (from reconnection)
95+
if isinstance(item, Start) and len([i for i in items if isinstance(i, Start)]) == 2:
96+
break
97+
98+
# Find all Start actions
99+
starts = [item for item in items if isinstance(item, Start)]
100+
assert len(starts) >= 2
101+
102+
# First connection should have first headers
103+
assert starts[0].headers == headers1
104+
105+
# Second connection should have second headers
106+
assert starts[1].headers == headers2
107+
108+
109+
def test_headers_on_retry_after_error():
110+
"""Test that headers are provided on successful retry after an error"""
111+
error = HTTPStatusError(503)
112+
headers = {'X-Retry': 'success'}
113+
114+
mock = MockConnectStrategy(
115+
RejectConnection(error),
116+
RespondWithData("event: test\ndata: data1\n\n", headers=headers)
117+
)
118+
119+
with SSEClient(
120+
connect=mock,
121+
error_strategy=ErrorStrategy.from_lambda(lambda _: (ErrorStrategy.CONTINUE, None)),
122+
retry_delay_strategy=no_delay()
123+
) as client:
124+
items = []
125+
for item in client.all:
126+
items.append(item)
127+
if isinstance(item, Event):
128+
break
129+
130+
# Should have: Fault (from error), Start (from retry), Event
131+
assert isinstance(items[0], Fault)
132+
assert isinstance(items[0].error, HTTPStatusError)
133+
134+
assert isinstance(items[1], Start)
135+
assert items[1].headers == headers
136+
137+
assert isinstance(items[2], Event)
138+
139+
140+
def test_connection_result_headers_property():
141+
"""Test that ConnectionResult properly stores and returns headers"""
142+
headers = {'X-Test': 'value'}
143+
result = ConnectionResult(stream=iter([b'data']), closer=None, headers=headers)
144+
assert result.headers == headers
145+
146+
147+
def test_connection_result_no_headers():
148+
"""Test that ConnectionResult returns None when no headers provided"""
149+
result = ConnectionResult(stream=iter([b'data']), closer=None)
150+
assert result.headers is None

ld_eventsource/testing/test_http_connect_strategy.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from urllib3.exceptions import ProtocolError
44

5+
from ld_eventsource import *
6+
from ld_eventsource.actions import *
57
from ld_eventsource.config.connect_strategy import *
68
from ld_eventsource.testing.helpers import *
79
from ld_eventsource.testing.http_util import *
@@ -133,3 +135,50 @@ def test_sse_client_with_http_connect_strategy():
133135
stream.push("data: data1\n\n")
134136
event = next(client.events)
135137
assert event.data == 'data1'
138+
139+
140+
def test_http_response_headers_captured():
141+
"""Test that HTTP response headers are captured from the connection"""
142+
with start_server() as server:
143+
custom_headers = {
144+
'Content-Type': 'text/event-stream',
145+
'X-Custom-Header': 'custom-value',
146+
'X-Rate-Limit': '100'
147+
}
148+
with ChunkedResponse(custom_headers) as stream:
149+
server.for_path('/', stream)
150+
with ConnectStrategy.http(server.uri).create_client(logger()) as client:
151+
result = client.connect(None)
152+
assert result.headers is not None
153+
assert result.headers.get('X-Custom-Header') == 'custom-value'
154+
assert result.headers.get('X-Rate-Limit') == '100'
155+
# urllib3 should also include Content-Type
156+
assert 'Content-Type' in result.headers
157+
158+
159+
def test_http_response_headers_in_sse_client():
160+
"""Test that headers are exposed via Start action in SSEClient"""
161+
with start_server() as server:
162+
custom_headers = {
163+
'Content-Type': 'text/event-stream',
164+
'X-Session-Id': 'abc123'
165+
}
166+
with ChunkedResponse(custom_headers) as stream:
167+
server.for_path('/', stream)
168+
with SSEClient(connect=ConnectStrategy.http(server.uri)) as client:
169+
stream.push("event: test\ndata: data1\n\n")
170+
171+
# Read from .all to get Start action
172+
all_items = []
173+
for item in client.all:
174+
all_items.append(item)
175+
if isinstance(item, Event):
176+
break
177+
178+
# First item should be Start with headers
179+
assert isinstance(all_items[0], Start)
180+
assert all_items[0].headers is not None
181+
assert all_items[0].headers.get('X-Session-Id') == 'abc123'
182+
183+
# Second item should be the event
184+
assert isinstance(all_items[1], Event)

0 commit comments

Comments
 (0)