55
55
56
56
from six import iteritems
57
57
58
+ import txredisapi as redis
58
59
from prometheus_client import Counter
59
60
60
61
from twisted .protocols .basic import LineOnlyReceiver
61
62
from twisted .python .failure import Failure
62
63
64
+ from synapse .logging .context import PreserveLoggingContext
63
65
from synapse .metrics import LaterGauge
64
66
from synapse .metrics .background_process_metrics import run_as_background_process
65
67
from synapse .replication .tcp .commands import (
@@ -420,6 +422,8 @@ class CommandHandler:
420
422
def __init__ (self , hs , handler ):
421
423
self .handler = handler
422
424
425
+ self .is_master = hs .config .worker .worker_app is None
426
+
423
427
self .clock = hs .get_clock ()
424
428
425
429
self .streams = {
@@ -458,11 +462,22 @@ def lost_connection(self, connection):
458
462
self .handler .lost_connection (connection )
459
463
460
464
async def on_USER_SYNC (self , cmd : UserSyncCommand ):
465
+ if not self .connection :
466
+ raise Exception ("Not connected" )
467
+
461
468
await self .handler .on_user_sync (
462
469
self .connection .conn_id , cmd .user_id , cmd .is_syncing , cmd .last_sync_ms
463
470
)
464
471
465
472
async def on_REPLICATE (self , cmd : ReplicateCommand ):
473
+ # We only want to announce positions by the writer of the streams.
474
+ # Currently this is just the master process.
475
+ if not self .is_master :
476
+ return
477
+
478
+ if not self .connection :
479
+ raise Exception ("Not connected" )
480
+
466
481
for stream_name , stream in self .streams .items ():
467
482
current_token = stream .current_token ()
468
483
self .connection .send_command (PositionCommand (stream_name , current_token ))
@@ -483,15 +498,14 @@ async def on_SYNC(self, cmd: SyncCommand):
483
498
self .handler .on_sync (cmd .data )
484
499
485
500
async def on_RDATA (self , cmd : RdataCommand ):
501
+
486
502
stream_name = cmd .stream_name
487
503
inbound_rdata_count .labels (stream_name ).inc ()
488
504
489
505
try :
490
506
row = STREAMS_MAP [stream_name ].parse_row (cmd .row )
491
507
except Exception :
492
- logger .exception (
493
- "[%s] Failed to parse RDATA: %r %r" , self .id (), stream_name , cmd .row
494
- )
508
+ logger .exception ("[%s] Failed to parse RDATA: %r" , stream_name , cmd .row )
495
509
raise
496
510
497
511
if cmd .token is None or stream_name in self .streams_connecting :
@@ -519,7 +533,7 @@ async def on_POSITION(self, cmd: PositionCommand):
519
533
return
520
534
521
535
# Fetch all updates between then and now.
522
- limited = True
536
+ limited = cmd . token != current_token
523
537
while limited :
524
538
updates , current_token , limited = await stream .get_updates_since (
525
539
current_token , cmd .token
@@ -582,7 +596,7 @@ def lost_connection(self, connection):
582
596
raise NotImplementedError ()
583
597
584
598
@abc .abstractmethod
585
- def on_user_sync (
599
+ async def on_user_sync (
586
600
self , conn_id : str , user_id : str , is_syncing : bool , last_sync_ms : int
587
601
):
588
602
"""A client has started/stopped syncing on a worker.
@@ -794,3 +808,112 @@ def transport_kernel_read_buffer_size(protocol, read=True):
794
808
inbound_rdata_count = Counter (
795
809
"synapse_replication_tcp_protocol_inbound_rdata_count" , "" , ["stream_name" ]
796
810
)
811
+
812
+
813
+ class RedisSubscriber (redis .SubscriberProtocol ):
814
+ def connectionMade (self ):
815
+ logger .info ("MADE CONNECTION" )
816
+ self .subscribe (self .stream_name )
817
+ self .send_command (ReplicateCommand ("ALL" ))
818
+
819
+ self .handler .new_connection (self )
820
+
821
+ def messageReceived (self , pattern , channel , message ):
822
+ if message .strip () == "" :
823
+ # Ignore blank lines
824
+ return
825
+
826
+ line = message
827
+ cmd_name , rest_of_line = line .split (" " , 1 )
828
+
829
+ cmd_cls = COMMAND_MAP [cmd_name ]
830
+ try :
831
+ cmd = cmd_cls .from_line (rest_of_line )
832
+ except Exception as e :
833
+ logger .exception (
834
+ "[%s] failed to parse line %r: %r" , self .id (), cmd_name , rest_of_line
835
+ )
836
+ self .send_error (
837
+ "failed to parse line for %r: %r (%r):" % (cmd_name , e , rest_of_line )
838
+ )
839
+ return
840
+
841
+ # Now lets try and call on_<CMD_NAME> function
842
+ run_as_background_process (
843
+ "replication-" + cmd .get_logcontext_id (), self .handle_command , cmd
844
+ )
845
+
846
+ async def handle_command (self , cmd : Command ):
847
+ """Handle a command we have received over the replication stream.
848
+
849
+ By default delegates to on_<COMMAND>, which should return an awaitable.
850
+
851
+ Args:
852
+ cmd: received command
853
+ """
854
+ # First call any command handlers on this instance. These are for TCP
855
+ # specific handling.
856
+ cmd_func = getattr (self , "on_%s" % (cmd .NAME ,), None )
857
+ if cmd_func :
858
+ await cmd_func (cmd )
859
+
860
+ # Then call out to the handler.
861
+ cmd_func = getattr (self .handler , "on_%s" % (cmd .NAME ,), None )
862
+ if cmd_func :
863
+ await cmd_func (cmd )
864
+
865
+ def connectionLost (self , reason ):
866
+ logger .info ("LOST CONNECTION" )
867
+ self .handler .lost_connection (self )
868
+
869
+ def send_command (self , cmd ):
870
+ """Send a command if connection has been established.
871
+
872
+ Args:
873
+ cmd (Command)
874
+ """
875
+ string = "%s %s" % (cmd .NAME , cmd .to_line ())
876
+ if "\n " in string :
877
+ raise Exception ("Unexpected newline in command: %r" , string )
878
+
879
+ encoded_string = string .encode ("utf-8" )
880
+
881
+ async def _send ():
882
+ with PreserveLoggingContext ():
883
+ await self .redis_connection .publish (self .stream_name , encoded_string )
884
+
885
+ run_as_background_process ("send-cmd" , _send )
886
+
887
+ def stream_update (self , stream_name , token , data ):
888
+ """Called when a new update is available to stream to clients.
889
+
890
+ We need to check if the client is interested in the stream or not
891
+ """
892
+ self .send_command (RdataCommand (stream_name , token , data ))
893
+
894
+ def send_sync (self , data ):
895
+ self .send_command (SyncCommand (data ))
896
+
897
+ def send_remote_server_up (self , server : str ):
898
+ self .send_command (RemoteServerUpCommand (server ))
899
+
900
+
901
+ class RedisFactory (redis .SubscriberFactory ):
902
+
903
+ maxDelay = 5
904
+ continueTrying = True
905
+ protocol = RedisSubscriber
906
+
907
+ def __init__ (self , hs , handler ):
908
+ super (RedisFactory , self ).__init__ ()
909
+
910
+ self .handler = CommandHandler (hs , handler )
911
+ self .stream_name = hs .hostname
912
+
913
+ def buildProtocol (self , addr ):
914
+ p = super (RedisFactory , self ).buildProtocol (addr )
915
+ p .handler = self .handler
916
+ p .redis_connection = redis .lazyConnection ("redis" )
917
+ p .conn_id = random_string (5 ) # TODO: FIXME
918
+ p .stream_name = self .stream_name
919
+ return p
0 commit comments