Skip to content

Commit

Permalink
Refactor Cancelling Logic To Use /cancel (#8370)
Browse files Browse the repository at this point in the history
* Cancel refactor

* add changeset

* add changeset

* types

* Add code

* Fix types

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot authored Jun 5, 2024
1 parent 96d8de2 commit 48eeea4
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 85 deletions.
7 changes: 7 additions & 0 deletions .changeset/deep-weeks-show.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/app": patch
"@gradio/client": patch
"gradio": patch
---

feat:Refactor Cancelling Logic To Use /cancel
2 changes: 1 addition & 1 deletion client/js/src/helpers/api_info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ export function transform_api_info(
dependencyIndex !== -1
? config.dependencies.find((dep) => dep.id == dependencyIndex)
?.types
: { continuous: false, generator: false };
: { continuous: false, generator: false, cancel: false };

if (
dependencyIndex !== -1 &&
Expand Down
13 changes: 8 additions & 5 deletions client/js/src/test/test_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export const transformed_api_info: ApiInfo<ApiData> = {
component: "Textbox"
}
],
type: { continuous: false, generator: false }
type: { continuous: false, generator: false, cancel: false }
}
},
unnamed_endpoints: {
Expand All @@ -68,7 +68,7 @@ export const transformed_api_info: ApiInfo<ApiData> = {
component: "Textbox"
}
],
type: { continuous: false, generator: false }
type: { continuous: false, generator: false, cancel: false }
}
}
};
Expand Down Expand Up @@ -395,7 +395,8 @@ export const config_response: Config = {
cancels: [],
types: {
continuous: false,
generator: false
generator: false,
cancel: false
},
collects_event_data: false,
trigger_after: null,
Expand All @@ -421,7 +422,8 @@ export const config_response: Config = {
cancels: [],
types: {
continuous: false,
generator: false
generator: false,
cancel: false
},
collects_event_data: false,
trigger_after: null,
Expand All @@ -447,7 +449,8 @@ export const config_response: Config = {
cancels: [],
types: {
continuous: false,
generator: false
generator: false,
cancel: false
},
collects_event_data: false,
trigger_after: null,
Expand Down
1 change: 1 addition & 0 deletions client/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ export interface Dependency {
export interface DependencyTypes {
continuous: boolean;
generator: boolean;
cancel: boolean;
}

export interface Payload {
Expand Down
16 changes: 13 additions & 3 deletions client/js/src/utils/submit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ export function submit(
fn_index: fn_index
});

let reset_request = {};
let cancel_request = {};
if (protocol === "ws") {
if (websocket && websocket.readyState === 0) {
Expand All @@ -131,21 +132,30 @@ export function submit(
} else {
websocket.close();
}
cancel_request = { fn_index, session_hash };
reset_request = { fn_index, session_hash };
} else {
stream?.close();
cancel_request = { event_id };
reset_request = { event_id };
cancel_request = { event_id, session_hash, fn_index };
}

try {
if (!config) {
throw new Error("Could not resolve app config");
}

if ("event_id" in cancel_request) {
await fetch(`${config.root}/cancel`, {
headers: { "Content-Type": "application/json" },
method: "POST",
body: JSON.stringify(cancel_request)
});
}

await fetch(`${config.root}/reset`, {
headers: { "Content-Type": "application/json" },
method: "POST",
body: JSON.stringify(cancel_request)
body: JSON.stringify(reset_request)
});
} catch (e) {
console.warn(
Expand Down
24 changes: 6 additions & 18 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
TupleNoPrint,
check_function_inputs_match,
component_or_layout_class,
get_cancel_function,
get_cancelled_fn_indices,
get_continuous_fn,
get_package_version,
get_upload_folder,
Expand Down Expand Up @@ -541,12 +541,7 @@ def __init__(
self.rendered_in = rendered_in

# We need to keep track of which events are cancel events
# in two places:
# 1. So that we can skip postprocessing for cancel events.
# They return event_ids that have been cancelled but there
# are no output components
# 2. So that we can place the ProcessCompletedMessage in the
# event stream so that clients can close the stream when necessary
# so that the client can call the /cancel route directly
self.is_cancel_function = is_cancel_function

self.spaces_auto_wrap()
Expand Down Expand Up @@ -589,6 +584,7 @@ def get_config(self):
"types": {
"continuous": self.types_continuous,
"generator": self.types_generator,
"cancel": self.is_cancel_function,
},
"collects_event_data": self.collects_event_data,
"trigger_after": self.trigger_after,
Expand Down Expand Up @@ -1377,7 +1373,7 @@ def render(self):
updated_cancels = [
root_context.fns[i].get_config() for i in dependency.cancels
]
dependency.fn = get_cancel_function(updated_cancels)[0]
dependency.cancels = get_cancelled_fn_indices(updated_cancels)
root_context.fns[root_context.fn_id] = dependency
root_context.fn_id += 1
Context.root_block.temp_file_sets.extend(self.temp_file_sets)
Expand Down Expand Up @@ -1694,17 +1690,9 @@ async def postprocess_data(
block_fn: BlockFunction,
predictions: list | dict,
state: SessionState | None,
) -> Any:
) -> list[Any]:
state = state or SessionState(self)

# If the function is a cancel function, 'predictions' are the ids of
# the event in the queue that has been cancelled. We need these
# so that the server can put the ProcessCompleted message in the event stream
# Cancel events have no output components, so we need to return early otherise the output
# be None.
if block_fn.is_cancel_function:
return predictions

if isinstance(predictions, dict) and len(predictions) > 0:
predictions = convert_component_dict_to_list(
[block._id for block in block_fn.outputs], predictions
Expand Down Expand Up @@ -1920,7 +1908,7 @@ async def process_api(
for o in zip(*preds)
]
if root_path is not None:
data = processing_utils.add_root_url(data, root_path, None)
data = processing_utils.add_root_url(data, root_path, None) # type: ignore
data = list(zip(*data))
is_generating, iterator = None, None
else:
Expand Down
6 changes: 3 additions & 3 deletions gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from gradio.blocks import Block, Component

from gradio.context import get_blocks_context
from gradio.utils import get_cancel_function
from gradio.utils import get_cancelled_fn_indices


def set_cancel_events(
Expand All @@ -36,15 +36,15 @@ def set_cancel_events(
if cancels:
if not isinstance(cancels, list):
cancels = [cancels]
cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
fn_indices_to_cancel = get_cancelled_fn_indices(cancels)

root_block = get_blocks_context()
if root_block is None:
raise AttributeError("Cannot cancel outside of a gradio.Blocks context.")

root_block.set_event_trigger(
triggers,
cancel_fn,
fn=None,
inputs=None,
outputs=None,
queue=False,
Expand Down
34 changes: 14 additions & 20 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,8 @@ async def file_deprecated(path: str, request: fastapi.Request):

@app.post("/reset/")
@app.post("/reset")
async def reset_iterator(body: ResetBody):
if body.event_id not in app.iterators:
return {"success": False}
async with app.lock:
del app.iterators[body.event_id]
app.iterators_to_reset.add(body.event_id)
await app.get_blocks()._queue.clean_events(event_id=body.event_id)
async def reset_iterator(body: ResetBody): # noqa: ARG001
# No-op, all the cancelling/reset logic handled by /cancel
return {"success": True}

@app.get("/heartbeat/{session_hash}")
Expand Down Expand Up @@ -739,18 +734,6 @@ async def predict(
fn=fn,
root_path=root_path,
)
if fn.is_cancel_function:
# Need to complete the job so that the client disconnects
blocks = app.get_blocks()
if body.session_hash in blocks._queue.pending_messages_per_session:
for event_id in output["data"]:
message = ProcessCompletedMessage(
output={}, success=True, event_id=event_id
)
blocks._queue.pending_messages_per_session[ # type: ignore
body.session_hash
].put_nowait(message)

except BaseException as error:
show_error = app.get_blocks().show_error or isinstance(error, Error)
traceback.print_exc()
Expand Down Expand Up @@ -823,13 +806,24 @@ async def cancel_event(body: CancelBody):
await cancel_tasks({f"{body.session_hash}_{body.fn_index}"})
blocks = app.get_blocks()
# Need to complete the job so that the client disconnects
if body.session_hash in blocks._queue.pending_messages_per_session:
session_open = (
body.session_hash in blocks._queue.pending_messages_per_session
)
event_running = (
body.event_id
in blocks._queue.pending_event_ids_session.get(body.session_hash, {})
)
if session_open and event_running:
message = ProcessCompletedMessage(
output={}, success=True, event_id=body.event_id
)
blocks._queue.pending_messages_per_session[
body.session_hash
].put_nowait(message)
if body.event_id in app.iterators:
async with app.lock:
del app.iterators[body.event_id]
app.iterators_to_reset.add(body.event_id)
return {"success": True}

@app.get("/call/{api_name}/{event_id}", dependencies=[Depends(login_check)])
Expand Down
20 changes: 6 additions & 14 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def after_fn():

async def cancel_tasks(task_ids: set[str]) -> list[str]:
tasks = [(task, task.get_name()) for task in asyncio.all_tasks()]
event_ids = []
event_ids: list[str] = []
matching_tasks = []
for task, name in tasks:
if "<gradio-sep>" not in name:
Expand All @@ -891,27 +891,19 @@ def set_task_name(task, session_hash: str, fn_index: int, event_id: str, batch:
task.set_name(f"{session_hash}_{fn_index}<gradio-sep>{event_id}")


def get_cancel_function(
def get_cancelled_fn_indices(
dependencies: list[dict[str, Any]],
) -> tuple[Callable, list[int]]:
fn_to_comp = {}
) -> list[int]:
fn_indices = []
for dep in dependencies:
root_block = get_blocks_context()
if root_block:
fn_index = next(
i for i, d in root_block.fns.items() if d.get_config() == dep
)
fn_to_comp[fn_index] = [root_block.blocks[o] for o in dep["outputs"]]
fn_indices.append(fn_index)

async def cancel(session_hash: str) -> list[str]:
task_ids = {f"{session_hash}_{fn}" for fn in fn_to_comp}
event_ids = await cancel_tasks(task_ids)
return event_ids

return (
cancel,
list(fn_to_comp.keys()),
)
return fn_indices


def get_type_hints(fn):
Expand Down
17 changes: 8 additions & 9 deletions js/app/src/Blocks.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,6 @@
const current_status = loading_status.get_status_for_fn(dep_index);
messages = messages.filter(({ fn_index }) => fn_index !== dep_index);
if (dep.cancels) {
await Promise.all(
dep.cancels.map(async (fn_index) => {
const submission = submit_map.get(fn_index);
submission?.cancel();
return submission;
})
);
}
if (current_status === "pending" || current_status === "generating") {
dep.pending_request = true;
}
Expand All @@ -242,6 +233,14 @@
handle_update(v, dep_index);
}
});
} else if (dep.types.cancel && dep.cancels) {
await Promise.all(
dep.cancels.map(async (fn_index) => {
const submission = submit_map.get(fn_index);
submission?.cancel();
return submission;
})
);
} else {
if (dep.backend_fn) {
if (dep.trigger_mode === "once") {
Expand Down
1 change: 1 addition & 0 deletions js/app/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export interface ComponentMeta {
export interface DependencyTypes {
continuous: boolean;
generator: boolean;
cancel: boolean;
}

/** An event payload that is sent with an API request */
Expand Down
Loading

0 comments on commit 48eeea4

Please sign in to comment.