Skip to content

Commit

Permalink
[Serve] Avoid looping over all snapshot ids for each long poll request (
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshKarpel authored Jul 8, 2024
1 parent 85eb05c commit 6beeacb
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions python/ray/serve/_private/long_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def __init__(
self.key_listeners = key_listeners
self.event_loop = call_in_event_loop
self.snapshot_ids: Dict[KeyType, int] = {
key: -1 for key in self.key_listeners.keys()
# The initial snapshot id for each key is < 0,
# but real snapshot keys in the long poll host are always >= 0,
# so this will always trigger an initial update.
key: -1
for key in self.key_listeners.keys()
}
self.is_running = True

Expand Down Expand Up @@ -191,11 +195,9 @@ def __init__(
] = LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S,
):
# Map object_key -> int
self.snapshot_ids: DefaultDict[KeyType, int] = defaultdict(
lambda: random.randint(0, 1_000_000)
)
self.snapshot_ids: Dict[KeyType, int] = {}
# Map object_key -> object
self.object_snapshots: Dict[KeyType, Any] = dict()
self.object_snapshots: Dict[KeyType, Any] = {}
# Map object_key -> set(asyncio.Event waiting for updates)
self.notifier_events: DefaultDict[KeyType, Set[asyncio.Event]] = defaultdict(
set
Expand Down Expand Up @@ -247,24 +249,32 @@ async def listen_for_change(
immediately if the snapshot_ids are outdated, otherwise it will block
until there's an update.
"""
watched_keys = keys_to_snapshot_ids.keys()
existent_keys = set(watched_keys).intersection(set(self.snapshot_ids.keys()))

# If there are any keys with outdated snapshot ids,
# return their updated values immediately.
updated_objects = {
key: UpdatedObject(self.object_snapshots[key], self.snapshot_ids[key])
for key in existent_keys
if self.snapshot_ids[key] != keys_to_snapshot_ids[key]
}
updated_objects = {}
for key, client_snapshot_id in keys_to_snapshot_ids.items():
try:
existing_id = self.snapshot_ids[key]
except KeyError:
# The caller may ask for keys that we don't know about (yet),
# just ignore them.
# This can happen when, for example,
# a deployment handle is manually created for an app
# that hasn't been deployed yet (by bypassing the safety checks).
continue

if existing_id != client_snapshot_id:
updated_objects[key] = UpdatedObject(
self.object_snapshots[key], existing_id
)
if len(updated_objects) > 0:
self._count_send(updated_objects)
return updated_objects

# Otherwise, register asyncio events to be waited.
async_task_to_events = {}
async_task_to_watched_keys = {}
for key in watched_keys:
for key in keys_to_snapshot_ids.keys():
# Create a new asyncio event for this key.
event = asyncio.Event()

Expand Down Expand Up @@ -398,10 +408,16 @@ def notify_changed(
object_key: KeyType,
updated_object: Any,
):
self.snapshot_ids[object_key] += 1
try:
self.snapshot_ids[object_key] += 1
except KeyError:
# Initial snapshot id must be >= 0, so that the long poll client
# can send a negative initial snapshot id to get a fast update.
# They should also be randomized;
# see https://github.com/ray-project/ray/pull/45881#discussion_r1645243485
self.snapshot_ids[object_key] = random.randint(0, 1_000_000)
self.object_snapshots[object_key] = updated_object
logger.debug(f"LongPollHost: Notify change for key {object_key}.")

if object_key in self.notifier_events:
for event in self.notifier_events.pop(object_key):
event.set()
for event in self.notifier_events.pop(object_key, set()):
event.set()

0 comments on commit 6beeacb

Please sign in to comment.