|
5 | 5 | import types |
6 | 6 | import requests |
7 | 7 | import asyncio |
| 8 | +from requests.adapters import HTTPAdapter |
8 | 9 | from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse |
9 | 10 | from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody |
10 | 11 | from skyflow.vault._config import Configuration, ConnectionConfig, DeleteOptions, DetokenizeOptions, GetOptions, InsertOptions, UpdateOptions, QueryOptions |
@@ -36,49 +37,71 @@ def __init__(self, config: Configuration): |
36 | 37 | raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.TOKEN_PROVIDER_ERROR.value % ( |
37 | 38 | str(type(config.tokenProvider))), interface=interface) |
38 | 39 |
|
| 40 | + self._create_session() |
39 | 41 | self.vaultID = config.vaultID |
40 | 42 | self.vaultURL = config.vaultURL.rstrip('/') |
41 | 43 | self.tokenProvider = config.tokenProvider |
42 | 44 | self.storedToken = '' |
43 | 45 | log_info(InfoMessages.CLIENT_INITIALIZED.value, interface=interface) |
| 46 | + |
| 47 | + def _create_session(self): |
| 48 | + self.session = requests.Session() |
| 49 | + adapter = HTTPAdapter(pool_connections=1, pool_maxsize=25, pool_block=True) |
| 50 | + self.session.mount("https://", adapter) |
| 51 | + |
| 52 | + def __del__(self): |
| 53 | + if (self.session is not None): |
| 54 | + log_info(InfoMessages.CLOSING_SESSION.value, interface=InterfaceName.CLIENT.value) |
| 55 | + self.session.close() |
| 56 | + self.session = None |
| 57 | + |
| 58 | + def _get_session(self): |
| 59 | + if (self.session is None): |
| 60 | + self._create_session() |
| 61 | + return self.session |
44 | 62 |
|
45 | 63 | def insert(self, records: dict, options: InsertOptions = InsertOptions()): |
| 64 | + max_retries = 1 |
46 | 65 | interface = InterfaceName.INSERT.value |
47 | 66 | log_info(InfoMessages.INSERT_TRIGGERED.value, interface=interface) |
48 | 67 | self._checkConfig(interface) |
49 | | - |
50 | 68 | jsonBody = getInsertRequestBody(records, options) |
51 | 69 | requestURL = self._get_complete_vault_url() |
52 | | - self.storedToken = tokenProviderWrapper( |
53 | | - self.storedToken, self.tokenProvider, interface) |
54 | | - headers = { |
55 | | - "Authorization": "Bearer " + self.storedToken, |
56 | | - "sky-metadata": json.dumps(getMetrics()) |
57 | | - } |
58 | | - max_retries = 3 |
59 | | - # Use for-loop for retry logic, avoid code repetition |
60 | | - for attempt in range(max_retries+1): |
| 70 | + |
| 71 | + for attempt in range(max_retries + 1): |
61 | 72 | try: |
62 | | - # If jsonBody is a dict, use json=, else use data= |
63 | | - response = requests.post(requestURL, data=jsonBody, headers=headers) |
| 73 | + self.storedToken = tokenProviderWrapper( |
| 74 | + self.storedToken, self.tokenProvider, interface) |
| 75 | + headers = { |
| 76 | + "Authorization": "Bearer " + self.storedToken, |
| 77 | + "sky-metadata": json.dumps(getMetrics()), |
| 78 | + } |
| 79 | + response = self._get_session().post( |
| 80 | + requestURL, |
| 81 | + data=jsonBody, |
| 82 | + headers=headers, |
| 83 | + ) |
64 | 84 | processedResponse = processResponse(response) |
65 | 85 | result, partial = convertResponse(records, processedResponse, options) |
66 | 86 | if partial: |
67 | 87 | log_error(SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, interface) |
68 | | - raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, result, interface=interface) |
69 | | - if 'records' not in result: |
| 88 | + elif 'records' not in result: |
70 | 89 | log_error(SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, interface) |
71 | | - raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, result, interface=interface) |
72 | | - log_info(InfoMessages.INSERT_DATA_SUCCESS.value, interface) |
| 90 | + else: |
| 91 | + log_info(InfoMessages.INSERT_DATA_SUCCESS.value, interface) |
73 | 92 | return result |
74 | | - except Exception as err: |
| 93 | + except requests.exceptions.ConnectionError as err: |
75 | 94 | if attempt < max_retries: |
76 | | - continue |
77 | | - else: |
78 | | - if isinstance(err, SkyflowError): |
79 | | - raise err |
80 | | - else: |
81 | | - raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, f"Error occurred: {err}", interface=interface) |
| 95 | + continue |
| 96 | + raise SkyflowError( |
| 97 | + SkyflowErrorCodes.SERVER_ERROR, |
| 98 | + SkyflowErrorMessages.NETWORK_ERROR.value % str(err), |
| 99 | + interface=interface |
| 100 | + ) |
| 101 | + except SkyflowError as err: |
| 102 | + if err.code != SkyflowErrorCodes.SERVER_ERROR or attempt >= max_retries: |
| 103 | + raise err |
| 104 | + continue |
82 | 105 |
|
83 | 106 | def detokenize(self, records: dict, options: DetokenizeOptions = DetokenizeOptions()): |
84 | 107 | interface = InterfaceName.DETOKENIZE.value |
|
0 commit comments