Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance(client): add http_retry decorator #1104

Merged
merged 1 commit into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import requests
import pyarrow.parquet as pq # type: ignore
from loguru import logger
from tenacity import retry
from tenacity.stop import stop_after_attempt
from tenacity.retry import retry_if_exception_type
from typing_extensions import Protocol

from starwhale.utils.fs import ensure_dir
from starwhale.consts.env import SWEnv
from starwhale.utils.error import MissingFieldError
from starwhale.utils.retry import http_retry
from starwhale.utils.config import SWCliConfigMixed

try:
Expand Down Expand Up @@ -872,11 +870,7 @@ def __init__(self, instance_uri: str) -> None:
if self.token is None:
raise RuntimeError("SW_TOKEN is not found in environment")

@retry(
reraise=True,
stop=stop_after_attempt(3),
retry=retry_if_exception_type(requests.exceptions.HTTPError),
)
@http_retry
def update_table(
self,
table_name: str,
Expand Down Expand Up @@ -923,6 +917,20 @@ def update_table(
)
resp.raise_for_status()

@http_retry
def _do_scan_table_request(self, post_data: Dict[str, Any]) -> Dict[str, Any]:
resp = requests.post(
urllib.parse.urljoin(self.instance_uri, "/api/v1/datastore/scanTable"),
data=json.dumps(post_data, separators=(",", ":")),
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": self.token, # type: ignore
},
timeout=60,
)
resp.raise_for_status()
return resp.json()["data"] # type: ignore

