Skip to content

Commit ecc852c

Browse files
committed
PYTHON-1673 Mongos pinning for sharded transactions
In a sharded transaction, a session is pinned to the mongos server selected for the initial command. All subsequent commands in the same transaction are routed to the pinned mongos server.
1 parent 1d8c739 commit ecc852c

14 files changed

+1002
-60
lines changed

pymongo/bulk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def execute(self, write_concern, session):
515515

516516
client = self.collection.database.client
517517
if not write_concern.acknowledged:
518-
with client._socket_for_writes() as sock_info:
518+
with client._socket_for_writes(session) as sock_info:
519519
self.execute_no_results(sock_info, generator)
520520
else:
521521
return self.execute_command(generator, write_concern, session)

pymongo/change_stream.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def _run_aggregation_cmd(self, session, explicit_session):
107107
"""
108108
read_preference = self._target._read_preference_for(session)
109109
client = self._database.client
110-
with client._socket_for_reads(read_preference) as (sock_info, slave_ok):
110+
with client._socket_for_reads(
111+
read_preference, session) as (sock_info, slave_ok):
111112
pipeline = self._full_pipeline()
112113
cmd = SON([("aggregate", self._aggregation_target),
113114
("pipeline", pipeline),

pymongo/client_session.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def __init__(self, opts):
246246
self.opts = opts
247247
self.state = _TxnState.NONE
248248
self.transaction_id = 0
249+
self.pinned_address = None
249250

250251
def active(self):
251252
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
@@ -369,6 +370,7 @@ def start_transaction(self, read_concern=None, write_concern=None,
369370
self._transaction.state = _TxnState.STARTING
370371
self._start_retryable_write()
371372
self._transaction.transaction_id = self._server_session.transaction_id
373+
self._transaction.pinned_address = None
372374
return _TransactionContext(self)
373375

374376
def commit_transaction(self):
@@ -388,6 +390,10 @@ def commit_transaction(self):
388390
elif state is _TxnState.ABORTED:
389391
raise InvalidOperation(
390392
"Cannot call commitTransaction after calling abortTransaction")
393+
elif state is _TxnState.COMMITTED:
394+
# We're rerunning the commit, move the state back to "in progress"
395+
# so that _in_transaction returns true.
396+
self._transaction.state = _TxnState.IN_PROGRESS
391397

392398
try:
393399
self._finish_transaction_with_retry("commitTransaction")
@@ -441,12 +447,10 @@ def abort_transaction(self):
441447
self._transaction.state = _TxnState.ABORTED
442448

443449
def _finish_transaction(self, command_name):
444-
with self._client._socket_for_writes() as sock_info:
450+
with self._client._socket_for_writes(self) as sock_info:
445451
return self._client.admin._command(
446452
sock_info,
447453
command_name,
448-
txnNumber=self._transaction.transaction_id,
449-
autocommit=False,
450454
session=self,
451455
write_concern=self._transaction.opts.write_concern,
452456
parse_write_concern_error=True)
@@ -528,6 +532,17 @@ def _in_transaction(self):
528532
"""True if this session has an active multi-statement transaction."""
529533
return self._transaction.active()
530534

535+
@property
536+
def _pinned_address(self):
537+
"""The mongos address this transaction was created on."""
538+
if self._transaction.active():
539+
return self._transaction.pinned_address
540+
return None
541+
542+
def _pin_mongos(self, server):
543+
"""Pin this session to the given mongos Server."""
544+
self._transaction.pinned_address = server.description.address
545+
531546
def _txn_read_preference(self):
532547
"""Return read preference of this transaction or None."""
533548
if self._in_transaction:
@@ -542,6 +557,7 @@ def _apply_to(self, command, is_retryable, read_preference):
542557

543558
if not self._in_transaction:
544559
self._transaction.state = _TxnState.NONE
560+
self._transaction.pinned_address = None
545561

546562
if is_retryable:
547563
command['txnNumber'] = self._server_session.transaction_id

pymongo/collection.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,15 +185,16 @@ def __init__(self, database, name, create=False, codec_options=None,
185185

186186
def _socket_for_reads(self, session):
187187
return self.__database.client._socket_for_reads(
188-
self._read_preference_for(session))
188+
self._read_preference_for(session), session)
189189

190190
def _socket_for_primary_reads(self, session):
191191
read_pref = ((session and session._txn_read_preference())
192192
or ReadPreference.PRIMARY)
193-
return self.__database.client._socket_for_reads(read_pref), read_pref
193+
return self.__database.client._socket_for_reads(
194+
read_pref, session), read_pref
194195

195-
def _socket_for_writes(self):
196-
return self.__database.client._socket_for_writes()
196+
def _socket_for_writes(self, session):
197+
return self.__database.client._socket_for_writes(session)
197198

198199
def _command(self, sock_info, command, slave_ok=False,
199200
read_preference=None,
@@ -251,7 +252,7 @@ def __create(self, options, collation, session):
251252
if "size" in options:
252253
options["size"] = float(options["size"])
253254
cmd.update(options)
254-
with self._socket_for_writes() as sock_info:
255+
with self._socket_for_writes(session) as sock_info:
255256
self._command(
256257
sock_info, cmd, read_preference=ReadPreference.PRIMARY,
257258
write_concern=self._write_concern_for(session),
@@ -1804,7 +1805,7 @@ def create_indexes(self, indexes, session=None, **kwargs):
18041805
"""
18051806
common.validate_list('indexes', indexes)
18061807
names = []
1807-
with self._socket_for_writes() as sock_info:
1808+
with self._socket_for_writes(session) as sock_info:
18081809
supports_collations = sock_info.max_wire_version >= 5
18091810
def gen_indexes():
18101811
for index in indexes:
@@ -1844,7 +1845,7 @@ def __create_index(self, keys, index_options, session, **kwargs):
18441845
index_options.pop('collation', None))
18451846
index.update(index_options)
18461847

