Skip to content

Commit ba42a3d

Browse files
Greg Leclercqggreg
authored andcommitted
Use same transaction and HTTP session within a connection
When an HTTP gateway is used to balance queries across multiple Presto coordinators, queries within the same transactions must go to the same coordinator. A single transation is started for a given connection. An HTTP cookie can support this behavior such as `Set-Cookie: Presto-Gateway-Sticky={host}:{port};Version=1`. The cookie is persisted by the `requests.Session` object referenced by `PrestoRequest._http_session`. Hence all queries should be sent with the same session data. Presto client protocol is connection-less. Here the connection is used to persist session data and use the same connection parameters.
1 parent e9519d1 commit ba42a3d

File tree

4 files changed

+50
-54
lines changed

4 files changed

+50
-54
lines changed

integration_tests/test_dbapi.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,13 @@ def test_select_query(presto_connection):
6565

6666

6767
def test_select_query_result_iteration(presto_connection):
68-
cur = presto_connection.cursor()
69-
cur.execute('select custkey from tpch.sf1.customer LIMIT 10')
70-
rows0 = cur.genall()
71-
cur.execute('select custkey from tpch.sf1.customer LIMIT 10')
72-
rows1 = cur.fetchall()
68+
cur0 = presto_connection.cursor()
69+
cur0.execute('select custkey from tpch.sf1.customer LIMIT 10')
70+
rows0 = cur0.genall()
71+
72+
cur1 = presto_connection.cursor()
73+
cur1.execute('select custkey from tpch.sf1.customer LIMIT 10')
74+
rows1 = cur1.fetchall()
7375

7476
assert len(list(rows0)) == len(rows1)
7577

@@ -177,7 +179,7 @@ def test_transaction_single(presto_connection_with_transaction):
177179

178180
def test_transaction_rollback(presto_connection_with_transaction):
179181
connection = presto_connection_with_transaction
180-
for i in range(3):
182+
for _ in range(3):
181183
cur = connection.cursor()
182184
cur.execute('SELECT * FROM tpch.sf1.customer LIMIT 1000')
183185
rows = cur.fetchall()

