Skip to content

Commit c094eb0

Browse files
committed
Fix datachannel data races by locking
1 parent b885979 commit c094eb0

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

src/datachannel/streaming.go

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ type DataChannel struct {
119119

120120
// AgentVersion received during handshake
121121
agentVersion string
122+
123+
mutex sync.Mutex
122124
}
123125

124126
type ListMessageBuffer struct {
@@ -279,6 +281,9 @@ func (dataChannel *DataChannel) SendInputDataMessage(
279281
payloadType message.PayloadType,
280282
inputData []byte) (err error) {
281283

284+
dataChannel.mutex.Lock()
285+
defer dataChannel.mutex.Unlock()
286+
282287
var (
283288
flag uint64 = 0
284289
msg []byte
@@ -343,12 +348,16 @@ func (dataChannel *DataChannel) ResendStreamDataMessageScheduler(log log.T) (err
343348
streamMessageElement := dataChannel.OutgoingMessageBuffer.Messages.Front()
344349
dataChannel.OutgoingMessageBuffer.Mutex.Unlock()
345350

351+
dataChannel.mutex.Lock()
352+
localTimeout := dataChannel.RetransmissionTimeout
353+
dataChannel.mutex.Unlock()
354+
346355
if streamMessageElement == nil {
347356
continue
348357
}
349358

350359
streamMessage := streamMessageElement.Value.(StreamingMessage)
351-
if time.Since(streamMessage.LastSentTime) > dataChannel.RetransmissionTimeout {
360+
if time.Since(streamMessage.LastSentTime) > localTimeout {
352361
log.Debugf("Resend stream data message %d for the %d attempt.", streamMessage.SequenceNumber, *streamMessage.ResendAttempt)
353362
if *streamMessage.ResendAttempt >= config.ResendMaxAttempt {
354363
log.Warnf("Message %d was resent over %d times.", streamMessage.SequenceNumber, config.ResendMaxAttempt)
@@ -368,6 +377,9 @@ func (dataChannel *DataChannel) ResendStreamDataMessageScheduler(log log.T) (err
368377

369378
// ProcessAcknowledgedMessage processes acknowledge messages by deleting them from OutgoingMessageBuffer
370379
func (dataChannel *DataChannel) ProcessAcknowledgedMessage(log log.T, acknowledgeMessageContent message.AcknowledgeContent) error {
380+
dataChannel.mutex.Lock()
381+
defer dataChannel.mutex.Unlock()
382+
371383
acknowledgeSequenceNumber := acknowledgeMessageContent.SequenceNumber
372384
for streamMessageElement := dataChannel.OutgoingMessageBuffer.Messages.Front(); streamMessageElement != nil; streamMessageElement = streamMessageElement.Next() {
373385
streamMessage := streamMessageElement.Value.(StreamingMessage)
@@ -615,6 +627,8 @@ func (dataChannel *DataChannel) HandleOutputMessage(
615627
outputMessage message.ClientMessage,
616628
rawMessage []byte) (err error) {
617629

630+
dataChannel.mutex.Lock()
631+
618632
// On receiving expected stream data message, send acknowledgement, process it and increment expected sequence number by 1.
619633
// Further process messages from IncomingMessageBuffer
620634
if outputMessage.SequenceNumber == dataChannel.ExpectedSequenceNumber {
@@ -623,40 +637,51 @@ func (dataChannel *DataChannel) HandleOutputMessage(
623637
case message.HandshakeRequestPayloadType:
624638
{
625639
if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil {
640+
dataChannel.mutex.Unlock()
626641
return err
627642
}
628643

629644
// PayloadType is HandshakeRequest so we call our own handler instead of the provided handler
630645
log.Debugf("Processing HandshakeRequest message %s", outputMessage)
646+
647+
// The handler will eventually request the lock in `SendInputDataMessage`, so we'll unlock here to avoid deadlock
648+
dataChannel.mutex.Unlock()
631649
if err = dataChannel.handleHandshakeRequest(log, outputMessage); err != nil {
632650
log.Errorf("Unable to process incoming data payload, MessageType %s, "+
633651
"PayloadType HandshakeRequestPayloadType, err: %s.", outputMessage.MessageType, err)
634652
return err
635653
}
654+
dataChannel.mutex.Lock()
636655
}
637656
case message.HandshakeCompletePayloadType:
638657
{
639658
if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil {
659+
dataChannel.mutex.Unlock()
640660
return err
641661
}
642662

663+
dataChannel.mutex.Unlock()
643664
if err = dataChannel.handleHandshakeComplete(log, outputMessage); err != nil {
644665
log.Errorf("Unable to process incoming data payload, MessageType %s, "+
645666
"PayloadType HandshakeCompletePayloadType, err: %s.", outputMessage.MessageType, err)
646667
return err
647668
}
669+
dataChannel.mutex.Lock()
648670
}
649671
case message.EncChallengeRequest:
650672
{
651673
if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil {
674+
dataChannel.mutex.Unlock()
652675
return err
653676
}
654677

678+
dataChannel.mutex.Unlock()
655679
if err = dataChannel.handleEncryptionChallengeRequest(log, outputMessage); err != nil {
656680
log.Errorf("Unable to process incoming data payload, MessageType %s, "+
657681
"PayloadType EncChallengeRequest, err: %s.", outputMessage.MessageType, err)
658682
return err
659683
}
684+
dataChannel.mutex.Lock()
660685
}
661686
default:
662687

@@ -686,11 +711,13 @@ func (dataChannel *DataChannel) HandleOutputMessage(
686711
} else {
687712
// Acknowledge outputMessage only if session specific handler is ready
688713
if err := SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil {
714+
dataChannel.mutex.Unlock()
689715
return err
690716
}
691717
}
692718
}
693719
dataChannel.ExpectedSequenceNumber = dataChannel.ExpectedSequenceNumber + 1
720+
dataChannel.mutex.Unlock()
694721
return dataChannel.ProcessIncomingMessageBufferItems(log, outputMessage)
695722
} else {
696723
log.Debugf("Unexpected sequence message received. Received Sequence Number: %d. Expected Sequence Number: %d",
@@ -703,6 +730,7 @@ func (dataChannel *DataChannel) HandleOutputMessage(
703730
outputMessage.SequenceNumber, dataChannel.ExpectedSequenceNumber)
704731
if len(dataChannel.IncomingMessageBuffer.Messages) < dataChannel.IncomingMessageBuffer.Capacity {
705732
if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil {
733+
dataChannel.mutex.Unlock()
706734
return err
707735
}
708736

@@ -718,6 +746,7 @@ func (dataChannel *DataChannel) HandleOutputMessage(
718746
}
719747
}
720748
}
749+
dataChannel.mutex.Unlock()
721750
return nil
722751
}
723752

@@ -727,6 +756,9 @@ func (dataChannel *DataChannel) HandleOutputMessage(
727756
func (dataChannel *DataChannel) ProcessIncomingMessageBufferItems(log log.T,
728757
outputMessage message.ClientMessage) (err error) {
729758

759+
dataChannel.mutex.Lock()
760+
defer dataChannel.mutex.Unlock()
761+
730762
for {
731763
bufferedStreamMessage := dataChannel.IncomingMessageBuffer.Messages[dataChannel.ExpectedSequenceNumber]
732764
if bufferedStreamMessage.Content != nil {
@@ -900,12 +932,18 @@ func (dataChannel *DataChannel) IsSessionTypeSet() chan bool {
900932

901933
// IsSessionEnded check if session has ended
902934
func (dataChannel *DataChannel) IsSessionEnded() bool {
903-
return dataChannel.isSessionEnded
935+
dataChannel.mutex.Lock()
936+
sessionEnded := dataChannel.isSessionEnded
937+
dataChannel.mutex.Unlock()
938+
return sessionEnded
904939
}
905940

906941
// IsSessionEnded check if session has ended
907942
func (dataChannel *DataChannel) EndSession() error {
943+
dataChannel.mutex.Lock()
908944
dataChannel.isSessionEnded = true
945+
dataChannel.mutex.Unlock()
946+
909947
return nil
910948
}
911949

0 commit comments

Comments
 (0)