@@ -119,6 +119,8 @@ type DataChannel struct {
119
119
120
120
// AgentVersion received during handshake
121
121
agentVersion string
122
+
123
+ mutex sync.Mutex
122
124
}
123
125
124
126
type ListMessageBuffer struct {
@@ -279,6 +281,9 @@ func (dataChannel *DataChannel) SendInputDataMessage(
279
281
payloadType message.PayloadType ,
280
282
inputData []byte ) (err error ) {
281
283
284
+ dataChannel .mutex .Lock ()
285
+ defer dataChannel .mutex .Unlock ()
286
+
282
287
var (
283
288
flag uint64 = 0
284
289
msg []byte
@@ -343,12 +348,16 @@ func (dataChannel *DataChannel) ResendStreamDataMessageScheduler(log log.T) (err
343
348
streamMessageElement := dataChannel .OutgoingMessageBuffer .Messages .Front ()
344
349
dataChannel .OutgoingMessageBuffer .Mutex .Unlock ()
345
350
351
+ dataChannel .mutex .Lock ()
352
+ localTimeout := dataChannel .RetransmissionTimeout
353
+ dataChannel .mutex .Unlock ()
354
+
346
355
if streamMessageElement == nil {
347
356
continue
348
357
}
349
358
350
359
streamMessage := streamMessageElement .Value .(StreamingMessage )
351
- if time .Since (streamMessage .LastSentTime ) > dataChannel . RetransmissionTimeout {
360
+ if time .Since (streamMessage .LastSentTime ) > localTimeout {
352
361
log .Debugf ("Resend stream data message %d for the %d attempt." , streamMessage .SequenceNumber , * streamMessage .ResendAttempt )
353
362
if * streamMessage .ResendAttempt >= config .ResendMaxAttempt {
354
363
log .Warnf ("Message %d was resent over %d times." , streamMessage .SequenceNumber , config .ResendMaxAttempt )
@@ -368,6 +377,9 @@ func (dataChannel *DataChannel) ResendStreamDataMessageScheduler(log log.T) (err
368
377
369
378
// ProcessAcknowledgedMessage processes acknowledge messages by deleting them from OutgoingMessageBuffer
370
379
func (dataChannel * DataChannel ) ProcessAcknowledgedMessage (log log.T , acknowledgeMessageContent message.AcknowledgeContent ) error {
380
+ dataChannel .mutex .Lock ()
381
+ defer dataChannel .mutex .Unlock ()
382
+
371
383
acknowledgeSequenceNumber := acknowledgeMessageContent .SequenceNumber
372
384
for streamMessageElement := dataChannel .OutgoingMessageBuffer .Messages .Front (); streamMessageElement != nil ; streamMessageElement = streamMessageElement .Next () {
373
385
streamMessage := streamMessageElement .Value .(StreamingMessage )
@@ -615,6 +627,8 @@ func (dataChannel *DataChannel) HandleOutputMessage(
615
627
outputMessage message.ClientMessage ,
616
628
rawMessage []byte ) (err error ) {
617
629
630
+ dataChannel .mutex .Lock ()
631
+
618
632
// On receiving expected stream data message, send acknowledgement, process it and increment expected sequence number by 1.
619
633
// Further process messages from IncomingMessageBuffer
620
634
if outputMessage .SequenceNumber == dataChannel .ExpectedSequenceNumber {
@@ -623,40 +637,51 @@ func (dataChannel *DataChannel) HandleOutputMessage(
623
637
case message .HandshakeRequestPayloadType :
624
638
{
625
639
if err = SendAcknowledgeMessageCall (log , dataChannel , outputMessage ); err != nil {
640
+ dataChannel .mutex .Unlock ()
626
641
return err
627
642
}
628
643
629
644
// PayloadType is HandshakeRequest so we call our own handler instead of the provided handler
630
645
log .Debugf ("Processing HandshakeRequest message %s" , outputMessage )
646
+
647
+ // The handler will eventually request the lock in `SendInputDataMessage`, so we'll unlock here to avoid deadlock
648
+ dataChannel .mutex .Unlock ()
631
649
if err = dataChannel .handleHandshakeRequest (log , outputMessage ); err != nil {
632
650
log .Errorf ("Unable to process incoming data payload, MessageType %s, " +
633
651
"PayloadType HandshakeRequestPayloadType, err: %s." , outputMessage .MessageType , err )
634
652
return err
635
653
}
654
+ dataChannel .mutex .Lock ()
636
655
}
637
656
case message .HandshakeCompletePayloadType :
638
657
{
639
658
if err = SendAcknowledgeMessageCall (log , dataChannel , outputMessage ); err != nil {
659
+ dataChannel .mutex .Unlock ()
640
660
return err
641
661
}
642
662
663
+ dataChannel .mutex .Unlock ()
643
664
if err = dataChannel .handleHandshakeComplete (log , outputMessage ); err != nil {
644
665
log .Errorf ("Unable to process incoming data payload, MessageType %s, " +
645
666
"PayloadType HandshakeCompletePayloadType, err: %s." , outputMessage .MessageType , err )
646
667
return err
647
668
}
669
+ dataChannel .mutex .Lock ()
648
670
}
649
671
case message .EncChallengeRequest :
650
672
{
651
673
if err = SendAcknowledgeMessageCall (log , dataChannel , outputMessage ); err != nil {
674
+ dataChannel .mutex .Unlock ()
652
675
return err
653
676
}
654
677
678
+ dataChannel .mutex .Unlock ()
655
679
if err = dataChannel .handleEncryptionChallengeRequest (log , outputMessage ); err != nil {
656
680
log .Errorf ("Unable to process incoming data payload, MessageType %s, " +
657
681
"PayloadType EncChallengeRequest, err: %s." , outputMessage .MessageType , err )
658
682
return err
659
683
}
684
+ dataChannel .mutex .Lock ()
660
685
}
661
686
default :
662
687
@@ -686,11 +711,13 @@ func (dataChannel *DataChannel) HandleOutputMessage(
686
711
} else {
687
712
// Acknowledge outputMessage only if session specific handler is ready
688
713
if err := SendAcknowledgeMessageCall (log , dataChannel , outputMessage ); err != nil {
714
+ dataChannel .mutex .Unlock ()
689
715
return err
690
716
}
691
717
}
692
718
}
693
719
dataChannel .ExpectedSequenceNumber = dataChannel .ExpectedSequenceNumber + 1
720
+ dataChannel .mutex .Unlock ()
694
721
return dataChannel .ProcessIncomingMessageBufferItems (log , outputMessage )
695
722
} else {
696
723
log .Debugf ("Unexpected sequence message received. Received Sequence Number: %d. Expected Sequence Number: %d" ,
@@ -703,6 +730,7 @@ func (dataChannel *DataChannel) HandleOutputMessage(
703
730
outputMessage .SequenceNumber , dataChannel .ExpectedSequenceNumber )
704
731
if len (dataChannel .IncomingMessageBuffer .Messages ) < dataChannel .IncomingMessageBuffer .Capacity {
705
732
if err = SendAcknowledgeMessageCall (log , dataChannel , outputMessage ); err != nil {
733
+ dataChannel .mutex .Unlock ()
706
734
return err
707
735
}
708
736
@@ -718,6 +746,7 @@ func (dataChannel *DataChannel) HandleOutputMessage(
718
746
}
719
747
}
720
748
}
749
+ dataChannel .mutex .Unlock ()
721
750
return nil
722
751
}
723
752
@@ -727,6 +756,9 @@ func (dataChannel *DataChannel) HandleOutputMessage(
727
756
func (dataChannel * DataChannel ) ProcessIncomingMessageBufferItems (log log.T ,
728
757
outputMessage message.ClientMessage ) (err error ) {
729
758
759
+ dataChannel .mutex .Lock ()
760
+ defer dataChannel .mutex .Unlock ()
761
+
730
762
for {
731
763
bufferedStreamMessage := dataChannel .IncomingMessageBuffer .Messages [dataChannel .ExpectedSequenceNumber ]
732
764
if bufferedStreamMessage .Content != nil {
@@ -900,12 +932,18 @@ func (dataChannel *DataChannel) IsSessionTypeSet() chan bool {
900
932
901
933
// IsSessionEnded check if session has ended
902
934
func (dataChannel * DataChannel ) IsSessionEnded () bool {
903
- return dataChannel .isSessionEnded
935
+ dataChannel .mutex .Lock ()
936
+ sessionEnded := dataChannel .isSessionEnded
937
+ dataChannel .mutex .Unlock ()
938
+ return sessionEnded
904
939
}
905
940
906
941
// IsSessionEnded check if session has ended
907
942
func (dataChannel * DataChannel ) EndSession () error {
943
+ dataChannel .mutex .Lock ()
908
944
dataChannel .isSessionEnded = true
945
+ dataChannel .mutex .Unlock ()
946
+
909
947
return nil
910
948
}
911
949
0 commit comments