def scan_tables(
self,
tables: List[TableDesc],
Expand All @@ -942,17 +950,7 @@ def scan_tables(
post_data["keepNone"] = True
assert self.token is not None
while True:
resp = requests.post(
urllib.parse.urljoin(self.instance_uri, "/api/v1/datastore/scanTable"),
data=json.dumps(post_data, separators=(",", ":")),
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": self.token,
},
timeout=60,
)
resp.raise_for_status()
resp_json: Dict[str, Any] = resp.json()["data"]
resp_json = self._do_scan_table_request(post_data)
records = resp_json.get("records", None)
if records is None or len(records) == 0:
break
Expand Down
4 changes: 3 additions & 1 deletion client/starwhale/base/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from starwhale.utils.fs import ensure_dir
from starwhale.utils.http import ignore_error, wrap_sw_error_resp
from starwhale.utils.error import NoSupportError
from starwhale.utils.retry import http_retry

_TMP_FILE_BUFSIZE = 8192
_DEFAULT_TIMEOUT_SECS = 90
Expand Down Expand Up @@ -94,7 +95,7 @@ def _progress_bar(monitor: MultipartEncoderMonitor) -> None:
_headers["Content-Type"] = _encoder.content_type
_monitor = MultipartEncoderMonitor(_encoder, callback=_progress_bar)

return self.do_http_request(
return self.do_http_request( # type: ignore
url_path,
instance_uri=instance_uri,
method=HTTPMethod.POST,
Expand All @@ -121,6 +122,7 @@ def do_http_request_simple_ret(

return status, message

@http_retry
def do_http_request(
self,
path: str,
Expand Down
50 changes: 50 additions & 0 deletions client/starwhale/utils/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import typing as t
from inspect import iscoroutinefunction
from urllib.error import HTTPError

import requests
from tenacity import Retrying
from tenacity.stop import stop_after_attempt
from tenacity.retry import retry_if_exception
from tenacity._asyncio import AsyncRetrying
from requests.exceptions import HTTPError as RequestsHTTPError

# https://docs.microsoft.com/en-us/azure/architecture/best-practices/retry-service-specific#general-rest-and-retry-guidelines
_RETRY_HTTP_STATUS_CODES = (408, 429, 500, 502, 503, 504)


class retry_if_http_exception(retry_if_exception):
def __init__(self, status_codes: t.Optional[t.Sequence[int]] = None) -> None:
self.status_codes = status_codes or _RETRY_HTTP_STATUS_CODES

def _predicate(e: BaseException) -> bool:
if isinstance(e, RequestsHTTPError) and isinstance(
e.response, requests.Response
):
return e.response.status_code in self.status_codes
elif isinstance(e, HTTPError):
return e.code in self.status_codes
else:
return False

super().__init__(_predicate)


def http_retry(*args: t.Any, **kw: t.Any) -> t.Any:

# support http_retry and http_retry()
if len(args) == 1 and callable(args[0]):
return http_retry()(args[0])
else:

def wrap(f: t.Callable) -> t.Any:
_attempts = kw.get("attempts", 3)
_cls = AsyncRetrying if iscoroutinefunction(f) else Retrying
return _cls(
*args,
reraise=True,
stop=stop_after_attempt(_attempts),
retry=retry_if_http_exception(_RETRY_HTTP_STATUS_CODES),
).wraps(f)

return wrap
18 changes: 16 additions & 2 deletions client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import numpy as np
import pyarrow as pa # type: ignore
import requests
from requests_mock import Mocker

from starwhale.consts import HTTPMethod
from starwhale.api._impl import data_store

from .test_base import BaseTestCase
Expand Down Expand Up @@ -1473,7 +1475,13 @@ def test_insert_and_delete(self) -> None:
)
assert not self.writer.is_alive()

def test_run_thread_exception_limit(self) -> None:
@Mocker()
def test_run_thread_exception_limit(self, request_mock: Mocker) -> None:
request_mock.request(
HTTPMethod.POST,
url="http://1.1.1.1/api/v1/datastore/updateTable",
status_code=400,
)
remote_store = data_store.RemoteDataStore("http://1.1.1.1")
remote_writer = data_store.TableWriter(
"p/test", "k", remote_store, run_exceptions_limits=0
Expand Down Expand Up @@ -1505,7 +1513,13 @@ def test_run_thread_exception_limit(self) -> None:
assert len(remote_writer._queue_run_exceptions) == 0
remote_writer.close()

def test_run_thread_exception(self) -> None:
@Mocker()
def test_run_thread_exception(self, request_mock: Mocker) -> None:
request_mock.request(
HTTPMethod.POST,
url="http://1.1.1.1/api/v1/datastore/updateTable",
status_code=400,
)
remote_store = data_store.RemoteDataStore("http://1.1.1.1")
remote_writer = data_store.TableWriter("p/test", "k", remote_store)

Expand Down
12 changes: 11 additions & 1 deletion client/tests/sdk/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os

from requests_mock import Mocker

from starwhale.consts import HTTPMethod
from starwhale.api._impl import wrapper, data_store
from starwhale.consts.env import SWEnv

Expand Down Expand Up @@ -46,7 +49,14 @@ def test_log_metrics(self) -> None:
),
)

def test_exception_close(self) -> None:
@Mocker()
def test_exception_close(self, request_mock: Mocker) -> None:
request_mock.request(
HTTPMethod.POST,
url="http://1.1.1.1/api/v1/datastore/updateTable",
status_code=400,
)

os.environ[SWEnv.instance_token] = "abcd"
os.environ[SWEnv.instance_uri] = "http://1.1.1.1"
eval = wrapper.Evaluation("test")
Expand Down
44 changes: 43 additions & 1 deletion client/tests/utils/test_common.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import os
import typing as t
import urllib.error
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch

import pytest
import requests
from requests_mock import Mocker
from pyfakefs.fake_filesystem import FakeFilesystem
from pyfakefs.fake_filesystem_unittest import Patcher

from starwhale.utils import load_dotenv, pretty_merge_list, validate_obj_name
from starwhale.consts import ENV_LOG_LEVEL
from starwhale.consts import HTTPMethod, ENV_LOG_LEVEL
from starwhale.utils.debug import init_logger
from starwhale.utils.retry import http_retry


def test_valid_object_name() -> None:
Expand Down Expand Up @@ -70,3 +75,40 @@ def test_pretty_merge_list() -> None:

for in_lst, expected_str in _cases:
assert pretty_merge_list(in_lst) == expected_str


class TestRetry(TestCase):
@http_retry
def _do_request(self, url: str) -> None:
_r = requests.get(url, timeout=1)
_r.raise_for_status()
raise Exception("dummy")

@http_retry(attempts=6)
def _do_urllib_raise(self):
raise urllib.error.HTTPError("http://1.1.1.1", 500, "dummy", None, None) # type: ignore

@Mocker()
def test_http_retry(self, request_mock: Mocker) -> None:
_cases = [
(200, 1, Exception),
(400, 1, requests.exceptions.HTTPError),
(500, 3, requests.exceptions.HTTPError),
(503, 3, requests.exceptions.HTTPError),
]

for status_code, expected_attempts, exception in _cases:
url = f"http://1.1.1.1/{status_code}"
request_mock.request(HTTPMethod.GET, url, status_code=status_code)

with self.assertRaises(exception):
self._do_request(url)

assert (
self._do_request.retry.statistics["attempt_number"] == expected_attempts
), url

with self.assertRaises(urllib.error.HTTPError):
self._do_urllib_raise()

assert self._do_urllib_raise.retry.statistics["attempt_number"] == 6