Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
Beam now supports data enrichment capabilities using SQL databases, with built-in support for:
- Managed PostgreSQL, MySQL, and Microsoft SQL Server instances on CloudSQL
- Unmanaged SQL database instances not hosted on CloudSQL (e.g., self-hosted or on-premises databases)
* [Python] Added the `ReactiveThrottler` and `ThrottlingSignaler` classes to streamline throttling behavior in DoFns, expose throttling mechanisms for users ([#35984](https://github.com/apache/beam/pull/35984))
* Added a pipeline option to specify the processing timeout for a single element by any PTransform (Java/Python/Go) ([#35174](https://github.com/apache/beam/issues/35174)).
- When specified, the SDK harness automatically restarts if an element takes too long to process. Beam runner may then retry processing of the same work item.
- Use the `--element_processing_timeout_minutes` option to reduce the chance of having stalled pipelines due to unexpected cases of slow processing, where slowness might not happen again if processing of the same element is retried.
Expand Down
92 changes: 92 additions & 0 deletions sdks/python/apache_beam/io/components/adaptive_throttler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,32 @@

# pytype: skip-file

import logging
import random
import time

from apache_beam.io.components import util
from apache_beam.metrics.metric import Metrics

_SECONDS_TO_MILLISECONDS = 1_000


class ThrottlingSignaler(object):
"""A class that handles signaling throttling of remote requests to the
SDK harness.
"""
def __init__(self, namespace: str = ""):
self.throttling_metric = Metrics.counter(
namespace, "cumulativeThrottlingSeconds")

def signal_throttled(self, seconds: int):
"""Signals to the runner that requests have been throttled for some amount
of time.

Args:
seconds: int, duration of throttling in seconds.
"""
self.throttling_metric.inc(seconds)


class AdaptiveThrottler(object):
Expand Down Expand Up @@ -94,3 +117,72 @@ def successful_request(self, now):
now: int, time in ms since the epoch
"""
self._successful_requests.add(now, 1)


class ReactiveThrottler(AdaptiveThrottler):
""" A wrapper around the AdaptiveThrottler that also handles logging and
signaling throttling to the SDK harness using the provided namespace.

For usage, instantiate one instance of a ReactiveThrottler class for a
PTransform. When making remote calls to a service, preface that call with
the throttle() method to potentially pre-emptively throttle the request.
This will throttle future calls based on the failure rate of preceding calls,
with higher failure rates leading to longer periods of throttling to allow
system recovery. capture the timestamp of the attempted request, then execute
the request code. On a success, call successful_request(timestamp) to report
the success to the throttler. This flow looks like the following:

def remote_call():
throttler.throttle()

try:
timestamp = time.time()
result = make_request()
throttler.successful_request(timestamp)
return result
except Exception as e:
# do any error handling you want to do
raise
"""
def __init__(
self,
window_ms: int,
bucket_ms: int,
overload_ratio: float,
namespace: str = '',
throttle_delay_secs: int = 5):
"""Initializes the ReactiveThrottler.

Args:
window_ms: int, length of history to consider, in ms, to set
throttling.
bucket_ms: int, granularity of time buckets that we store data in, in
ms.
overload_ratio: float, the target ratio between requests sent and
successful requests. This is "K" in the formula in
https://landing.google.com/sre/book/chapters/handling-overload.html.
namespace: str, the namespace to use for logging and signaling
throttling is occurring
throttle_delay_secs: int, the amount of time in seconds to wait
after preemptively throttled requests
"""
self.throttling_signaler = ThrottlingSignaler(namespace=namespace)
self.logger = logging.getLogger(namespace)
self.throttle_delay_secs = throttle_delay_secs
super().__init__(
window_ms=window_ms, bucket_ms=bucket_ms, overload_ratio=overload_ratio)

def throttle(self):
""" Stops request code from advancing while the underlying
AdaptiveThrottler is signaling to preemptively throttle the request.
Automatically handles logging the throttling and signaling to the SDK
harness that the request is being throttled. This should be called in any
context where a call to a remote service is being contacted prior to the
call being performed.
"""
while self.throttle_request(time.time() * _SECONDS_TO_MILLISECONDS):
self.logger.info(
"Delaying request for %d seconds due to previous failures",
self.throttle_delay_secs)
time.sleep(self.throttle_delay_secs)
self.throttling_signaler.signal_throttled(self.throttle_delay_secs)
43 changes: 19 additions & 24 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@
from typing import Union

import apache_beam as beam
from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
from apache_beam.metrics.metric import Metrics
from apache_beam.io.components.adaptive_throttler import ReactiveThrottler
from apache_beam.utils import multi_process_shared
from apache_beam.utils import retry
from apache_beam.utils import shared
Expand Down Expand Up @@ -354,14 +353,16 @@ def __init__(
window_ms: int = 1 * _MILLISECOND_TO_SECOND,
bucket_ms: int = 1 * _MILLISECOND_TO_SECOND,
overload_ratio: float = 2):
"""Initializes metrics tracking + an AdaptiveThrottler class for enabling
client-side throttling for remote calls to an inference service.
"""Initializes a ReactiveThrottler class for enabling
client-side throttling for remote calls to an inference service. Also wraps
provided calls to the service with retry logic.

See https://s.apache.org/beam-client-side-throttling for more details
on the configuration of the throttling and retry
mechanics.

Args:
namespace: the metrics and logging namespace
namespace: the metrics and logging namespace
num_retries: the maximum number of times to retry a request on retriable
errors before failing
throttle_delay_secs: the amount of time to throttle when the client-side
Expand All @@ -372,19 +373,18 @@ def __init__(
window_ms: length of history to consider, in ms, to set throttling.
bucket_ms: granularity of time buckets that we store data in, in ms.
overload_ratio: the target ratio between requests sent and successful
requests. This is "K" in the formula in
requests. This is "K" in the formula in
https://landing.google.com/sre/book/chapters/handling-overload.html.
"""
# Configure AdaptiveThrottler and throttling metrics for client-side
# throttling behavior.
self.throttled_secs = Metrics.counter(
namespace, "cumulativeThrottlingSeconds")
self.throttler = AdaptiveThrottler(
window_ms=window_ms, bucket_ms=bucket_ms, overload_ratio=overload_ratio)
# Configure ReactiveThrottler for client-side throttling behavior.
self.throttler = ReactiveThrottler(
window_ms=window_ms,
bucket_ms=bucket_ms,
overload_ratio=overload_ratio,
namespace=namespace,
throttle_delay_secs=throttle_delay_secs)
self.logger = logging.getLogger(namespace)

self.num_retries = num_retries
self.throttle_delay_secs = throttle_delay_secs
self.retry_filter = retry_filter

def __init_subclass__(cls):
Expand Down Expand Up @@ -434,12 +434,7 @@ def run_inference(
Returns:
An Iterable of Predictions.
"""
while self.throttler.throttle_request(time.time() * _MILLISECOND_TO_SECOND):
self.logger.info(
"Delaying request for %d seconds due to previous failures",
self.throttle_delay_secs)
time.sleep(self.throttle_delay_secs)
self.throttled_secs.inc(self.throttle_delay_secs)
self.throttler.throttle()

try:
req_time = time.time()
Expand Down Expand Up @@ -1642,7 +1637,7 @@ def next_model_index(self, num_models):

class _ModelStatus():
"""A class holding any metadata about a model required by RunInference.

Currently, this only includes whether or not the model is valid. Uses the
model tag to map models to metadata.
"""
Expand All @@ -1656,7 +1651,7 @@ def __init__(self, share_model_across_processes: bool):

def try_mark_current_model_invalid(self, min_model_life_seconds):
"""Mark the current model invalid.

Since we don't have sufficient information to say which model is being
marked invalid, but there may be multiple active models, we will mark all
models currently in use as inactive so that they all get reloaded. To
Expand All @@ -1678,7 +1673,7 @@ def try_mark_current_model_invalid(self, min_model_life_seconds):

def get_valid_tag(self, tag: str) -> str:
"""Takes in a proposed valid tag and returns a valid one.

Will always return a valid tag. If the passed in tag is valid, this
function will simply return it, otherwise it will deterministically
generate a new tag to use instead. The new tag will be the original tag
Expand Down Expand Up @@ -1747,7 +1742,7 @@ def load_model_status(

class _SharedModelWrapper():
"""A router class to map incoming calls to the correct model.

This allows us to round robin calls to models sitting in different
processes so that we can more efficiently use resources (e.g. GPUs).
"""
Expand Down
Loading