3030import time
3131
3232from restate .context import DurablePromise , AttemptFinishedEvent , HandlerType , ObjectContext , Request , RestateDurableCallFuture , RestateDurableFuture , RunAction , SendHandle , RestateDurableSleepFuture , RunOptions , P
33- from restate .exceptions import TerminalError
33+ from restate .exceptions import TerminalError , SdkInternalBaseException , SdkInternalException , SuspendedException
3434from restate .handler import Handler , handler_from_callable , invoke_handler
3535from restate .serde import BytesSerde , DefaultSerde , JsonSerde , Serde
3636from restate .server_types import ReceiveChannel , Send
@@ -273,19 +273,6 @@ def update_restate_context_is_replaying(vm: VMWrapper):
273273 """Update the context var 'restate_context_is_replaying'. This should be called after each vm.sys_*"""
274274 restate_context_is_replaying .set (vm .is_replaying ())
275275
276- async def cancel_current_task ():
277- """Cancel the current task"""
278- current_task = asyncio .current_task ()
279- if current_task is not None :
280- # Cancel through asyncio API
281- current_task .cancel (
282- "Cancelled by Restate SDK, you should not call any Context method after this exception is thrown."
283- )
284- # Sleep 0 will pop up the cancellation
285- await asyncio .sleep (0 )
286- else :
287- raise asyncio .CancelledError ("Cancelled by Restate SDK, you should not call any Context method after this exception is thrown." )
288-
289276# pylint: disable=R0902
290277class ServerInvocationContext (ObjectContext ):
291278 """This class implements the context for the restate framework based on the server."""
@@ -327,6 +314,8 @@ async def enter(self):
327314 # pylint: disable=W0718
328315 except asyncio .CancelledError :
329316 pass
317+ except SdkInternalBaseException :
318+ pass
330319 except DisconnectedException :
331320 raise
332321 except Exception as e :
@@ -393,11 +382,11 @@ async def must_take_notification(self, handle):
393382 await self .take_and_send_output ()
394383 # Print this exception, might be relevant for the user
395384 traceback .print_exception (res )
396- await cancel_current_task ()
385+ raise SdkInternalException () from res
397386 if isinstance (res , Suspended ):
398387 # We might need to write out something at this point.
399388 await self .take_and_send_output ()
400- await cancel_current_task ()
389+ raise SuspendedException ()
401390 if isinstance (res , NotReady ):
402391 raise ValueError (f"Unexpected value error: { handle } " )
403392 if res is None :
@@ -414,9 +403,9 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
414403 if isinstance (do_progress_response , BaseException ):
415404 # Print this exception, might be relevant for the user
416405 traceback .print_exception (do_progress_response )
417- await cancel_current_task ()
406+ raise SdkInternalException () from do_progress_response
418407 if isinstance (do_progress_response , Suspended ):
419- await cancel_current_task ()
408+ raise SuspendedException ()
420409 if isinstance (do_progress_response , DoProgressAnyCompleted ):
421410 # One of the handles completed
422411 return
@@ -565,6 +554,8 @@ async def create_run_coroutine(self,
565554 self .vm .propose_run_completion_failure (handle , failure )
566555 except asyncio .CancelledError as e :
567556 raise e from None
557+ except SdkInternalBaseException as e :
558+ raise e from None
568559 # pylint: disable=W0718
569560 except Exception as e :
570561 end = time .time ()
0 commit comments