Skip to content

Commit 5c51b4f

Browse files
Guy AmramGuy Amram
authored andcommitted
added tests and made fixes
1 parent 0e6d4c6 commit 5c51b4f

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

computation_graph/graph_runners.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import inspect
1+
import asyncio
22
from typing import Callable, Union
33

44
import gamla
@@ -78,16 +78,22 @@ def unary_with_state_and_expectations(
7878
def variadic_with_state_and_expectations(g, sink):
7979
f = run.to_callable_strict(g)
8080

81-
async def inner(turns):
82-
prev = {}
83-
for turn, expectation in turns:
84-
if inspect.iscoroutinefunction(f):
81+
if asyncio.iscoroutinefunction(f):
82+
async def inner(turns):
83+
prev = {}
84+
for turn, expectation in turns:
8585
prev = await f(prev, turn)
86-
else:
86+
assert (
87+
prev[graph.make_computation_node(sink)] == expectation
88+
), f"actual={prev[graph.make_computation_node(sink)]}\n expected: {expectation}"
89+
else:
90+
def inner(turns):
91+
prev = {}
92+
for turn, expectation in turns:
8793
prev = f(prev, turn)
88-
assert (
89-
prev[graph.make_computation_node(sink)] == expectation
90-
), f"actual={prev[graph.make_computation_node(sink)]}\n expected: {expectation}"
94+
assert (
95+
prev[graph.make_computation_node(sink)] == expectation
96+
), f"actual={prev[graph.make_computation_node(sink)]}\n expected: {expectation}"
9197

9298
return inner
9399

computation_graph/graph_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,11 @@ def _plus_1(y):
499499
return y + 1
500500

501501

502+
async def _plus_1_async(y):
503+
await asyncio.sleep(0.1)
504+
return y + 1
505+
506+
502507
def _times_2(x):
503508
return x * 2
504509

@@ -576,6 +581,26 @@ def test_compose_future():
576581
)(([[{a: 2, b: 2, c: 2}, 9], [{a: 2, b: 2, c: 2}, 25]]))
577582

578583

584+
def test_compose_future_async():
585+
a = graph.make_source()
586+
b = graph.make_source()
587+
c = graph.make_source()
588+
graph_runners.variadic_with_state_and_expectations(
589+
base_types.merge_graphs(
590+
composers.compose_source_unary(_plus_1_async, c),
591+
composers.compose_source_unary(_times_2, b),
592+
composers.compose_source(_multiply, "a", a),
593+
composers.make_compose_future(
594+
_multiply,
595+
composers.make_and([_plus_1_async, _times_2, _multiply], merge_fn=_sum),
596+
"b",
597+
None,
598+
),
599+
),
600+
_sum,
601+
)(([[{a: 2, b: 2, c: 2}, 9], [{a: 2, b: 2, c: 2}, 25]]))
602+
603+
579604
def test_dont_duplicate_sources():
580605
a = graph.make_source()
581606
assert (

0 commit comments

Comments
 (0)