11import asyncio
2+ import logging
23from datetime import timedelta
34from typing import AsyncGenerator
45
56import pytest
7+ from grpc .aio import grpc
68
79from replit_river .client import Client
810from replit_river .error_schema import RiverError
11+ from replit_river .rpc import subscription_method_handler
912from replit_river .transport_options import MAX_MESSAGE_BUFFER_SIZE
1013from tests .common_handlers import (
1114 basic_rpc_method ,
1417 basic_upload ,
1518)
1619from tests .conftest import (
20+ HandlerMapping ,
1721 deserialize_error ,
22+ deserialize_request ,
1823 deserialize_response ,
1924 serialize_request ,
25+ serialize_response ,
2026)
2127
2228
@@ -101,6 +107,7 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
101107@pytest .mark .asyncio
102108@pytest .mark .parametrize ("handlers" , [{** basic_subscription }])
103109async def test_subscription_method (client : Client ) -> None :
110+ messages = []
104111 async for response in client .send_subscription (
105112 "test_service" ,
106113 "subscription_method" ,
@@ -110,7 +117,8 @@ async def test_subscription_method(client: Client) -> None:
110117 deserialize_error ,
111118 ):
112119 assert isinstance (response , str )
113- assert "Subscription message" in response
120+ messages .append (response )
121+ assert messages == [f"Subscription message { i } for Bob" for i in range (5 )]
114122
115123
116124@pytest .mark .asyncio
@@ -213,3 +221,52 @@ async def stream_data() -> AsyncGenerator[str, None]:
213221 "Stream response for Stream Data 1" ,
214222 "Stream response for Stream Data 2" ,
215223 ]
224+
225+
226+ async def flood_subscription_handler (
227+ request : str , context : grpc .aio .ServicerContext
228+ ) -> AsyncGenerator [str , None ]:
229+ for i in range (128 * 2 ):
230+ logging .warning (f"sending { i } " )
231+ yield f"Subscription message { i } for { request } "
232+
233+
234+ flood_subscription : HandlerMapping = {
235+ ("test_service" , "flood_subscription_method" ): (
236+ "subscription-stream" ,
237+ subscription_method_handler (
238+ flood_subscription_handler , deserialize_request , serialize_response
239+ ),
240+ ),
241+ }
242+
243+
244+ @pytest .mark .asyncio
245+ @pytest .mark .parametrize ("handlers" , [{** basic_rpc_method , ** flood_subscription }])
246+ async def test_ignore_flood_subscription (client : Client ) -> None :
247+ sub = client .send_subscription (
248+ "test_service" ,
249+ "flood_subscription_method" ,
250+ "Initial Subscription Data" ,
251+ serialize_request ,
252+ deserialize_response ,
253+ deserialize_error ,
254+ )
255+
256+ # read one entry to start the subscription
257+ await sub .__anext__ ()
258+ # close the subscription so we can signal that we're not
259+ # interested in the rest of the subscription.
260+ await sub .aclose ()
261+
262+ # ensure that subsequent RPCs still work
263+ response = await client .send_rpc (
264+ "test_service" ,
265+ "rpc_method" ,
266+ "Alice" ,
267+ serialize_request ,
268+ deserialize_response ,
269+ deserialize_error ,
270+ timedelta (seconds = 20 ),
271+ )
272+ assert response == "Hello, Alice!"
0 commit comments