Skip to content

Commit

Permalink
enhance(client): add http_retry decorator (#1104)
Browse files Browse the repository at this point in the history
add http_retry decorator
  • Loading branch information
tianweidut authored Sep 2, 2022
1 parent 2c2c2df commit cbd0602
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 24 deletions.
36 changes: 17 additions & 19 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import requests
import pyarrow.parquet as pq # type: ignore
from loguru import logger
from tenacity import retry
from tenacity.stop import stop_after_attempt
from tenacity.retry import retry_if_exception_type
from typing_extensions import Protocol

from starwhale.utils.fs import ensure_dir
from starwhale.consts.env import SWEnv
from starwhale.utils.error import MissingFieldError
from starwhale.utils.retry import http_retry
from starwhale.utils.config import SWCliConfigMixed

try:
Expand Down Expand Up @@ -872,11 +870,7 @@ def __init__(self, instance_uri: str) -> None:
if self.token is None:
raise RuntimeError("SW_TOKEN is not found in environment")

@retry(
reraise=True,
stop=stop_after_attempt(3),
retry=retry_if_exception_type(requests.exceptions.HTTPError),
)
@http_retry
def update_table(
self,
table_name: str,
Expand Down Expand Up @@ -923,6 +917,20 @@ def update_table(
)
resp.raise_for_status()

@http_retry
def _do_scan_table_request(self, post_data: Dict[str, Any]) -> Dict[str, Any]:
resp = requests.post(
urllib.parse.urljoin(self.instance_uri, "/api/v1/datastore/scanTable"),
data=json.dumps(post_data, separators=(",", ":")),
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": self.token, # type: ignore
},
timeout=60,
)
resp.raise_for_status()
return resp.json()["data"] # type: ignore

def scan_tables(
self,
tables: List[TableDesc],
Expand All @@ -942,17 +950,7 @@ def scan_tables(
post_data["keepNone"] = True
assert self.token is not None
while True:
resp = requests.post(
urllib.parse.urljoin(self.instance_uri, "/api/v1/datastore/scanTable"),
data=json.dumps(post_data, separators=(",", ":")),
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": self.token,
},
timeout=60,
)
resp.raise_for_status()
resp_json: Dict[str, Any] = resp.json()["data"]
resp_json = self._do_scan_table_request(post_data)
records = resp_json.get("records", None)
if records is None or len(records) == 0:
break
Expand Down
4 changes: 3 additions & 1 deletion client/starwhale/base/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from starwhale.utils.fs import ensure_dir
from starwhale.utils.http import ignore_error, wrap_sw_error_resp
from starwhale.utils.error import NoSupportError
from starwhale.utils.retry import http_retry

_TMP_FILE_BUFSIZE = 8192
_DEFAULT_TIMEOUT_SECS = 90
Expand Down Expand Up @@ -94,7 +95,7 @@ def _progress_bar(monitor: MultipartEncoderMonitor) -> None:
_headers["Content-Type"] = _encoder.content_type
_monitor = MultipartEncoderMonitor(_encoder, callback=_progress_bar)

return self.do_http_request(
return self.do_http_request( # type: ignore
url_path,
instance_uri=instance_uri,
method=HTTPMethod.POST,
Expand All @@ -121,6 +122,7 @@ def do_http_request_simple_ret(

return status, message

@http_retry
def do_http_request(
self,
path: str,
Expand Down
50 changes: 50 additions & 0 deletions client/starwhale/utils/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import typing as t
from inspect import iscoroutinefunction
from urllib.error import HTTPError

import requests
from tenacity import Retrying
from tenacity.stop import stop_after_attempt
from tenacity.retry import retry_if_exception
from tenacity._asyncio import AsyncRetrying
from requests.exceptions import HTTPError as RequestsHTTPError

# https://docs.microsoft.com/en-us/azure/architecture/best-practices/retry-service-specific#general-rest-and-retry-guidelines
_RETRY_HTTP_STATUS_CODES = (408, 429, 500, 502, 503, 504)


class retry_if_http_exception(retry_if_exception):
def __init__(self, status_codes: t.Optional[t.Sequence[int]] = None) -> None:
self.status_codes = status_codes or _RETRY_HTTP_STATUS_CODES

def _predicate(e: BaseException) -> bool:
if isinstance(e, RequestsHTTPError) and isinstance(
e.response, requests.Response
):
return e.response.status_code in self.status_codes
elif isinstance(e, HTTPError):
return e.code in self.status_codes
else:
return False

super().__init__(_predicate)


def http_retry(*args: t.Any, **kw: t.Any) -> t.Any:

# support http_retry and http_retry()
if len(args) == 1 and callable(args[0]):
return http_retry()(args[0])
else:

def wrap(f: t.Callable) -> t.Any:
_attempts = kw.get("attempts", 3)
_cls = AsyncRetrying if iscoroutinefunction(f) else Retrying
return _cls(
*args,
reraise=True,
stop=stop_after_attempt(_attempts),
retry=retry_if_http_exception(_RETRY_HTTP_STATUS_CODES),
).wraps(f)

return wrap
18 changes: 16 additions & 2 deletions client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import numpy as np
import pyarrow as pa # type: ignore
import requests
from requests_mock import Mocker

from starwhale.consts import HTTPMethod
from starwhale.api._impl import data_store

from .test_base import BaseTestCase
Expand Down Expand Up @@ -1473,7 +1475,13 @@ def test_insert_and_delete(self) -> None:
)
assert not self.writer.is_alive()

def test_run_thread_exception_limit(self) -> None:
@Mocker()
def test_run_thread_exception_limit(self, request_mock: Mocker) -> None:
request_mock.request(
HTTPMethod.POST,
url="http://1.1.1.1/api/v1/datastore/updateTable",
status_code=400,
)
remote_store = data_store.RemoteDataStore("http://1.1.1.1")
remote_writer = data_store.TableWriter(
"p/test", "k", remote_store, run_exceptions_limits=0
Expand Down Expand Up @@ -1505,7 +1513,13 @@ def test_run_thread_exception_limit(self) -> None:
assert len(remote_writer._queue_run_exceptions) == 0
remote_writer.close()

def test_run_thread_exception(self) -> None:
@Mocker()
def test_run_thread_exception(self, request_mock: Mocker) -> None:
request_mock.request(
HTTPMethod.POST,
url="http://1.1.1.1/api/v1/datastore/updateTable",
status_code=400,
)
remote_store = data_store.RemoteDataStore("http://1.1.1.1")
remote_writer = data_store.TableWriter("p/test", "k", remote_store)

Expand Down
12 changes: 11 additions & 1 deletion client/tests/sdk/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os

from requests_mock import Mocker

from starwhale.consts import HTTPMethod
from starwhale.api._impl import wrapper, data_store
from starwhale.consts.env import SWEnv

Expand Down Expand Up @@ -46,7 +49,14 @@ def test_log_metrics(self) -> None:
),
)

def test_exception_close(self) -> None:
@Mocker()
def test_exception_close(self, request_mock: Mocker) -> None:
request_mock.request(
HTTPMethod.POST,
url="http://1.1.1.1/api/v1/datastore/updateTable",
status_code=400,
)

os.environ[SWEnv.instance_token] = "abcd"
os.environ[SWEnv.instance_uri] = "http://1.1.1.1"
eval = wrapper.Evaluation("test")
Expand Down
44 changes: 43 additions & 1 deletion client/tests/utils/test_common.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import os
import typing as t
import urllib.error
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch

import pytest
import requests
from requests_mock import Mocker
from pyfakefs.fake_filesystem import FakeFilesystem
from pyfakefs.fake_filesystem_unittest import Patcher

from starwhale.utils import load_dotenv, pretty_merge_list, validate_obj_name
from starwhale.consts import ENV_LOG_LEVEL
from starwhale.consts import HTTPMethod, ENV_LOG_LEVEL
from starwhale.utils.debug import init_logger
from starwhale.utils.retry import http_retry


def test_valid_object_name() -> None:
Expand Down Expand Up @@ -70,3 +75,40 @@ def test_pretty_merge_list() -> None:

for in_lst, expected_str in _cases:
assert pretty_merge_list(in_lst) == expected_str


class TestRetry(TestCase):
@http_retry
def _do_request(self, url: str) -> None:
_r = requests.get(url, timeout=1)
_r.raise_for_status()
raise Exception("dummy")

@http_retry(attempts=6)
def _do_urllib_raise(self):
raise urllib.error.HTTPError("http://1.1.1.1", 500, "dummy", None, None) # type: ignore

@Mocker()
def test_http_retry(self, request_mock: Mocker) -> None:
_cases = [
(200, 1, Exception),
(400, 1, requests.exceptions.HTTPError),
(500, 3, requests.exceptions.HTTPError),
(503, 3, requests.exceptions.HTTPError),
]

for status_code, expected_attempts, exception in _cases:
url = f"http://1.1.1.1/{status_code}"
request_mock.request(HTTPMethod.GET, url, status_code=status_code)

with self.assertRaises(exception):
self._do_request(url)

assert (
self._do_request.retry.statistics["attempt_number"] == expected_attempts
), url

with self.assertRaises(urllib.error.HTTPError):
self._do_urllib_raise()

assert self._do_urllib_raise.retry.statistics["attempt_number"] == 6

0 comments on commit cbd0602

Please sign in to comment.