diff --git a/client/starwhale/api/_impl/data_store.py b/client/starwhale/api/_impl/data_store.py index 1c71850bf7..289ea6d154 100644 --- a/client/starwhale/api/_impl/data_store.py +++ b/client/starwhale/api/_impl/data_store.py @@ -17,14 +17,12 @@ import requests import pyarrow.parquet as pq # type: ignore from loguru import logger -from tenacity import retry -from tenacity.stop import stop_after_attempt -from tenacity.retry import retry_if_exception_type from typing_extensions import Protocol from starwhale.utils.fs import ensure_dir from starwhale.consts.env import SWEnv from starwhale.utils.error import MissingFieldError +from starwhale.utils.retry import http_retry from starwhale.utils.config import SWCliConfigMixed try: @@ -872,11 +870,7 @@ def __init__(self, instance_uri: str) -> None: if self.token is None: raise RuntimeError("SW_TOKEN is not found in environment") - @retry( - reraise=True, - stop=stop_after_attempt(3), - retry=retry_if_exception_type(requests.exceptions.HTTPError), - ) + @http_retry def update_table( self, table_name: str, @@ -923,6 +917,20 @@ def update_table( ) resp.raise_for_status() + @http_retry + def _do_scan_table_request(self, post_data: Dict[str, Any]) -> Dict[str, Any]: + resp = requests.post( + urllib.parse.urljoin(self.instance_uri, "/api/v1/datastore/scanTable"), + data=json.dumps(post_data, separators=(",", ":")), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": self.token, # type: ignore + }, + timeout=60, + ) + resp.raise_for_status() + return resp.json()["data"] # type: ignore + def scan_tables( self, tables: List[TableDesc], @@ -942,17 +950,7 @@ def scan_tables( post_data["keepNone"] = True assert self.token is not None while True: - resp = requests.post( - urllib.parse.urljoin(self.instance_uri, "/api/v1/datastore/scanTable"), - data=json.dumps(post_data, separators=(",", ":")), - headers={ - "Content-Type": "application/json; charset=utf-8", - "Authorization": self.token, - }, - timeout=60, - ) - resp.raise_for_status() - resp_json: Dict[str, Any] = resp.json()["data"] + resp_json = self._do_scan_table_request(post_data) records = resp_json.get("records", None) if records is None or len(records) == 0: break diff --git a/client/starwhale/base/cloud.py b/client/starwhale/base/cloud.py index 68ef882e6e..e4059cf1f2 100644 --- a/client/starwhale/base/cloud.py +++ b/client/starwhale/base/cloud.py @@ -22,6 +22,7 @@ from starwhale.utils.fs import ensure_dir from starwhale.utils.http import ignore_error, wrap_sw_error_resp from starwhale.utils.error import NoSupportError +from starwhale.utils.retry import http_retry _TMP_FILE_BUFSIZE = 8192 _DEFAULT_TIMEOUT_SECS = 90 @@ -94,7 +95,7 @@ def _progress_bar(monitor: MultipartEncoderMonitor) -> None: _headers["Content-Type"] = _encoder.content_type _monitor = MultipartEncoderMonitor(_encoder, callback=_progress_bar) - return self.do_http_request( + return self.do_http_request( # type: ignore url_path, instance_uri=instance_uri, method=HTTPMethod.POST, @@ -121,6 +122,7 @@ def do_http_request_simple_ret( return status, message + @http_retry def do_http_request( self, path: str, diff --git a/client/starwhale/utils/retry.py b/client/starwhale/utils/retry.py new file mode 100644 index 0000000000..67d8849f1e --- /dev/null +++ b/client/starwhale/utils/retry.py @@ -0,0 +1,50 @@ +import typing as t +from inspect import iscoroutinefunction +from urllib.error import HTTPError + +import requests +from tenacity import Retrying +from tenacity.stop import stop_after_attempt +from tenacity.retry import retry_if_exception +from tenacity._asyncio import AsyncRetrying +from requests.exceptions import HTTPError as RequestsHTTPError + +# https://docs.microsoft.com/en-us/azure/architecture/best-practices/retry-service-specific#general-rest-and-retry-guidelines +_RETRY_HTTP_STATUS_CODES = (408, 429, 500, 502, 503, 504) + + +class retry_if_http_exception(retry_if_exception): + def __init__(self, status_codes: t.Optional[t.Sequence[int]] = None) -> None: + self.status_codes = status_codes or _RETRY_HTTP_STATUS_CODES + + def _predicate(e: BaseException) -> bool: + if isinstance(e, RequestsHTTPError) and isinstance( + e.response, requests.Response + ): + return e.response.status_code in self.status_codes + elif isinstance(e, HTTPError): + return e.code in self.status_codes + else: + return False + + super().__init__(_predicate) + + +def http_retry(*args: t.Any, **kw: t.Any) -> t.Any: + + # support http_retry and http_retry() + if len(args) == 1 and callable(args[0]): + return http_retry()(args[0]) + else: + + def wrap(f: t.Callable) -> t.Any: + _attempts = kw.get("attempts", 3) + _cls = AsyncRetrying if iscoroutinefunction(f) else Retrying + return _cls( + *args, + reraise=True, + stop=stop_after_attempt(_attempts), + retry=retry_if_http_exception(_RETRY_HTTP_STATUS_CODES), + ).wraps(f) + + return wrap diff --git a/client/tests/sdk/test_data_store.py b/client/tests/sdk/test_data_store.py index 1943f5596f..326c5fe057 100644 --- a/client/tests/sdk/test_data_store.py +++ b/client/tests/sdk/test_data_store.py @@ -8,7 +8,9 @@ import numpy as np import pyarrow as pa # type: ignore import requests +from requests_mock import Mocker +from starwhale.consts import HTTPMethod from starwhale.api._impl import data_store from .test_base import BaseTestCase @@ -1473,7 +1475,13 @@ def test_insert_and_delete(self) -> None: ) assert not self.writer.is_alive() - def test_run_thread_exception_limit(self) -> None: + @Mocker() + def test_run_thread_exception_limit(self, request_mock: Mocker) -> None: + request_mock.request( + HTTPMethod.POST, + url="http://1.1.1.1/api/v1/datastore/updateTable", + status_code=400, + ) remote_store = data_store.RemoteDataStore("http://1.1.1.1") remote_writer = data_store.TableWriter( "p/test", "k", remote_store, run_exceptions_limits=0 @@ -1505,7 +1513,13 @@ def test_run_thread_exception_limit(self) -> None: assert len(remote_writer._queue_run_exceptions) == 0 remote_writer.close() - def test_run_thread_exception(self) -> None: + @Mocker() + def test_run_thread_exception(self, request_mock: Mocker) -> None: + request_mock.request( + HTTPMethod.POST, + url="http://1.1.1.1/api/v1/datastore/updateTable", + status_code=400, + ) remote_store = data_store.RemoteDataStore("http://1.1.1.1") remote_writer = data_store.TableWriter("p/test", "k", remote_store) diff --git a/client/tests/sdk/test_wrapper.py b/client/tests/sdk/test_wrapper.py index 04366cc56b..1000da2001 100644 --- a/client/tests/sdk/test_wrapper.py +++ b/client/tests/sdk/test_wrapper.py @@ -1,5 +1,8 @@ import os +from requests_mock import Mocker + +from starwhale.consts import HTTPMethod from starwhale.api._impl import wrapper, data_store from starwhale.consts.env import SWEnv @@ -46,7 +49,14 @@ def test_log_metrics(self) -> None: ), ) - def test_exception_close(self) -> None: + @Mocker() + def test_exception_close(self, request_mock: Mocker) -> None: + request_mock.request( + HTTPMethod.POST, + url="http://1.1.1.1/api/v1/datastore/updateTable", + status_code=400, + ) + os.environ[SWEnv.instance_token] = "abcd" os.environ[SWEnv.instance_uri] = "http://1.1.1.1" eval = wrapper.Evaluation("test") diff --git a/client/tests/utils/test_common.py b/client/tests/utils/test_common.py index c6e0e6d063..aab4fc41d0 100644 --- a/client/tests/utils/test_common.py +++ b/client/tests/utils/test_common.py @@ -1,15 +1,20 @@ import os import typing as t +import urllib.error from pathlib import Path +from unittest import TestCase from unittest.mock import patch import pytest +import requests +from requests_mock import Mocker from pyfakefs.fake_filesystem import FakeFilesystem from pyfakefs.fake_filesystem_unittest import Patcher from starwhale.utils import load_dotenv, pretty_merge_list, validate_obj_name -from starwhale.consts import ENV_LOG_LEVEL +from starwhale.consts import HTTPMethod, ENV_LOG_LEVEL from starwhale.utils.debug import init_logger +from starwhale.utils.retry import http_retry def test_valid_object_name() -> None: @@ -70,3 +75,40 @@ def test_pretty_merge_list() -> None: for in_lst, expected_str in _cases: assert pretty_merge_list(in_lst) == expected_str + + +class TestRetry(TestCase): + @http_retry + def _do_request(self, url: str) -> None: + _r = requests.get(url, timeout=1) + _r.raise_for_status() + raise Exception("dummy") + + @http_retry(attempts=6) + def _do_urllib_raise(self): + raise urllib.error.HTTPError("http://1.1.1.1", 500, "dummy", None, None) # type: ignore + + @Mocker() + def test_http_retry(self, request_mock: Mocker) -> None: + _cases = [ + (200, 1, Exception), + (400, 1, requests.exceptions.HTTPError), + (500, 3, requests.exceptions.HTTPError), + (503, 3, requests.exceptions.HTTPError), + ] + + for status_code, expected_attempts, exception in _cases: + url = f"http://1.1.1.1/{status_code}" + request_mock.request(HTTPMethod.GET, url, status_code=status_code) + + with self.assertRaises(exception): + self._do_request(url) + + assert ( + self._do_request.retry.statistics["attempt_number"] == expected_attempts + ), url + + with self.assertRaises(urllib.error.HTTPError): + self._do_urllib_raise() + + assert self._do_urllib_raise.retry.statistics["attempt_number"] == 6