diff --git a/python/ray/workflow/examples/comparisons/prefect/compute_fib_workflow.py b/python/ray/workflow/examples/comparisons/prefect/compute_fib_workflow.py index b6ae0fed72e16..943d68cf9d093 100644 --- a/python/ray/workflow/examples/comparisons/prefect/compute_fib_workflow.py +++ b/python/ray/workflow/examples/comparisons/prefect/compute_fib_workflow.py @@ -1,31 +1,36 @@ +import tempfile + import ray from ray import workflow -import requests +from ray.actor import ActorHandle + +@ray.remote +class FibonacciActor: + def __init__(self): + self.cache = {} -def fibonacci(n): - assert n > 0 - a, b = 0, 1 - for _ in range(n - 1): - a, b = b, a + b - return b + def compute(self, n): + if n not in self.cache: + assert n > 0 + a, b = 0, 1 + for _ in range(n - 1): + a, b = b, a + b + self.cache[n] = b + return self.cache[n] @ray.remote -def compute_large_fib(M: int, n: int = 1, fib: int = 1): - try: - next_fib = requests.post( - "https://nemo.api.stdlib.com/fibonacci@0.0.1/", data={"nth": n} - ).json() - assert isinstance(next_fib, int) - except AssertionError: - # TODO(suquark): The web service would fail sometimes. This is a workaround. - next_fib = fibonacci(n) +def compute_large_fib(fibonacci_actor: ActorHandle, M: int, n: int = 1, fib: int = 1): + next_fib = ray.get(fibonacci_actor.compute.remote(n)) if next_fib > M: return fib else: - return workflow.continuation(compute_large_fib.bind(M, n + 1, next_fib)) + return workflow.continuation( + compute_large_fib.bind(fibonacci_actor, M, n + 1, next_fib) + ) if __name__ == "__main__": - assert workflow.run(compute_large_fib.bind(100)) == 89 + ray.init(storage=f"file://{tempfile.TemporaryDirectory().name}") + assert workflow.run(compute_large_fib.bind(FibonacciActor.remote(), 100)) == 89