Skip to content

Commit

Permalink
Merge execute_entity_ids into execute.
Browse files Browse the repository at this point in the history
This merges the functions in py/vtdb/{vtgate_cursor,vtgatev2}.py.
  • Loading branch information
dumbunny committed Oct 22, 2015
1 parent 6c93a31 commit 2fbbec7
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 72 deletions.
30 changes: 7 additions & 23 deletions py/vtdb/vtgate_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,17 @@ def execute(self, sql, bind_variables, **kargs):
self._clear_batch_state()
if self._handle_transaction_sql(sql):
return
entity_keyspace_id_map = kargs.get('entity_keyspace_id_map')
entity_column_name = kargs.get('entity_column_name')
write_query = bool(write_sql_pattern.match(sql))
# NOTE: This check may also be done at higher layers but adding it
# here for completion.
if write_query:
if not self.is_writable():
raise dbexceptions.DatabaseError('DML on a non-writable cursor', sql)

if entity_keyspace_id_map:
raise dbexceptions.DatabaseError(
'entity_keyspace_id_map is not allowed for write queries')
self.results, self.rowcount, self.lastrowid, self.description = (
self.connection._execute(
sql,
Expand All @@ -85,28 +89,8 @@ def execute(self, sql, bind_variables, **kargs):
self.tablet_type,
keyspace_ids=self.keyspace_ids,
keyranges=self.keyranges,
not_in_transaction=not self.is_writable(),
effective_caller_id=self.effective_caller_id))
return self.rowcount

def execute_entity_ids(
self, sql, bind_variables, entity_keyspace_id_map, entity_column_name):
self._clear_list_state()
self._clear_batch_state()

# This is by definition a scatter query, so raise exception.
write_query = bool(write_sql_pattern.match(sql))
if write_query:
raise dbexceptions.DatabaseError(
'execute_entity_ids is not allowed for write queries')
self.results, self.rowcount, self.lastrowid, self.description = (
self.connection._execute_entity_ids(
sql,
bind_variables,
self.keyspace,
self.tablet_type,
entity_keyspace_id_map,
entity_column_name,
entity_keyspace_id_map=entity_keyspace_id_map,
entity_column_name=entity_column_name,
not_in_transaction=not self.is_writable(),
effective_caller_id=self.effective_caller_id))
return self.rowcount
Expand Down
90 changes: 47 additions & 43 deletions py/vtdb/vtgatev2.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,64 +187,69 @@ def _get_rowset_from_query_result(self, query_result):

@vtgate_utils.exponential_backoff_retry((dbexceptions.TransientError))
def _execute(
self, sql, bind_variables, keyspace, tablet_type, keyspace_ids=None,
keyranges=None, not_in_transaction=False, effective_caller_id=None):
self, sql, bind_variables, keyspace, tablet_type,
keyspace_ids=None, keyranges=None,
entity_keyspace_id_map=None, entity_column_name=None,
not_in_transaction=False, effective_caller_id=None):

routing_kwargs = {}
exec_method = None
req = None
if keyspace_ids is not None:
routing_kwargs['keyspace_ids'] = keyspace_ids
req = _create_req_with_keyspace_ids(
sql, bind_variables, keyspace, tablet_type, keyspace_ids,
not_in_transaction)
exec_method = 'VTGate.ExecuteKeyspaceIds'
elif keyranges is not None:
routing_kwargs['keyranges'] = keyranges
req = _create_req_with_keyranges(
sql, bind_variables, keyspace, tablet_type, keyranges,
not_in_transaction)
exec_method = 'VTGate.ExecuteKeyRanges'
elif entity_keyspace_id_map is not None:
routing_kwargs['entity_keyspace_id_map'] = entity_keyspace_id_map
routing_kwargs['entity_column_name'] = entity_column_name
if entity_column_name is None:
raise dbexceptions.ProgrammingError(
'_execute called with entity_keyspace_id_map and no '
'entity_column_name')
sql, new_binds = dbapi.prepare_query_bind_vars(sql, bind_variables)
new_binds = field_types.convert_bind_vars(new_binds)
req = {
'Sql': sql,
'BindVariables': new_binds,
'Keyspace': keyspace,
'TabletType': tablet_type,
'EntityKeyspaceIDs': [
{'ExternalID': xid, 'KeyspaceID': kid}
for xid, kid in entity_keyspace_id_map.iteritems()],
'EntityColumnName': entity_column_name,
'NotInTransaction': not_in_transaction,
}
exec_method = 'VTGate.ExecuteEntityIds'
else:
raise dbexceptions.ProgrammingError(
'_execute called without specifying keyspace_ids or keyranges')

