@@ -10,6 +10,7 @@ import (
1010 "github.com/google/go-cmp/cmp"
1111 "go.mongodb.org/mongo-driver/bson/bsontype"
1212 "go.mongodb.org/mongo-driver/bson/primitive"
13+ "go.mongodb.org/mongo-driver/internal/testutil/assert"
1314 "go.mongodb.org/mongo-driver/mongo/readconcern"
1415 "go.mongodb.org/mongo-driver/mongo/readpref"
1516 "go.mongodb.org/mongo-driver/mongo/writeconcern"
@@ -518,6 +519,93 @@ func TestOperation(t *testing.T) {
518519 })
519520 }
520521 })
522+ t .Run ("ExecuteExhaust" , func (t * testing.T ) {
523+ t .Run ("errors if connection is not streaming" , func (t * testing.T ) {
524+ conn := & mockConnection {
525+ rStreaming : false ,
526+ }
527+ err := Operation {}.ExecuteExhaust (context .TODO (), conn , nil )
528+ assert .NotNil (t , err , "expected error, got nil" )
529+ })
530+ })
531+ t .Run ("exhaustAllowed and moreToCome" , func (t * testing.T ) {
532+ // Test the interaction between exhaustAllowed and moreToCome on requests/responses when using the Execute
533+ // and ExecuteExhaust methods.
534+
535+ // Create a server response wire message that has moreToCome=false.
536+ serverResponseDoc := bsoncore .BuildDocumentFromElements (nil ,
537+ bsoncore .AppendInt32Element (nil , "ok" , 1 ),
538+ )
539+ nonStreamingResponse := createExhaustServerResponse (t , serverResponseDoc , false )
540+
541+ // Create a connection that reports that it cannot stream messages.
542+ conn := & mockConnection {
543+ rDesc : description.Server {
544+ WireVersion : & description.VersionRange {
545+ Max : 6 ,
546+ },
547+ },
548+ rReadWM : nonStreamingResponse ,
549+ rCanStream : false ,
550+ }
551+ op := Operation {
552+ CommandFn : func (dst []byte , desc description.SelectedServer ) ([]byte , error ) {
553+ return bsoncore .AppendInt32Element (dst , "isMaster" , 1 ), nil
554+ },
555+ Database : "admin" ,
556+ Deployment : SingleConnectionDeployment {conn },
557+ }
558+ err := op .Execute (context .TODO (), nil )
559+ assert .Nil (t , err , "Execute error: %v" , err )
560+
561+ // The wire message sent to the server should not have exhaustAllowed=true. After execution, the connection
562+ // should not be in a streaming state.
563+ assertExhaustAllowedSet (t , conn .pWriteWM , false )
564+ assert .False (t , conn .CurrentlyStreaming (), "expected CurrentlyStreaming to be false" )
565+
566+ // Modify the connection to report that it can stream and create a new server response with moreToCome=true.
567+ streamingResponse := createExhaustServerResponse (t , serverResponseDoc , true )
568+ conn .rReadWM = streamingResponse
569+ conn .rCanStream = true
570+ err = op .Execute (context .TODO (), nil )
571+ assert .Nil (t , err , "Execute error: %v" , err )
572+ assertExhaustAllowedSet (t , conn .pWriteWM , true )
573+ assert .True (t , conn .CurrentlyStreaming (), "expected CurrentlyStreaming to be true" )
574+
575+ // Reset the server response and go through ExecuteExhaust to mimic streaming the next response. After
576+ // execution, the connection should still be in a streaming state.
577+ conn .rReadWM = streamingResponse
578+ err = op .ExecuteExhaust (context .TODO (), conn , nil )
579+ assert .Nil (t , err , "ExecuteExhaust error: %v" , err )
580+ assert .True (t , conn .CurrentlyStreaming (), "expected CurrentlyStreaming to be true" )
581+ })
582+ }
583+
584+ func createExhaustServerResponse (t * testing.T , response bsoncore.Document , moreToCome bool ) []byte {
585+ idx , wm := wiremessage .AppendHeaderStart (nil , 0 , wiremessage .CurrentRequestID ()+ 1 , wiremessage .OpMsg )
586+ var flags wiremessage.MsgFlag
587+ if moreToCome {
588+ flags = wiremessage .MoreToCome
589+ }
590+ wm = wiremessage .AppendMsgFlags (wm , flags )
591+ wm = wiremessage .AppendMsgSectionType (wm , wiremessage .SingleDocument )
592+ wm = bsoncore .AppendDocument (wm , response )
593+ return bsoncore .UpdateLength (wm , idx , int32 (len (wm )))
594+ }
595+
596+ func assertExhaustAllowedSet (t * testing.T , wm []byte , expected bool ) {
597+ t .Helper ()
598+ _ , _ , _ , _ , wm , ok := wiremessage .ReadHeader (wm )
599+ if ! ok {
600+ t .Fatal ("could not read wm header" )
601+ }
602+ flags , wm , ok := wiremessage .ReadMsgFlags (wm )
603+ if ! ok {
604+ t .Fatal ("could not read wm flags" )
605+ }
606+
607+ actual := flags & wiremessage .ExhaustAllowed > 0
608+ assert .Equal (t , expected , actual , "expected exhaustAllowed set %v, got %v" , expected , actual )
521609}
522610
523611type mockDeployment struct {
@@ -554,20 +642,25 @@ type mockConnection struct {
554642 pReadDst []byte
555643
556644 // returns
557- rWriteErr error
558- rReadWM []byte
559- rReadErr error
560- rDesc description.Server
561- rCloseErr error
562- rID string
563- rAddr address.Address
645+ rWriteErr error
646+ rReadWM []byte
647+ rReadErr error
648+ rDesc description.Server
649+ rCloseErr error
650+ rID string
651+ rAddr address.Address
652+ rCanStream bool
653+ rStreaming bool
564654}
565655
566656func (m * mockConnection ) Description () description.Server { return m .rDesc }
567657func (m * mockConnection ) Close () error { return m .rCloseErr }
568658func (m * mockConnection ) ID () string { return m .rID }
569659func (m * mockConnection ) Address () address.Address { return m .rAddr }
570660func (m * mockConnection ) Stale () bool { return false }
661+ func (m * mockConnection ) SupportsStreaming () bool { return m .rCanStream }
662+ func (m * mockConnection ) CurrentlyStreaming () bool { return m .rStreaming }
663+ func (m * mockConnection ) SetStreaming (streaming bool ) { m .rStreaming = streaming }
571664
572665func (m * mockConnection ) WriteWireMessage (_ context.Context , wm []byte ) error {
573666 m .pWriteWM = wm
0 commit comments