Skip to content

Commit

Permalink
fix(components): GCP - Fixed the BrokenPipe error. Fixes #5746 (#5760)
Browse files Browse the repository at this point in the history
* Components - GCP - Fixed the BrokenPipe error

* Addressed review feedback and fixed bugs

* Added missing import

* Fixed syntax error in the error logging code
  • Loading branch information
Ark-kun committed Jul 30, 2021
1 parent 8a5e067 commit 5bdb5e5
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.

from ._utils import (normalize_name, dump_file,
check_resource_changed, wait_operation_done)
check_resource_changed, wait_operation_done, ClientWithRetries)
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import logging
import re
import os
import time
from functools import wraps
from typing import Any, Callable, Optional, Tuple

def normalize_name(name,
valid_first_char_pattern='a-zA-Z',
Expand Down Expand Up @@ -120,3 +123,55 @@ def wait_operation_done(get_operation, wait_interval):
))
return operation


def with_retries(
func: Callable,
on_error: Optional[Callable[[], Any]] = None,
errors: Tuple[Exception, ...] = Exception,
number_of_retries: int = 5,
delay: float = 1,
):
"""Retry decorator.
The decorator catches `errors`, calls `on_error` and retries after waiting `delay` seconds.
Args:
number_of_retries (int): Total number of retries if error is raised.
delay (float): Number of seconds to wait between consecutive retries.
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
remaining_retries = number_of_retries
while remaining_retries:
try:
return func(self, *args, **kwargs)
except errors as e:
remaining_retries -= 1
if not remaining_retries:
raise

logging.warning(
'Caught {}. Retrying in {} seconds...'.format(
e.__class__.__name__, delay
)
)

time.sleep(delay)
if on_error:
on_error()

return wrapper


class ClientWithRetries:

def __init__(self):
self._build_client()
for name, member in self.__dict__.items():
if callable(member) and not name.startswith("_"):
self.__dict__[name] = with_retries(func=member, errors=(BrokenPipeError, IOError), on_error=self._build_client)

@abc.abstractmethod
def _build_client():
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

import googleapiclient.discovery as discovery
from googleapiclient import errors
from ..common import ClientWithRetries


class DataflowClient:

def __init__(self):
class DataflowClient(ClientWithRetries):
def _build_client(self):
self._df = discovery.build('dataflow', 'v1b3', cache_discovery=False)

def launch_template(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import time

import googleapiclient.discovery as discovery
from ..common import wait_operation_done
from ..common import wait_operation_done, ClientWithRetries

class DataprocClient:

class DataprocClient(ClientWithRetries):
""" Internal client for calling Dataproc APIs.
"""
def __init__(self):

def _create_client(self):
self._dataproc = discovery.build('dataproc', 'v1', cache_discovery=False)

def create_cluster(self, project_id, region, cluster, request_id):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,55 +18,16 @@

import googleapiclient.discovery as discovery
from googleapiclient import errors
from ..common import wait_operation_done
from ..common import wait_operation_done, ClientWithRetries


def _retry(func, tries=5, delay=1):
"""Retry decorator for methods in MLEngineClient class.
It bypasses the BrokenPipeError by directly accessing the `_build_client` method
and rebuilds `_ml_client` after `delay` seconds.
Args:
tries (int): Total number of retries if BrokenPipeError/IOError is raised.
delay (int): Number of seconds to wait between consecutive retries.
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
_tries, _delay = tries, delay
while _tries:
try:
return func(self, *args, **kwargs)
except (BrokenPipeError, IOError) as e:
_tries -= 1
if not _tries:
raise

logging.warning(
'Caught {}. Retrying in {} seconds...'.format(
e._class__.__name__, _delay
)
)

time.sleep(_delay)
# access _build_client method and rebuild Http Client
self._build_client()

return wrapper


class MLEngineClient:
class MLEngineClient(ClientWithRetries):
""" Client for calling MLEngine APIs.
"""

def __init__(self):
self._build_client()

def _build_client(self):
self._ml_client = discovery.build('ml', 'v1', cache_discovery=False)

@_retry
def create_job(self, project_id, job):
"""Create a new job.
Expand All @@ -82,7 +43,6 @@ def create_job(self, project_id, job):
body = job
).execute()

@_retry
def cancel_job(self, project_id, job_id):
"""Cancel the specified job.
Expand All @@ -98,7 +58,6 @@ def cancel_job(self, project_id, job_id):
},
).execute()

@_retry
def get_job(self, project_id, job_id):
"""Gets the job by ID.
Expand All @@ -112,7 +71,6 @@ def get_job(self, project_id, job_id):
return self._ml_client.projects().jobs().get(
name=job_name).execute()

@_retry
def create_model(self, project_id, model):
"""Creates a new model.
Expand All @@ -127,7 +85,6 @@ def create_model(self, project_id, model):
body = model
).execute()

@_retry
def get_model(self, model_name):
"""Gets a model.
Expand All @@ -140,7 +97,6 @@ def get_model(self, model_name):
name = model_name
).execute()

@_retry
def create_version(self, model_name, version):
"""Creates a new version.
Expand All @@ -156,7 +112,6 @@ def create_version(self, model_name, version):
body = version
).execute()

@_retry
def get_version(self, version_name):
"""Gets a version.
Expand All @@ -175,7 +130,6 @@ def get_version(self, version_name):
return None
raise

@_retry
def delete_version(self, version_name):
"""Deletes a version.
Expand All @@ -195,13 +149,11 @@ def delete_version(self, version_name):
return None
raise

@_retry
def set_default_version(self, version_name):
return self._ml_client.projects().models().versions().setDefault(
name = version_name
).execute()

@_retry
def get_operation(self, operation_name):
"""Gets an operation.
Expand Down Expand Up @@ -229,7 +181,6 @@ def wait_for_operation_done(self, operation_name, wait_interval):
return wait_operation_done(
lambda: self.get_operation(operation_name), wait_interval)

@_retry
def cancel_operation(self, operation_name):
"""Cancels an operation.
Expand Down

0 comments on commit 5bdb5e5

Please sign in to comment.