Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel server progress from the python client #8245

Merged
merged 7 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP using events
  • Loading branch information
freddyaboulton committed May 10, 2024
commit ee59f92e8730b3ac72eddea674ab48ff3231ce7a
31 changes: 23 additions & 8 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,13 +503,16 @@ def submit(
>> 9.0
"""
inferred_fn_index = self._infer_fn_index(api_name, fn_index)
cancellable = False
cancel_fn_index = None
candidates: list[tuple[int, list[int]]] = []
for i, dep in enumerate(self.config["dependencies"]):
if inferred_fn_index in dep["cancels"]:
cancellable = True
cancel_fn_index = i
break
candidates.append(
(i, [d for d in dep["cancels"] if d != inferred_fn_index])
)
cancel_fn_index, other_cancelled = (
min(candidates, key=lambda x: len(x[1])) if candidates else (None, None)
)

endpoint = self.endpoints[inferred_fn_index]

Expand All @@ -530,8 +533,18 @@ def submit(
future = self.executor.submit(end_to_end_fn, *args)

cancel_fn = None
if cancellable:
cancel_fn = endpoint.make_cancel(cancel_fn_index)
cancel_msg = None
if other_cancelled:
other_api_names = [
"/" + self.config["dependencies"][i].get("api_name")
for i in other_cancelled
]
cancel_msg = (
f"Cancelled this job will also cancel any jobs for {', '.join(other_api_names)} "
"that are currently running."
)
if cancel_fn_index is not None and isinstance(endpoint, Endpoint):
cancel_fn = endpoint.make_cancel(cancel_fn_index, cancel_msg)

job = Job(
future,
Expand Down Expand Up @@ -1127,8 +1140,10 @@ def _inner(*data):

return _inner

def make_cancel(self, fn_index):
def make_cancel(self, fn_index: int, cancel_msg: str | None):
def _cancel():
if cancel_msg:
warnings.warn(cancel_msg)
data = {
"data": [],
"fn_index": fn_index,
Expand Down Expand Up @@ -1379,7 +1394,7 @@ def __init__(
communicator: Communicator | None = None,
verbose: bool = True,
space_id: str | None = None,
_cancel_fn: Callable[None, None] = None,
_cancel_fn: Callable[[], None] | None = None,
):
"""
Parameters:
Expand Down
9 changes: 9 additions & 0 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,7 @@ def iterate_over_children(children_list):
_targets = dependency.pop("targets")
trigger = dependency.pop("trigger", None)
is_then_event = False
is_internal_cancel = False

# This assumes that you cannot combine multiple .then() events in a single
# gr.on() event, which is true for now. If this changes, we will need to
Expand All @@ -999,6 +1000,11 @@ def iterate_over_children(children_list):
"This logic assumes that .then() events are not combined with other events in a single gr.on() event"
)
is_then_event = True
if (
not isinstance(_targets[0], int)
and _targets[0][1] == "cancel_internal"
):
is_internal_cancel = True

