Skip to content

Commit 10350bc

Browse files
SK-2131 retry for errors in insert (#205)
1 parent b4bd129 commit 10350bc

File tree

1 file changed

+22
-46
lines changed

1 file changed

+22
-46
lines changed

skyflow/vault/_client.py

Lines changed: 22 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import types
66
import requests
77
import asyncio
8-
from requests.adapters import HTTPAdapter
98
from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse
109
from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody
1110
from skyflow.vault._config import Configuration, ConnectionConfig, DeleteOptions, DetokenizeOptions, GetOptions, InsertOptions, UpdateOptions, QueryOptions
@@ -37,71 +36,48 @@ def __init__(self, config: Configuration):
3736
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.TOKEN_PROVIDER_ERROR.value % (
3837
str(type(config.tokenProvider))), interface=interface)
3938

40-
self._create_session()
4139
self.vaultID = config.vaultID
4240
self.vaultURL = config.vaultURL.rstrip('/')
4341
self.tokenProvider = config.tokenProvider
4442
self.storedToken = ''
4543
log_info(InfoMessages.CLIENT_INITIALIZED.value, interface=interface)
46-
47-
def _create_session(self):
48-
self.session = requests.Session()
49-
adapter = HTTPAdapter(pool_connections=1, pool_maxsize=25, pool_block=True)
50-
self.session.mount("https://", adapter)
51-
52-
def __del__(self):
53-
if (self.session is not None):
54-
log_info(InfoMessages.CLOSING_SESSION.value, interface=InterfaceName.CLIENT.value)
55-
self.session.close()
56-
self.session = None
57-
58-
def _get_session(self):
59-
if (self.session is None):
60-
self._create_session()
61-
return self.session
6244

6345
def insert(self, records: dict, options: InsertOptions = InsertOptions()):
64-
max_retries = 1
6546
interface = InterfaceName.INSERT.value
6647
log_info(InfoMessages.INSERT_TRIGGERED.value, interface=interface)
6748
self._checkConfig(interface)
6849
jsonBody = getInsertRequestBody(records, options)
6950
requestURL = self._get_complete_vault_url()
70-
71-
for attempt in range(max_retries + 1):
51+
self.storedToken = tokenProviderWrapper(
52+
self.storedToken, self.tokenProvider, interface)
53+
headers = {
54+
"Authorization": "Bearer " + self.storedToken,
55+
"sky-metadata": json.dumps(getMetrics())
56+
}
57+
max_retries = 3
58+
# Use for-loop for retry logic, avoid code repetition
59+
for attempt in range(max_retries+1):
7260
try:
73-
self.storedToken = tokenProviderWrapper(
74-
self.storedToken, self.tokenProvider, interface)
75-
headers = {
76-
"Authorization": "Bearer " + self.storedToken,
77-
"sky-metadata": json.dumps(getMetrics()),
78-
}
79-
response = self._get_session().post(
80-
requestURL,
81-
data=jsonBody,
82-
headers=headers,
83-
)
61+
# If jsonBody is a dict, use json=, else use data=
62+
response = requests.post(requestURL, data=jsonBody, headers=headers)
8463
processedResponse = processResponse(response)
8564
result, partial = convertResponse(records, processedResponse, options)
8665
if partial:
8766
log_error(SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, interface)
88-
elif 'records' not in result:
67+
raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, result, interface=interface)
68+
if 'records' not in result:
8969
log_error(SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, interface)
90-
else:
91-
log_info(InfoMessages.INSERT_DATA_SUCCESS.value, interface)
70+
raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, result, interface=interface)
71+
log_info(InfoMessages.INSERT_DATA_SUCCESS.value, interface)
9272
return result
93-
except requests.exceptions.ConnectionError as err:
73+
except Exception as err:
9474
if attempt < max_retries:
95-
continue
96-
raise SkyflowError(
97-
SkyflowErrorCodes.SERVER_ERROR,
98-
SkyflowErrorMessages.NETWORK_ERROR.value % str(err),
99-
interface=interface
100-
)
101-
except SkyflowError as err:
102-
if err.code != SkyflowErrorCodes.SERVER_ERROR or attempt >= max_retries:
103-
raise err
104-
continue
75+
continue
76+
else:
77+
if isinstance(err, SkyflowError):
78+
raise err
79+
else:
80+
raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, f"Error occurred: {err}", interface=interface)
10581

10682
def detokenize(self, records: dict, options: DetokenizeOptions = DetokenizeOptions()):
10783
interface = InterfaceName.DETOKENIZE.value

0 commit comments

Comments
 (0)