prestodb/client.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def __init__(
203203
catalog=None, # type: Text
204204
schema=None, # type: Text
205205
session_properties=None, # type: Optional[Dict[Text, Any]]
206+
http_session=None, # type: Any
206207
http_headers=None, # type: Optional[Dict[Text, Text]]
207208
transaction_id=NO_TRANSACTION, # type: Optional[Text]
208209
http_scheme=constants.HTTP, # type: Text
@@ -226,8 +227,12 @@ def __init__(
226227
self._host = host
227228
self._port = port
228229
self._next_uri = None # type: Optional[Text]
229-
# mypy cannot follow module import
230-
self._http_session = self.http.Session() # type: ignore
230+
231+
if http_session is not None:
232+
self._http_session = http_session
233+
else:
234+
# mypy cannot follow module import
235+
self._http_session = self.http.Session() # type: ignore
231236
self._http_session.headers.update(self.http_headers)
232237
self._auth = auth
233238
if self._auth:
@@ -413,9 +418,9 @@ def process(self, http_response):
413418

414419
if constants.HEADER_SET_SESSION in http_response.headers:
415420
for key, value in get_session_property_values(
416-
response.headers,
417-
constants.HEADER_SET_SESSION,
418-
):
421+
response.headers,
422+
constants.HEADER_SET_SESSION,
423+
):
419424
self._client_session.properties[key] = value
420425

421426
self._next_uri = response.get('nextUri')

prestodb/dbapi.py

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def __init__(
100100
self.catalog = catalog
101101
self.schema = schema
102102
self.session_properties = session_properties
103+
# mypy cannot follow module import
104+
self._http_session = prestodb.client.PrestoRequest.http.Session()
103105
self.http_headers = http_headers
104106
self.http_scheme = http_scheme
105107
self.auth = auth
@@ -108,6 +110,7 @@ def __init__(
108110
self.request_timeout = request_timeout
109111

110112
self._isolation_level = isolation_level
113+
self._request = None
111114
self._transaction = None
112115

113116
@property
@@ -124,7 +127,7 @@ def __enter__(self):
124127
def __exit__(self, exc_type, exc_value, traceback):
125128
try:
126129
self.commit()
127-
except:
130+
except Exception:
128131
self.rollback()
129132
else:
130133
self.close()
@@ -135,14 +138,32 @@ def close(self):
135138
pass
136139

137140
def start_transaction(self):
138-
request = prestodb.client.PrestoRequest(
141+
self._transaction = Transaction(self._create_request())
142+
self._transaction.begin()
143+
return self._transaction
144+
145+
def commit(self):
146+
if self.transaction is None:
147+
return
148+
self._transaction.commit()
149+
self._transaction = None
150+
151+
def rollback(self):
152+
if self.transaction is None:
153+
raise RuntimeError('no transaction was started')
154+
self._transaction.rollback()
155+
self._transaction = None
156+
157+
def _create_request(self):
158+
return prestodb.client.PrestoRequest(
139159
self.host,
140160
self.port,
141161
self.user,
142162
self.source,
143163
self.catalog,
144164
self.schema,
145165
self.session_properties,
166+
self._http_session,
146167
self.http_headers,
147168
NO_TRANSACTION,
148169
self.http_scheme,
@@ -151,25 +172,10 @@ def start_transaction(self):
151172
self.max_attempts,
152173
self.request_timeout,
153174
)
154-
self._transaction = Transaction(request)
155-
self._transaction.begin()
156-
return self._transaction
157-
158-
def commit(self):
159-
if self.transaction is None:
160-
return
161-
self._transaction.commit()
162-
self._transaction = None
163-
164-
def rollback(self):
165-
if self.transaction is None:
166-
raise RuntimeError('no transaction was started')
167-
self._transaction.rollback()
168-
self._transaction = None
169175

170176
def cursor(self):
171177
"""Return a new :py:class:`Cursor` object using the connection."""
172-
return Cursor(self)
178+
return Cursor(self, self._create_request())
173179

174180

175181
class Cursor(object):
@@ -179,13 +185,14 @@ class Cursor(object):
179185
cursor are immediately visible by other cursors or connections.
180186
181187
"""
182-
def __init__(self, connection):
188+
def __init__(self, connection, request):
183189
if not isinstance(connection, Connection):
184190
raise ValueError(
185191
'connection must be a Connection object: {}'.format(
186192
type(connection)
187193
))
188194
self._connection = connection
195+
self._request = request
189196

190197
self.arraysize = 1
191198
self._iterator = None
@@ -195,7 +202,6 @@ def __init__(self, connection):
195202
def connection(self):
196203
return self._connection
197204

198-
199205
@property
200206
def description(self):
201207
if self._query.columns is None:
@@ -234,28 +240,8 @@ def execute(self, operation, params=None):
234240
if self.connection.isolation_level != IsolationLevel.AUTOCOMMIT:
235241
if self.connection.transaction is None:
236242
self.connection.start_transaction()
237-
transaction_id = self.connection.transaction.id
238-
else:
239-
transaction_id = 'NONE'
240-
241-
request = prestodb.client.PrestoRequest(
242-
self.connection.host,
243-
self.connection.port,
244-
self.connection.user,
245-
self.connection.source,
246-
self.connection.catalog,
247-
self.connection.schema,
248-
self.connection.session_properties,
249-
self.connection.http_headers,
250-
transaction_id,
251-
self.connection.http_scheme,
252-
self.connection.auth,
253-
self.connection.redirect_handler,
254-
self.connection.max_attempts,
255-
self.connection.request_timeout,
256-
)
257243

258-
self._query = prestodb.client.PrestoQuery(request, sql=operation)
244+
self._query = prestodb.client.PrestoQuery(self._request, sql=operation)
259245
result = self._query.execute()
260246
self._iterator = iter(result)
261247
return result

prestodb/transaction.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,22 +87,25 @@ def begin(self):
8787
response.headers[constants.HEADER_STARTED_TRANSACTION]
8888
)
8989
status = self._request.process(response)
90+
self._request.transaction_id = self._id
9091
logger.info('transaction started: ' + self._id)
9192

9293
def commit(self):
93-
self._request.transaction_id = self._id
9494
query = prestodb.client.PrestoQuery(self._request, COMMIT)
9595
try:
9696
list(query.execute())
9797
except Exception as err:
9898
raise prestodb.exceptions.DatabaseError(
9999
'failed to commit transaction {}: {}'.format(self._id, err))
100+
self._id = NO_TRANSACTION
101+
self._request.transaction_id = self._id
100102

101103
def rollback(self):
102-
self._request.transaction_id = self._id
103104
query = prestodb.client.PrestoQuery(self._request, ROLLBACK)
104105
try:
105106
list(query.execute())
106107
except Exception as err:
107108
raise prestodb.exceptions.DatabaseError(
108109
'failed to rollback transaction {}: {}'.format(self._id, err))
110+
self._id = NO_TRANSACTION
111+
self._request.transaction_id = self._id

0 commit comments

Comments
 (0)