self._add_caller_id(req, effective_caller_id)
self._add_session(req)
try:
response = self._get_client().call(exec_method, req)
self._update_session(response)
vtgate_utils.extract_rpc_error(exec_method, response)
reply = response.reply
return self._get_rowset_from_query_result(reply.get('Result'))
except (gorpc.GoRpcError, vtgate_utils.VitessError) as e:
self.logger_object.log_private_data(bind_variables)
raise self._convert_exception(
e, sql, keyspace_ids, keyranges,
keyspace=keyspace, tablet_type=tablet_type)
except Exception:
logging.exception('gorpc low-level error')
raise

@vtgate_utils.exponential_backoff_retry((dbexceptions.TransientError))
def _execute_entity_ids(
self, sql, bind_variables, keyspace, tablet_type,
entity_keyspace_id_map, entity_column_name, not_in_transaction=False,
effective_caller_id=None):
sql, new_binds = dbapi.prepare_query_bind_vars(sql, bind_variables)
new_binds = field_types.convert_bind_vars(new_binds)
req = {
'Sql': sql,
'BindVariables': new_binds,
'Keyspace': keyspace,
'TabletType': tablet_type,
'EntityKeyspaceIDs': [
{'ExternalID': xid, 'KeyspaceID': kid}
for xid, kid in entity_keyspace_id_map.iteritems()],
'EntityColumnName': entity_column_name,
'NotInTransaction': not_in_transaction,
}
'_execute called with no keyspace_ids, keyranges, or '
'entity_keyspace_id_map')

def check_incompatible_args(arg_name):
if arg_name not in routing_kwargs:
raise dbexceptions.ProgrammingError(
'_execute called with routing_args=%s, '
'incompatible routing arg=%s' % (
sorted(routing_kwargs), arg_name))

if keyranges is not None:
check_incompatible_args('keyranges')
if entity_column_name is not None:
check_incompatible_args('entity_column_name')
if entity_keyspace_id_map is not None:
check_incompatible_args('entity_keyspace_id_map')

self._add_caller_id(req, effective_caller_id)
self._add_session(req)
try:
exec_method = 'VTGate.ExecuteEntityIds'
response = self._get_client().call(exec_method, req)
self._update_session(response)
vtgate_utils.extract_rpc_error(exec_method, response)
Expand All @@ -253,8 +258,7 @@ def _execute_entity_ids(
except (gorpc.GoRpcError, vtgate_utils.VitessError) as e:
self.logger_object.log_private_data(bind_variables)
raise self._convert_exception(
e, sql, entity_keyspace_id_map,
keyspace=keyspace, tablet_type=tablet_type)
e, sql, keyspace=keyspace, tablet_type=tablet_type, **routing_kwargs)
except Exception:
logging.exception('gorpc low-level error')
raise
Expand Down Expand Up @@ -510,7 +514,7 @@ def _convert_exception(self, exc, *args, **kwargs):

new_args = exc.args + (str(self),) + args
if kwargs:
new_args += tuple(kwargs.itervalues())
new_args += tuple(sorted(kwargs.itervalues()))
new_exc = exc

if isinstance(exc, gorpc.TimeoutError):
Expand Down
4 changes: 2 additions & 2 deletions test/python_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _verify_exception_for_execute(self, query, exception):
# ExecuteEntityIds test
cursor = self.conn.cursor('keyspace', 'master')
with self.assertRaises(exception):
cursor.execute_entity_ids(
cursor.execute(
query, {},
entity_keyspace_id_map={1: self.KEYSPACE_ID_0X80},
entity_column_name='user_id')
Expand Down Expand Up @@ -324,7 +324,7 @@ def cursor_execute_key_ranges_method(cursor):
self._open_keyranges_cursor(), cursor_execute_key_ranges_method)

def cursor_execute_entity_ids_method(cursor):
cursor.execute_entity_ids(
cursor.execute(
effective_caller_id_test_query, {},
entity_keyspace_id_map={1: self.KEYSPACE_ID_0X80},
entity_column_name='user_id')
Expand Down
8 changes: 4 additions & 4 deletions test/vtgatev2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,10 @@ def test_execute_entity_ids(self):
'values (%(eid)s, %(id)s, %(keyspace_id)s)',
{'eid': x, 'id': x, 'keyspace_id': keyspace_id})
cursor.commit()
cursor = vtgate_conn.cursor(KEYSPACE_NAME, 'master',
keyspace_ids=eid_map.values())
rowcount = cursor.execute_entity_ids('select * from vt_a', {}, eid_map,
'id')
cursor = vtgate_conn.cursor(KEYSPACE_NAME, 'master', keyspace_ids=None)
rowcount = cursor.execute(
'select * from vt_a', {},
entity_keyspace_id_map=eid_map, entity_column_name='id')
self.assertEqual(rowcount, count, 'entity_ids works')

def test_batch_read(self):
Expand Down

0 comments on commit 2fbbec7

Please sign in to comment.