Skip to content

Commit 39597c8

Browse files
committed
more type hints around durable promise
1 parent d3f9b9c commit 39597c8

File tree

4 files changed

+16
-5
lines changed

4 files changed

+16
-5
lines changed

examples/virtual_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
@counter.handler()
2020
async def increment(ctx: ObjectContext, value: int) -> int:
21-
n = await ctx.get("counter") or 0
21+
n = await ctx.get("counter", type_hint=int) or 0
2222
n += value
2323
ctx.set("counter", n)
2424
return n

examples/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ def payment_gateway():
5656

5757
@payment.handler()
5858
async def payment_verified(ctx: WorkflowSharedContext, result: str):
59-
promise = ctx.promise("verify.payment")
59+
promise = ctx.promise("verify.payment", type_hint=str)
6060
await promise.resolve(result)

python/restate/context.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,13 +438,19 @@ def value(self) -> RestateDurableFuture[T]:
438438
Returns the value of the promise if it is resolved, None otherwise.
439439
"""
440440

441+
@abc.abstractmethod
442+
def __await__(self) -> typing.Generator[Any, Any, T]:
443+
"""
444+
Returns the value of the promise. This is a shortcut for calling value() and awaiting it.
445+
"""
446+
441447
class WorkflowContext(ObjectContext):
442448
"""
443449
Represents the context of the current workflow invocation.
444450
"""
445451

446452
@abc.abstractmethod
447-
def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[T]:
453+
def promise(self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None) -> DurablePromise[T]:
448454
"""
449455
Returns a durable promise with the given name.
450456
"""
@@ -455,7 +461,7 @@ class WorkflowSharedContext(ObjectSharedContext):
455461
"""
456462

457463
@abc.abstractmethod
458-
def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[T]:
464+
def promise(self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None) -> DurablePromise[T]:
459465
"""
460466
Returns a durable promise with the given name.
461467
"""

python/restate/server_context.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def peek(self) -> Awaitable[Any | None]:
216216
assert serde is not None
217217
return self.server_context.create_future(handle, serde)
218218

219+
def __await__(self):
220+
return self.value().__await__()
221+
219222

220223
# disable too many public method
221224
# pylint: disable=R0904
@@ -684,8 +687,10 @@ def resolve_awakeable(self,
684687
def reject_awakeable(self, name: str, failure_message: str, failure_code: int = 500) -> None:
685688
return self.vm.sys_reject_awakeable(name, Failure(code=failure_code, message=failure_message))
686689

687-
def promise(self, name: str, serde: typing.Optional[Serde[T]] = JsonSerde()) -> DurablePromise[Any]:
690+
def promise(self, name: str, serde: typing.Optional[Serde[T]] = JsonSerde(), type_hint: Optional[typing.Type[T]] = None) -> DurablePromise[T]:
688691
"""Create a durable promise."""
692+
if isinstance(serde, DefaultSerde):
693+
serde = serde.with_maybe_type(type_hint)
689694
return ServerDurablePromise(self, name, serde)
690695

691696
def key(self) -> str:

0 commit comments

Comments
 (0)