Skip to content

Commit ab49e7a

Browse files
committed
fix udp
1 parent 80b0ece commit ab49e7a

File tree

4 files changed

+58
-7
lines changed

4 files changed

+58
-7
lines changed

tests/test_context.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, cvar, *, loop=None):
2626
self.pipe_ctx = {0, 1, 2}
2727
self.pipe_connection_lost_fut = asyncio.Future(loop=loop)
2828
self.process_exited_fut = asyncio.Future(loop=loop)
29+
self.error_received_fut = asyncio.Future(loop=loop)
2930
self.connection_lost_ctx = None
3031
self.done = asyncio.Future(loop=loop)
3132

@@ -77,6 +78,14 @@ def buffer_updated(self, nbytes):
7778
)
7879

7980

81+
class _DatagramProtocol(_BaseProtocol, asyncio.DatagramProtocol):
82+
def datagram_received(self, data, addr):
83+
self.data_received_fut.set_result(self.cvar.get())
84+
85+
def error_received(self, exc):
86+
self.error_received_fut.set_result(self.cvar.get())
87+
88+
8089
class _SubprocessProtocol(_BaseProtocol, asyncio.SubprocessProtocol):
8190
def pipe_data_received(self, fd, data):
8291
self.data_received_fut.set_result(self.cvar.get())
@@ -703,6 +712,45 @@ async def test():
703712

704713
self.loop.run_until_complete(test())
705714

715+
def test_datagram_protocol(self):
716+
cvar = contextvars.ContextVar('cvar', default='outer')
717+
proto = _DatagramProtocol(cvar, loop=self.loop)
718+
server_addr = ('127.0.0.1', 8888)
719+
client_addr = ('127.0.0.1', 0)
720+
721+
async def run():
722+
self.assertEqual(cvar.get(), 'outer')
723+
cvar.set('inner')
724+
725+
def close():
726+
cvar.set('closing')
727+
proto.transport.close()
728+
729+
try:
730+
await self.loop.create_datagram_endpoint(
731+
lambda: proto, local_addr=server_addr)
732+
inner = await proto.connection_made_fut
733+
self.assertEqual(inner, "inner")
734+
735+
s = socket.socket(socket.AF_INET, type=socket.SOCK_DGRAM)
736+
s.bind(client_addr)
737+
s.sendto(b'data', server_addr)
738+
inner = await proto.data_received_fut
739+
self.assertEqual(inner, "inner")
740+
741+
self.loop.call_soon(close)
742+
await proto.done
743+
if self.implementation != 'asyncio':
744+
# bug in asyncio
745+
self.assertEqual(proto.connection_lost_ctx, "inner")
746+
finally:
747+
proto.transport.close()
748+
s.close()
749+
# let transports close
750+
await asyncio.sleep(0.1)
751+
752+
self.loop.run_until_complete(run())
753+
706754

707755
class Test_UV_Context(_ContextBaseTests, tb.UVTestCase):
708756
pass

uvloop/handles/udp.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ cdef class UDPTransport(UVBaseTransport):
1919
cdef _send(self, object data, object addr)
2020

2121
cdef _on_receive(self, bytes data, object exc, object addr)
22-
cdef _on_sent(self, object exc)
22+
cdef _on_sent(self, object exc, object context=*)

uvloop/handles/udp.pyx

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ cdef class UDPTransport(UVBaseTransport):
5656
self._family = uv.AF_UNSPEC
5757
self.__receiving = 0
5858
self._address = None
59+
self.context = Context_CopyCurrent()
5960

6061
cdef _init(self, Loop loop, unsigned int family):
6162
cdef int err
@@ -252,18 +253,20 @@ cdef class UDPTransport(UVBaseTransport):
252253
exc = convert_error(err)
253254
self._fatal_error(exc, True)
254255
else:
255-
self._on_sent(None)
256+
self._on_sent(None, self.context.copy())
256257

257258
cdef _on_receive(self, bytes data, object exc, object addr):
258259
if exc is None:
259-
self._protocol.datagram_received(data, addr)
260+
self.context.run(self._protocol.datagram_received, data, addr)
260261
else:
261-
self._protocol.error_received(exc)
262+
self.context.run(self._protocol.error_received, exc)
262263

263-
cdef _on_sent(self, object exc):
264+
cdef _on_sent(self, object exc, object context=None):
264265
if exc is not None:
265266
if isinstance(exc, OSError):
266-
self._protocol.error_received(exc)
267+
if context is None:
268+
context = self.context
269+
context.run(self._protocol.error_received, exc)
267270
else:
268271
self._fatal_error(
269272
exc, False, 'Fatal write error on datagram transport')

uvloop/sslproto.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ cdef class SSLProtocol:
739739
new_MethodHandle(self._loop,
740740
"SSLProtocol._do_read",
741741
<method_t>self._do_read,
742-
None,
742+
None, # current context is good
743743
self))
744744
except ssl_SSLAgainErrors as exc:
745745
pass

0 commit comments

Comments
 (0)