1847-
with self._socket_for_writes() as sock_info:
1848+
with self._socket_for_writes(session) as sock_info:
18481849
if collation is not None:
18491850
if sock_info.max_wire_version < 5:
18501851
raise ConfigurationError(
@@ -2070,7 +2071,7 @@ def drop_index(self, index_or_name, session=None, **kwargs):
20702071
self.__database.name, self.__name, name)
20712072
cmd = SON([("dropIndexes", self.__name), ("index", name)])
20722073
cmd.update(kwargs)
2073-
with self._socket_for_writes() as sock_info:
2074+
with self._socket_for_writes(session) as sock_info:
20742075
self._command(sock_info,
20752076
cmd,
20762077
read_preference=ReadPreference.PRIMARY,
@@ -2106,7 +2107,7 @@ def reindex(self, session=None, **kwargs):
21062107
"""
21072108
cmd = SON([("reIndex", self.__name)])
21082109
cmd.update(kwargs)
2109-
with self._socket_for_writes() as sock_info:
2110+
with self._socket_for_writes(session) as sock_info:
21102111
return self._command(
21112112
sock_info, cmd, read_preference=ReadPreference.PRIMARY,
21122113
session=session)
@@ -2606,7 +2607,7 @@ def rename(self, new_name, session=None, **kwargs):
26062607
cmd.update(kwargs)
26072608
write_concern = self._write_concern_for_cmd(cmd, session)
26082609

2609-
with self._socket_for_writes() as sock_info:
2610+
with self._socket_for_writes(session) as sock_info:
26102611
with self.__database.client._tmp_session(session) as s:
26112612
return sock_info.command(
26122613
'admin', cmd,

pymongo/database.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def command(self, command, value=1, check=True,
608608
read_preference = ((session and session._txn_read_preference())
609609
or ReadPreference.PRIMARY)
610610
with self.__client._socket_for_reads(
611-
read_preference) as (sock_info, slave_ok):
611+
read_preference, session) as (sock_info, slave_ok):
612612
return self._command(sock_info, command, slave_ok, value,
613613
check, allowable_errors, read_preference,
614614
codec_options, session=session, **kwargs)
@@ -671,7 +671,7 @@ def list_collections(self, session=None, **kwargs):
671671
read_pref = ((session and session._txn_read_preference())
672672
or ReadPreference.PRIMARY)
673673
with self.__client._socket_for_reads(
674-
read_pref) as (sock_info, slave_okay):
674+
read_pref, session) as (sock_info, slave_okay):
675675
return self._list_collections(
676676
sock_info, slave_okay, session, read_preference=read_pref,
677677
**kwargs)
@@ -745,7 +745,7 @@ def drop_collection(self, name_or_collection, session=None):
745745

746746
self.__client._purge_index(self.__name, name)
747747

748-
with self.__client._socket_for_writes() as sock_info:
748+
with self.__client._socket_for_writes(session) as sock_info:
749749
return self._command(
750750
sock_info, 'drop', value=_unicode(name),
751751
allowable_errors=['ns not found'],
@@ -826,7 +826,7 @@ def current_op(self, include_all=False, session=None):
826826
Added ``session`` parameter.
827827
"""
828828
cmd = SON([("currentOp", 1), ("$all", include_all)])
829-
with self.__client._socket_for_writes() as sock_info:
829+
with self.__client._socket_for_writes(session) as sock_info:
830830
if sock_info.max_wire_version >= 4:
831831
with self.__client._tmp_session(session) as s:
832832
return sock_info.command("admin", cmd, session=s,

pymongo/mongo_client.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,8 @@ def _end_sessions(self, session_ids):
10001000
# Use SocketInfo.command directly to avoid implicitly creating
10011001
# another session.
10021002
with self._socket_for_reads(
1003-
ReadPreference.PRIMARY_PREFERRED) as (sock_info, slave_ok):
1003+
ReadPreference.PRIMARY_PREFERRED,
1004+
None) as (sock_info, slave_ok):
10041005
if not sock_info.supports_sessions:
10051006
return
10061007

@@ -1102,12 +1103,40 @@ def _get_socket(self, server):
11021103
self.__reset_server(server.description.address)
11031104
raise
11041105

1105-
def _socket_for_writes(self):
1106-
server = self._get_topology().select_server(writable_server_selector)
1106+
def _select_server(self, server_selector, session, address=None):
1107+
"""Select a server to run an operation on this client.
1108+
1109+
:Parameters:
1110+
- `server_selector`: The server selector to use if the session is
1111+
not pinned and no address is given.
1112+
- `session`: The ClientSession for the next operation, or None. May
1113+
be pinned to a mongos server address.
1114+
- `address` (optional): Address when sending a message
1115+
to a specific server, used for getMore.
1116+
"""
1117+
topology = self._get_topology()
1118+
address = address or (session and session._pinned_address)
1119+
if address:
1120+
# We're running a getMore or this session is pinned to a mongos.
1121+
server = topology.select_server_by_address(address)
1122+
if not server:
1123+
raise AutoReconnect('server %s:%d no longer available'
1124+
% address)
1125+
else:
1126+
server = topology.select_server(server_selector)
1127+
# Pin this session to the selected server if it's performing a
1128+
# sharded transaction.
1129+
if server.description.mongos and (session and
1130+
session._in_transaction):
1131+
session._pin_mongos(server)
1132+
return server
1133+
1134+
def _socket_for_writes(self, session):
1135+
server = self._select_server(writable_server_selector, session)
11071136
return self._get_socket(server)
11081137

11091138
@contextlib.contextmanager
1110-
def _socket_for_reads(self, read_preference):
1139+
def _socket_for_reads(self, read_preference, session):
11111140
assert read_preference is not None, "read_preference must not be None"
11121141
# Get a socket for a server matching the read preference, and yield
11131142
# sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to
@@ -1117,7 +1146,7 @@ def _socket_for_reads(self, read_preference):
11171146
# Thread safe: if the type is single it cannot change.
11181147
topology = self._get_topology()
11191148
single = topology.description.topology_type == TOPOLOGY_TYPE.Single
1120-
server = topology.select_server(read_preference)
1149+
server = self._select_server(read_preference, session)
11211150

11221151
with self._get_socket(server) as sock_info:
11231152
slave_ok = (single and not sock_info.is_mongos) or (
@@ -1140,14 +1169,9 @@ def _send_message_with_response(self, operation, exhaust=False,
11401169
# If needed, restart kill-cursors thread after a fork.
11411170
self._kill_cursors_executor.open()
11421171

1172+
server = self._select_server(
1173+
operation.read_preference, operation.session, address=address)
11431174
topology = self._get_topology()
1144-
if address:
1145-
server = topology.select_server_by_address(address)
1146-
if not server:
1147-
raise AutoReconnect('server %s:%d no longer available'
1148-
% address)
1149-
else:
1150-
server = topology.select_server(operation.read_preference)
11511175

11521176
# If this is a direct connection to a mongod, *always* set the slaveOk
11531177
# bit. See bullet point 2 in server-selection.rst#topology-type-single.
@@ -1207,8 +1231,7 @@ def is_retrying():
12071231

12081232
while True:
12091233
try:
1210-
server = self._get_topology().select_server(
1211-
writable_server_selector)
1234+
server = self._select_server(writable_server_selector, session)
12121235
supports_session = (
12131236
session is not None and
12141237
server.description.retryable_writes_supported)
@@ -1737,7 +1760,7 @@ def drop_database(self, name_or_database, session=None):
17371760
"of %s or a Database" % (string_type.__name__,))
17381761

17391762
self._purge_index(name)
1740-
with self._socket_for_writes() as sock_info:
1763+
with self._socket_for_writes(session) as sock_info:
17411764
self[name]._command(
17421765
sock_info,
17431766
"dropDatabase",
@@ -1878,7 +1901,7 @@ def unlock(self, session=None):
18781901
Added ``session`` parameter.
18791902
"""
18801903
cmd = SON([("fsyncUnlock", 1)])
1881-
with self._socket_for_writes() as sock_info:
1904+
with self._socket_for_writes(session) as sock_info:
18821905
if sock_info.max_wire_version >= 4:
18831906
try:
18841907
with self._tmp_session(session) as s:

pymongo/server_description.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,10 @@ def is_writable(self):
187187
def is_readable(self):
188188
return self._is_readable
189189

190+
@property
191+
def mongos(self):
192+
return self._server_type == SERVER_TYPE.Mongos
193+
190194
@property
191195
def is_server_type_known(self):
192196
return self.server_type != SERVER_TYPE.Unknown

test/__init__.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(self):
162162
self.auth_enabled = False
163163
self.test_commands_enabled = False
164164
self.is_mongos = False
165+
self.mongoses = []
165166
self.is_rs = False
166167
self.has_ipv6 = False
167168
self.ssl = False
@@ -295,6 +296,17 @@ def _init_client(self):
295296

296297
self.is_mongos = (self.ismaster.get('msg') == 'isdbgrid')
297298
self.has_ipv6 = self._server_started_with_ipv6()
299+
if self.is_mongos:
300+
# Check for another mongos on the next port.
301+
address = self.client.address
302+
next_address = address[0], address[1] + 1
303+
self.mongoses.append(address)
304+
mongos_client = self._connect(*next_address,
305+
**self.default_client_options)
306+
if mongos_client:
307+
ismaster = mongos_client.admin.command('ismaster')
308+
if ismaster.get('msg') == 'isdbgrid':
309+
self.mongoses.append(next_address)
298310

299311
def init(self):
300312
with self.conn_lock:
@@ -507,6 +519,13 @@ def require_mongos(self, func):
507519
"Must be connected to a mongos",
508520
func=func)
509521

522+
def require_multiple_mongoses(self, func):
523+
"""Run a test only if the client is connected to a sharded cluster
524+
that has 2 mongos nodes."""
525+
return self._require(lambda: len(self.mongoses) > 1,
526+
"Must have multiple mongoses available",
527+
func=func)
528+
510529
def require_standalone(self, func):
511530
"""Run a test only if the client is connected to a standalone."""
512531
return self._require(lambda: not (self.is_mongos or self.is_rs),
@@ -586,13 +605,25 @@ def require_sessions(self, func):
586605
"Sessions not supported",
587606
func=func)
588607

608+
def supports_transactions(self):
609+
if self.version.at_least(4, 1, 6):
610+
return self.is_mongos or self.is_rs
611+
612+
if self.version.at_least(4, 0):
613+
return self.is_rs
614+
return False
615+
589616
def require_transactions(self, func):
590617
"""Run a test only if the deployment might support transactions.
591618
592619
*Might* because this does not test the storage engine or FCV.
593620
"""
594-
new_func = self.require_version_min(4, 0, 0, -1)(func)
595-
return self.require_replica_set(new_func)
621+
return self._require(self.supports_transactions,
622+
"Transactions are not supported",
623+
func=func)
624+
625+
def mongos_seeds(self):
626+
return ','.join('%s:%s' % address for address in self.mongoses)
596627

597628
@property
598629
def supports_reindex(self):

test/test_read_preferences.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,9 @@ def __init__(self, *args, **kwargs):
318318
super(ReadPrefTester, self).__init__(*args, **client_options)
319319

320320
@contextlib.contextmanager
321-
def _socket_for_reads(self, read_preference):
321+
def _socket_for_reads(self, read_preference, session):
322322
context = super(ReadPrefTester, self)._socket_for_reads(
323-
read_preference)
323+
read_preference, session)
324324
with context as (sock_info, slave_ok):
325325
self.record_a_read(sock_info.address)
326326
yield sock_info, slave_ok

0 commit comments

Comments
 (0)