|
5 | 5 | import types |
6 | 6 | import requests |
7 | 7 | import asyncio |
8 | | -from requests.adapters import HTTPAdapter |
9 | 8 | from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse |
10 | 9 | from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody |
11 | 10 | from skyflow.vault._config import Configuration, ConnectionConfig, DeleteOptions, DetokenizeOptions, GetOptions, InsertOptions, UpdateOptions, QueryOptions |
@@ -37,71 +36,48 @@ def __init__(self, config: Configuration): |
37 | 36 | raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.TOKEN_PROVIDER_ERROR.value % ( |
38 | 37 | str(type(config.tokenProvider))), interface=interface) |
39 | 38 |
|
40 | | - self._create_session() |
41 | 39 | self.vaultID = config.vaultID |
42 | 40 | self.vaultURL = config.vaultURL.rstrip('/') |
43 | 41 | self.tokenProvider = config.tokenProvider |
44 | 42 | self.storedToken = '' |
45 | 43 | log_info(InfoMessages.CLIENT_INITIALIZED.value, interface=interface) |
46 | | - |
47 | | - def _create_session(self): |
48 | | - self.session = requests.Session() |
49 | | - adapter = HTTPAdapter(pool_connections=1, pool_maxsize=25, pool_block=True) |
50 | | - self.session.mount("https://", adapter) |
51 | | - |
52 | | - def __del__(self): |
53 | | - if (self.session is not None): |
54 | | - log_info(InfoMessages.CLOSING_SESSION.value, interface=InterfaceName.CLIENT.value) |
55 | | - self.session.close() |
56 | | - self.session = None |
57 | | - |
58 | | - def _get_session(self): |
59 | | - if (self.session is None): |
60 | | - self._create_session() |
61 | | - return self.session |
62 | 44 |
|
63 | 45 | def insert(self, records: dict, options: InsertOptions = InsertOptions()): |
64 | | - max_retries = 1 |
65 | 46 | interface = InterfaceName.INSERT.value |
66 | 47 | log_info(InfoMessages.INSERT_TRIGGERED.value, interface=interface) |
67 | 48 | self._checkConfig(interface) |
68 | 49 | jsonBody = getInsertRequestBody(records, options) |
69 | 50 | requestURL = self._get_complete_vault_url() |
70 | | - |
71 | | - for attempt in range(max_retries + 1): |
| 51 | + self.storedToken = tokenProviderWrapper( |
| 52 | + self.storedToken, self.tokenProvider, interface) |
| 53 | + headers = { |
| 54 | + "Authorization": "Bearer " + self.storedToken, |
| 55 | + "sky-metadata": json.dumps(getMetrics()) |
| 56 | + } |
| 57 | + max_retries = 3 |
| 58 | + # Use for-loop for retry logic, avoid code repetition |
| 59 | + for attempt in range(max_retries+1): |
72 | 60 | try: |
73 | | - self.storedToken = tokenProviderWrapper( |
74 | | - self.storedToken, self.tokenProvider, interface) |
75 | | - headers = { |
76 | | - "Authorization": "Bearer " + self.storedToken, |
77 | | - "sky-metadata": json.dumps(getMetrics()), |
78 | | - } |
79 | | - response = self._get_session().post( |
80 | | - requestURL, |
81 | | - data=jsonBody, |
82 | | - headers=headers, |
83 | | - ) |
| 61 | + # If jsonBody is a dict, use json=, else use data= |
| 62 | + response = requests.post(requestURL, data=jsonBody, headers=headers) |
84 | 63 | processedResponse = processResponse(response) |
85 | 64 | result, partial = convertResponse(records, processedResponse, options) |
86 | 65 | if partial: |
87 | 66 | log_error(SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, interface) |
88 | | - elif 'records' not in result: |
| 67 | + raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, result, interface=interface) |
| 68 | + if 'records' not in result: |
89 | 69 | log_error(SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, interface) |
90 | | - else: |
91 | | - log_info(InfoMessages.INSERT_DATA_SUCCESS.value, interface) |
| 70 | + raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, result, interface=interface) |
| 71 | + log_info(InfoMessages.INSERT_DATA_SUCCESS.value, interface) |
92 | 72 | return result |
93 | | - except requests.exceptions.ConnectionError as err: |
| 73 | + except Exception as err: |
94 | 74 | if attempt < max_retries: |
95 | | - continue |
96 | | - raise SkyflowError( |
97 | | - SkyflowErrorCodes.SERVER_ERROR, |
98 | | - SkyflowErrorMessages.NETWORK_ERROR.value % str(err), |
99 | | - interface=interface |
100 | | - ) |
101 | | - except SkyflowError as err: |
102 | | - if err.code != SkyflowErrorCodes.SERVER_ERROR or attempt >= max_retries: |
103 | | - raise err |
104 | | - continue |
| 75 | + continue |
| 76 | + else: |
| 77 | + if isinstance(err, SkyflowError): |
| 78 | + raise err |
| 79 | + else: |
| 80 | + raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, f"Error occurred: {err}", interface=interface) |
105 | 81 |
|
106 | 82 | def detokenize(self, records: dict, options: DetokenizeOptions = DetokenizeOptions()): |
107 | 83 | interface = InterfaceName.DETOKENIZE.value |
|
0 commit comments