From 7674d417089d810cca5c55f756ca8d8c37088634 Mon Sep 17 00:00:00 2001 From: Jing Wang Date: Mon, 21 Mar 2016 16:58:28 -0700 Subject: [PATCH] Add test for session properties --- pyhive/presto.py | 5 +++-- pyhive/tests/test_presto.py | 44 +++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/pyhive/presto.py b/pyhive/presto.py index 6a7702c5..fd3fbfdc 100644 --- a/pyhive/presto.py +++ b/pyhive/presto.py @@ -45,8 +45,6 @@ class Connection(object): def __init__(self, *args, **kwargs): self._args = args self._kwargs = kwargs - if 'session_props' not in kwargs: - kwargs['session_props'] = {} def close(self): """Presto does not have anything to close""" @@ -212,6 +210,9 @@ def _process_response(self, response): assert self._state == self._STATE_RUNNING, "Should be running if processing response" self._nextUri = response_json.get('nextUri') self._columns = response_json.get('columns') + if 'X-Presto-Clear-Session' in response.headers: + propname = response.headers['X-Presto-Clear-Session'] + self._session_props.pop(propname, None) if 'X-Presto-Set-Session' in response.headers: propname, propval = response.headers['X-Presto-Set-Session'].split('=', 1) self._session_props[propname] = propval diff --git a/pyhive/tests/test_presto.py b/pyhive/tests/test_presto.py index 6703042c..63963b7b 100644 --- a/pyhive/tests/test_presto.py +++ b/pyhive/tests/test_presto.py @@ -6,6 +6,9 @@ from __future__ import absolute_import from __future__ import unicode_literals + +import contextlib + from pyhive import exc from pyhive import presto from pyhive.tests.dbapi_test_case import DBAPITestCase @@ -103,3 +106,44 @@ def fail(*args, **kwargs): self.fail("Should not need requests.get after done polling") # pragma: no cover with mock.patch('requests.get', fail): self.assertEqual(cursor.fetchall(), [[1]]) + + @with_cursor + def test_set_session(self, cursor): + cursor.execute("SET SESSION query_max_run_time = '1234m'") + cursor.fetchall() + + cursor.execute('SHOW SESSION') + rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + assert len(rows) == 1 + session_prop = rows[0] + assert session_prop[1] == '1234m' + + cursor.execute('RESET SESSION query_max_run_time') + cursor.fetchall() + + cursor.execute('SHOW SESSION') + rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + assert len(rows) == 1 + session_prop = rows[0] + assert session_prop[1] != '1234m' + + def test_set_session_in_consructor(self): + conn = presto.connect( + host=_HOST, source=self.id(), session_props={'query_max_run_time': '1234m'} + ) + with contextlib.closing(conn): + with contextlib.closing(conn.cursor()) as cursor: + cursor.execute('SHOW SESSION') + rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + assert len(rows) == 1 + session_prop = rows[0] + assert session_prop[1] == '1234m' + + cursor.execute('RESET SESSION query_max_run_time') + cursor.fetchall() + + cursor.execute('SHOW SESSION') + rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] + assert len(rows) == 1 + session_prop = rows[0] + assert session_prop[1] != '1234m'