dependency.pop("backend_fn")
dependency.pop("documentation", None)
Expand All @@ -1019,6 +1025,9 @@ def iterate_over_children(children_list):
"trigger_only_on_success"
)
dependency["no_target"] = True
elif is_internal_cancel:
targets = [EventListenerMethod(None, "cancel_internal")]
dependency["no_target"] = True
else:
targets = [
getattr(
Expand Down
11 changes: 10 additions & 1 deletion gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,16 @@ def inner(*args, **kwargs):
)
if _callback:
_callback(block)
return Dependency(block, dep.get_config(), dep_index, fn)
dep_obj = Dependency(block, dep.get_config(), dep_index, fn)
if queue is not False and fn is not None:
try:
set_cancel_events(
[EventListenerMethod(None, "cancel_internal")], [dep_obj]
)
# For invalid blocks case
except KeyError:
pass
return dep_obj

event_trigger.event_name = _event_name
event_trigger.has_trigger = _has_trigger
Expand Down
35 changes: 18 additions & 17 deletions test/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def greet(name, formatter):
o = gr.Textbox()
t.change(greet_upper_case, t, o)

assert len(demo.fns) == 1
# Cancel function automatically added
assert len(demo.fns) == 2
assert "fn" in str(demo.fns[0])

@pytest.mark.asyncio
Expand Down Expand Up @@ -242,7 +243,7 @@ def iteration(count: int):

with connect(demo) as client:
job_1 = client.submit(3, fn_index=0)
job_2 = client.submit(4, fn_index=1)
job_2 = client.submit(4, fn_index=2)
wait([job_1, job_2])

assert job_1.outputs()[-1] == 2
Expand Down Expand Up @@ -347,13 +348,13 @@ def generator_function():
demo.load(continuous_fn, inputs=None, outputs=[meaning_of_life], every=1)

for i, dependency in enumerate(demo.config["dependencies"]):
if i == 3:
assert dependency["types"] == {"continuous": True, "generator": True}
if i == 5:
assert dependency["types"] == {"continuous": False, "generator": False}
if i == 0:
assert dependency["types"] == {"continuous": False, "generator": False}
if i == 1:
assert dependency["types"] == {"continuous": False, "generator": True}
if i == 2:
assert dependency["types"] == {"continuous": False, "generator": True}
if i == 4:
assert dependency["types"] == {"continuous": True, "generator": True}

@patch(
Expand Down Expand Up @@ -885,9 +886,9 @@ async def test_call_multiple_functions(self):
output = demo("World")
assert output == "Hello, World"

output = await demo.call_function(1, ["World"])
output = await demo.call_function(2, ["World"])
assert output["prediction"] == "Hi, World"
output = demo("World", fn_index=1) # fn_index must be a keyword argument
output = demo("World", fn_index=2) # fn_index must be a keyword argument
assert output == "Hi, World"

@pytest.mark.asyncio
Expand All @@ -903,7 +904,7 @@ def test(x):

output = await demo.call_function(0, ["Adam"])
assert output["prediction"] == "Hello Adam"
output = await demo.call_function(1, ["Adam"])
output = await demo.call_function(2, ["Adam"])
assert output["prediction"] == "Hello Adam"

@pytest.mark.asyncio
Expand Down Expand Up @@ -961,16 +962,16 @@ def generator(x):
output = await demo.call_function(0, [-1])
assert output["prediction"] == -2

output = await demo.call_function(1, [3])
output = await demo.call_function(2, [3])
assert output["prediction"] == (0, 3)
output = await demo.call_function(1, [3], iterator=output["iterator"])
output = await demo.call_function(2, [3], iterator=output["iterator"])
assert output["prediction"] == (1, 3)
output = await demo.call_function(1, [3], iterator=output["iterator"])
output = await demo.call_function(2, [3], iterator=output["iterator"])
assert output["prediction"] == (2, 3)
output = await demo.call_function(1, [3], iterator=output["iterator"])
output = await demo.call_function(2, [3], iterator=output["iterator"])
assert output["prediction"] == (gr.components._Keywords.FINISHED_ITERATING,) * 2
assert output["iterator"] is None
output = await demo.call_function(1, [3], iterator=output["iterator"])
output = await demo.call_function(2, [3], iterator=output["iterator"])
assert output["prediction"] == (0, 3)


Expand Down Expand Up @@ -1053,12 +1054,12 @@ def batch_fn(words, lengths):
output = demo("Abubakar", "Abid")
assert output

output = await demo.call_function(1, [["Adam", "Mary"], [3, 5]])
output = await demo.call_function(2, [["Adam", "Mary"], [3, 5]])
assert output["prediction"] == (
["Ada", "Mary"],
[True, False],
)
output = demo("Abubakar", 3, fn_index=1)
output = demo("Abubakar", 3, fn_index=2)
assert output == ["Abu", True]

@pytest.mark.asyncio
Expand Down Expand Up @@ -1143,7 +1144,7 @@ async def test_accordion_update(self):
"__type__": "update",
}
result = await demo.process_api(
fn_index=1, inputs=[None], request=None, state=None
fn_index=2, inputs=[None], request=None, state=None
)
assert result["data"][0] == {
"open": False,
Expand Down