diff --git a/.github/workflows/ginkgo_test.yml b/.github/workflows/ginkgo_test.yml index 9bbdc6f35..0ad6aed04 100644 --- a/.github/workflows/ginkgo_test.yml +++ b/.github/workflows/ginkgo_test.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: # os: [ "ubuntu-latest", "windows-latest", "macos-latest" ] - go: [ "1.20.x", "1.21.0-rc.4" ] + go: [ "1.20.x", "1.21.x" ] runs-on: "ubuntu-latest" steps: - uses: actions/checkout@v3 @@ -23,20 +23,36 @@ jobs: go-version: ${{ matrix.go }} - run: go version - - name: Run tests + - name: Run unit tests env: TIMESCALE_FACTOR: 10 run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -cover -randomize-all -randomize-suites -trace -skip-package integrationtests - # - name: Run tests (32 bit) + # - name: Run unit tests (32 bit) # if: ${{ matrix.os != 'macos' }} # can't run 32 bit tests on OSX. # env: # TIMESCALE_FACTOR: 10 # GOARCH: 386 # run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -cover -coverprofile coverage.txt -output-dir . -randomize-all -randomize-suites -trace -skip-package integrationtests - # - name: Run tests with race detector + # - name: Run unit tests with race detector # if: ${{ matrix.os == 'ubuntu' }} # speed things up. Windows and OSX VMs are slow # env: # TIMESCALE_FACTOR: 20 # run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -race -randomize-all -randomize-suites -trace -skip-package integrationtests + + - name: Run other tests + run: | + go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self,versionnegotiation integrationtests + go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/versionnegotiation + - name: Run self tests, using QUIC v1 + if: success() || failure() # run this step even if the previous one failed + run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=1 + - name: Run self tests, using QUIC v2 + if: success() || failure() # run this step even if the previous one failed + run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=2 + - name: Run self tests, with GSO disabled + if: ${{ matrix.os == 'ubuntu' && (success() || failure()) }} # run this step even if the previous one failed + env: + QUIC_GO_DISABLE_GSO: true + run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=1 \ No newline at end of file diff --git a/client.go b/client.go index 2357c1202..cc8c1bd10 100644 --- a/client.go +++ b/client.go @@ -57,11 +57,11 @@ func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Confi if err != nil { return nil, err } - dl, err := setupTransport(udpConn, tlsConf, true) + tr, err := setupTransport(udpConn, tlsConf, true) if err != nil { return nil, err } - return dl.Dial(ctx, udpAddr, tlsConf, conf) + return tr.dial(ctx, udpAddr, addr, tlsConf, conf, false) } // DialAddrEarly establishes a new 0-RTT QUIC connection to a server. @@ -75,13 +75,13 @@ func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf * if err != nil { return nil, err } - dl, err := setupTransport(udpConn, tlsConf, true) + tr, err := setupTransport(udpConn, tlsConf, true) if err != nil { return nil, err } - conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf) + conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true) if err != nil { - dl.Close() + tr.Close() return nil, err } return conn, nil @@ -166,12 +166,6 @@ func dial( } func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { - if tlsConf == nil { - tlsConf = &tls.Config{} - } else { - tlsConf = tlsConf.Clone() - } - srcConnID, err := connIDGenerator.GenerateConnectionID() if err != nil { return nil, err diff --git a/client_test.go b/client_test.go index 9ccf27dee..b12fec931 100644 --- a/client_test.go +++ b/client_test.go @@ -13,10 +13,9 @@ import ( "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/logging" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) type nullMultiplexer struct{} diff --git a/connection.go b/connection.go index 955f8f8a7..270f6ec0f 100644 --- a/connection.go +++ b/connection.go @@ -244,7 +244,7 @@ var newConnection = func( handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), tokenGenerator: tokenGenerator, - oneRTTStream: newCryptoStream(), + oneRTTStream: newCryptoStream(true), perspective: protocol.PerspectiveServer, tracer: tracer, logger: logger, @@ -394,8 +394,7 @@ var newClientConnection = func( ) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - oneRTTStream := newCryptoStream() - + oneRTTStream := newCryptoStream(true) params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -453,8 +452,8 @@ var newClientConnection = func( } func (s *connection) preSetup() { - s.initialStream = newCryptoStream() - s.handshakeStream = newCryptoStream() + s.initialStream = newCryptoStream(false) + s.handshakeStream = newCryptoStream(false) s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue() s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) diff --git a/connection_test.go b/connection_test.go index 95f1b3464..de0983eaa 100644 --- a/connection_test.go +++ b/connection_test.go @@ -26,10 +26,9 @@ import ( "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func areConnsRunning() bool { @@ -2703,8 +2702,9 @@ var _ = Describe("Client Connection", func() { Expect(recreateErr.nextPacketNumber).To(Equal(protocol.PacketNumber(128))) }) - It("it closes when no matching version is found", func() { + It("closes when no matching version is found", func() { errChan := make(chan error, 1) + packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -2712,7 +2712,6 @@ var _ = Describe("Client Connection", func() { errChan <- conn.run() }() connRunner.EXPECT().Remove(srcConnID).MaxTimes(1) - packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) gomock.InOrder( tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()), tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { diff --git a/crypto_stream.go b/crypto_stream.go index 4ad097ce5..0c991089c 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -30,10 +30,17 @@ type cryptoStreamImpl struct { writeOffset protocol.ByteCount writeBuf []byte + + // Reassemble TLS handshake messages before returning them from GetCryptoData. + // This is only needed because crypto/tls doesn't correctly handle post-handshake messages. + onlyCompleteMsg bool } -func newCryptoStream() cryptoStream { - return &cryptoStreamImpl{queue: newFrameSorter()} +func newCryptoStream(onlyCompleteMsg bool) cryptoStream { + return &cryptoStreamImpl{ + queue: newFrameSorter(), + onlyCompleteMsg: onlyCompleteMsg, + } } func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { @@ -71,6 +78,20 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // GetCryptoData retrieves data that was received in CRYPTO frames func (s *cryptoStreamImpl) GetCryptoData() []byte { + if s.onlyCompleteMsg { + if len(s.msgBuf) < 4 { + return nil + } + msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) + if len(s.msgBuf) < msgLen { + return nil + } + msg := make([]byte, msgLen) + copy(msg, s.msgBuf[:msgLen]) + s.msgBuf = s.msgBuf[msgLen:] + return msg + } + b := s.msgBuf s.msgBuf = nil return b diff --git a/crypto_stream_test.go b/crypto_stream_test.go index c875bcb53..af6ad986f 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -1,6 +1,7 @@ package quic import ( + "crypto/rand" "fmt" "github.com/refraction-networking/uquic/internal/protocol" @@ -15,7 +16,7 @@ var _ = Describe("Crypto Stream", func() { var str cryptoStream BeforeEach(func() { - str = newCryptoStream() + str = newCryptoStream(false) }) Context("handling incoming data", func() { @@ -137,4 +138,23 @@ var _ = Describe("Crypto Stream", func() { Expect(f.Data).To(Equal([]byte("bar"))) }) }) + + It("reassembles data", func() { + str = newCryptoStream(true) + data := make([]byte, 1337) + l := len(data) - 4 + data[1] = uint8(l >> 16) + data[2] = uint8(l >> 8) + data[3] = uint8(l) + rand.Read(data[4:]) + + for i, b := range data { + Expect(str.GetCryptoData()).To(BeEmpty()) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: protocol.ByteCount(i), + Data: []byte{b}, + })).To(Succeed()) + } + Expect(str.GetCryptoData()).To(Equal(data)) + }) }) diff --git a/framer_test.go b/framer_test.go index 9a6b07779..1c4944dad 100644 --- a/framer_test.go +++ b/framer_test.go @@ -8,10 +8,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Framer", func() { diff --git a/fuzzing/handshake/fuzz.go b/fuzzing/handshake/fuzz.go index f017e1f09..16db3a2f8 100644 --- a/fuzzing/handshake/fuzz.go +++ b/fuzzing/handshake/fuzz.go @@ -18,6 +18,7 @@ import ( "github.com/refraction-networking/uquic/fuzzing/internal/helper" "github.com/refraction-networking/uquic/internal/handshake" "github.com/refraction-networking/uquic/internal/protocol" + "github.com/refraction-networking/uquic/internal/qtls" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" ) @@ -85,33 +86,6 @@ func (m messageType) String() string { } } -func appendSuites(suites []uint16, rand uint8) []uint16 { - const ( - s1 = tls.TLS_AES_128_GCM_SHA256 - s2 = tls.TLS_AES_256_GCM_SHA384 - s3 = tls.TLS_CHACHA20_POLY1305_SHA256 - ) - switch rand % 4 { - default: - return suites - case 1: - return append(suites, s1) - case 2: - return append(suites, s2) - case 3: - return append(suites, s3) - } -} - -// consumes 2 bits -func getSuites(rand uint8) []uint16 { - suites := make([]uint16, 0, 3) - for i := 1; i <= 3; i++ { - suites = appendSuites(suites, rand>>i%4) - } - return suites -} - // consumes 3 bits func getClientAuth(rand uint8) tls.ClientAuthType { switch rand { @@ -148,6 +122,7 @@ func getTransportParameters(seed uint8) *wire.TransportParameters { const maxVarInt = math.MaxUint64 / 4 r := mrand.New(mrand.NewSource(int64(seed))) return &wire.TransportParameters{ + ActiveConnectionIDLimit: 2, InitialMaxData: protocol.ByteCount(r.Int63n(maxVarInt)), InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Int63n(maxVarInt)), InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)), @@ -207,14 +182,26 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. SessionTicketKey: sessionTicketKey, } + // This sets the cipher suite for both client and server. + // The way crypto/tls is designed doesn't allow us to set different cipher suites for client and server. + resetCipherSuite := func() {} + switch (runConfig[0] >> 6) % 4 { + case 0: + resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_128_GCM_SHA256) + case 1: + resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_256_GCM_SHA384) + case 3: + resetCipherSuite = qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256) + default: + } + defer resetCipherSuite() + enable0RTTClient := helper.NthBit(runConfig[0], 0) enable0RTTServer := helper.NthBit(runConfig[0], 1) sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3) sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4) sendSessionTicket := helper.NthBit(runConfig[0], 5) - clientConf.CipherSuites = getSuites(runConfig[0] >> 6) serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111) - serverConf.CipherSuites = getSuites(runConfig[1] >> 6) serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3) if helper.NthBit(runConfig[2], 0) { clientConf.RootCAs = x509.NewCertPool() @@ -303,6 +290,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. if err := client.StartHandshake(); err != nil { log.Fatal(err) } + defer client.Close() server := handshake.NewCryptoSetupServer( protocol.ConnectionID{}, @@ -319,12 +307,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. if err := server.StartHandshake(); err != nil { log.Fatal(err) } + defer server.Close() var clientHandshakeComplete, serverHandshakeComplete bool for { + var processedEvent bool clientLoop: for { - var processedEvent bool ev := client.NextEvent() //nolint:exhaustive // only need to process a few events switch ev.Kind { @@ -335,11 +324,16 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. break clientLoop case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: msg := ev.Data + encLevel := protocol.EncryptionInitial + if ev.Kind == handshake.EventWriteHandshakeData { + encLevel = protocol.EncryptionHandshake + } if msg[0] == messageToReplace { fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) msg = data + encLevel = messageToReplaceEncLevel } - if err := server.HandleMessage(msg, messageToReplaceEncLevel); err != nil { + if err := server.HandleMessage(msg, encLevel); err != nil { return 1 } case handshake.EventHandshakeComplete: @@ -348,9 +342,9 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. processedEvent = true } + processedEvent = false serverLoop: for { - var processedEvent bool ev := server.NextEvent() //nolint:exhaustive // only need to process a few events switch ev.Kind { @@ -360,12 +354,17 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. } break serverLoop case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: + encLevel := protocol.EncryptionInitial + if ev.Kind == handshake.EventWriteHandshakeData { + encLevel = protocol.EncryptionHandshake + } msg := ev.Data if msg[0] == messageToReplace { fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) msg = data + encLevel = messageToReplaceEncLevel } - if err := client.HandleMessage(msg, messageToReplaceEncLevel); err != nil { + if err := client.HandleMessage(msg, encLevel); err != nil { return 1 } case handshake.EventHandshakeComplete: @@ -410,10 +409,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. } client.HandleMessage(ticket, protocol.Encryption1RTT) } + if sendPostHandshakeMessageToClient { + fmt.Println("sending post handshake message to the client at", messageToReplaceEncLevel) client.HandleMessage(data, messageToReplaceEncLevel) } if sendPostHandshakeMessageToServer { + fmt.Println("sending post handshake message to the server at", messageToReplaceEncLevel) server.HandleMessage(data, messageToReplaceEncLevel) } diff --git a/go.mod b/go.mod index 6587caedf..5b855acb4 100644 --- a/go.mod +++ b/go.mod @@ -5,13 +5,13 @@ go 1.20 require ( github.com/francoispqt/gojay v1.2.13 github.com/gaukas/clienthellod v0.4.0 - github.com/golang/mock v1.6.0 - github.com/onsi/ginkgo/v2 v2.11.0 + github.com/onsi/ginkgo/v2 v2.12.0 github.com/onsi/gomega v1.27.10 github.com/quic-go/qpack v0.4.0 - github.com/refraction-networking/utls v1.4.3 + github.com/refraction-networking/utls v1.5.2 + go.uber.org/mock v0.2.0 golang.org/x/crypto v0.12.0 - golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/net v0.14.0 golang.org/x/sync v0.3.0 golang.org/x/sys v0.11.0 @@ -19,16 +19,17 @@ require ( require ( github.com/andybalholm/brotli v1.0.5 // indirect + github.com/cloudflare/circl v1.3.3 // indirect github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-logr/logr v1.2.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/gopacket v1.1.19 // indirect - github.com/google/pprof v0.0.0-20230808223545-4887780b67fb // indirect + github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f // indirect github.com/klauspost/compress v1.16.7 // indirect - github.com/quic-go/quic-go v0.37.4 // indirect + github.com/quic-go/quic-go v0.38.1 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/text v0.12.0 // indirect - golang.org/x/tools v0.12.0 // indirect + golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 151f2f571..870557e2d 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudflare/circl v1.3.3 h1:fE/Qz0QdIGqeWfnwq0RE0R7MI51s0M2E4Ga9kq5AEMs= +github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -40,8 +42,6 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= @@ -57,6 +57,8 @@ github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXi github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20230808223545-4887780b67fb h1:oqpb3Cwpc7EOml5PVGMYbSGmwNui2R7i8IW83gs4W0c= github.com/google/pprof v0.0.0-20230808223545-4887780b67fb/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= +github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f h1:pDhu5sgp8yJlEF/g6osliIIpF9K4F5jvkULXa4daRDQ= +github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -84,6 +86,8 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU= github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM= +github.com/onsi/ginkgo/v2 v2.12.0 h1:UIVDowFPwpg6yMUpPjGkYvf06K3RAiJXUhCxEwQVHRI= +github.com/onsi/ginkgo/v2 v2.12.0/go.mod h1:ZNEzXISYlqpb8S36iN71ifqLi3vVD1rVJGvWRCJOUpQ= github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= @@ -98,8 +102,10 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/quic-go v0.37.4 h1:ke8B73yMCWGq9MfrCCAw0Uzdm7GaViC3i39dsIdDlH4= github.com/quic-go/quic-go v0.37.4/go.mod h1:YsbH1r4mSHPJcLF4k4zruUkLBqctEMBDR6VPvcYjIsU= -github.com/refraction-networking/utls v1.4.3 h1:BdWS3BSzCwWCFfMIXP3mjLAyQkdmog7diaD/OqFbAzM= -github.com/refraction-networking/utls v1.4.3/go.mod h1:4u9V/awOSBrRw6+federGmVJQfPtemEqLBXkML1b0bo= +github.com/quic-go/quic-go v0.38.1 h1:M36YWA5dEhEeT+slOu/SwMEucbYd0YFidxG3KlGPZaE= +github.com/quic-go/quic-go v0.38.1/go.mod h1:ijnZM7JsFIkp4cRyjxJNIzdSfCLmUMg9wdyhGmg+SN4= +github.com/refraction-networking/utls v1.5.2 h1:l6diiLbEoRqdQ+/osPDO0z0lTc8O8VZV+p82N+Hi+ws= +github.com/refraction-networking/utls v1.5.2/go.mod h1:SPuDbBmgLGp8s+HLNc83FuavwZCFoMmExj+ltUHiHUw= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -133,8 +139,9 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go.uber.org/mock v0.2.0 h1:TaP3xedm7JaAgScZO7tlvlKrqT0p7I6OsdGB5YNSMDU= +go.uber.org/mock v0.2.0/go.mod h1:J0y0rp9L3xiff1+ZBfKxlC1fz2+aO16tw0tsDOixfuM= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -146,12 +153,13 @@ golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98y golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb h1:mIKbk8weKhSeLH2GmUTrvx8CjkyJmnU1wFmg59CUjFA= golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -164,7 +172,6 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -177,7 +184,6 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -186,32 +192,24 @@ golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E= +golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= diff --git a/http3/body_test.go b/http3/body_test.go index 0f64bc762..c5f9bb87b 100644 --- a/http3/body_test.go +++ b/http3/body_test.go @@ -6,9 +6,9 @@ import ( quic "github.com/refraction-networking/uquic" mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Response Body", func() { diff --git a/http3/client_test.go b/http3/client_test.go index 5b56a1fbc..28504f30c 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -19,11 +19,11 @@ import ( "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/quicvarint" - "github.com/golang/mock/gomock" "github.com/quic-go/qpack" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Client", func() { diff --git a/http3/error_codes.go b/http3/error_codes.go index 515c675e3..86e27ff76 100644 --- a/http3/error_codes.go +++ b/http3/error_codes.go @@ -26,7 +26,7 @@ const ( ErrCodeMessageError ErrCode = 0x10e ErrCodeConnectError ErrCode = 0x10f ErrCodeVersionFallback ErrCode = 0x110 - ErrCodeDatagramError ErrCode = 0x4a1268 + ErrCodeDatagramError ErrCode = 0x33 ) func (e ErrCode) String() string { diff --git a/http3/frames.go b/http3/frames.go index 7897b86f1..7f6d0fe8d 100644 --- a/http3/frames.go +++ b/http3/frames.go @@ -88,7 +88,7 @@ func (f *headersFrame) Append(b []byte) []byte { return quicvarint.Append(b, f.Length) } -const settingDatagram = 0xffd277 +const settingDatagram = 0x33 type settingsFrame struct { Datagram bool diff --git a/http3/http3_suite_test.go b/http3/http3_suite_test.go index 56b2108f9..b34003b92 100644 --- a/http3/http3_suite_test.go +++ b/http3/http3_suite_test.go @@ -6,10 +6,9 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestHttp3(t *testing.T) { diff --git a/http3/http_stream_test.go b/http3/http_stream_test.go index 4f87ee550..5e0ff514b 100644 --- a/http3/http_stream_test.go +++ b/http3/http_stream_test.go @@ -7,9 +7,9 @@ import ( quic "github.com/refraction-networking/uquic" mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func getDataFrame(data []byte) []byte { diff --git a/http3/mock_quic_early_listener_test.go b/http3/mock_quic_early_listener_test.go index 0e7cf6859..b43b3550d 100644 --- a/http3/mock_quic_early_listener_test.go +++ b/http3/mock_quic_early_listener_test.go @@ -9,7 +9,7 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" quic "github.com/refraction-networking/uquic" ) diff --git a/http3/mock_roundtripcloser_test.go b/http3/mock_roundtripcloser_test.go index f9a82130f..6550b2554 100644 --- a/http3/mock_roundtripcloser_test.go +++ b/http3/mock_roundtripcloser_test.go @@ -8,7 +8,7 @@ import ( http "net/http" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockRoundTripCloser is a mock of RoundTripCloser interface. diff --git a/http3/mockgen.go b/http3/mockgen.go index a8185a97f..cf4b917a6 100644 --- a/http3/mockgen.go +++ b/http3/mockgen.go @@ -2,7 +2,7 @@ package http3 -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/refraction-networking/uquic/http3 RoundTripCloser" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/refraction-networking/uquic/http3 RoundTripCloser" type RoundTripCloser = roundTripCloser -//go:generate sh -c "go run github.com/golang/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/refraction-networking/uquic/http3 QUICEarlyListener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/refraction-networking/uquic/http3 QUICEarlyListener" diff --git a/http3/request_writer_test.go b/http3/request_writer_test.go index 9de46f90a..e590efa12 100644 --- a/http3/request_writer_test.go +++ b/http3/request_writer_test.go @@ -8,8 +8,8 @@ import ( mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" "github.com/refraction-networking/uquic/internal/utils" - "github.com/golang/mock/gomock" "github.com/quic-go/qpack" + "go.uber.org/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" diff --git a/http3/response_writer.go b/http3/response_writer.go index 2262f16c7..8eb592a99 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -15,19 +15,61 @@ import ( "github.com/quic-go/qpack" ) +// The maximum length of an encoded HTTP/3 frame header is 16: +// The frame has a type and length field, both QUIC varints (maximum 8 bytes in length) +const frameHeaderLen = 16 + +// headerWriter wraps the stream, so that the first Write call flushes the header to the stream +type headerWriter struct { + str quic.Stream + header http.Header + status int // status code passed to WriteHeader + written bool + + logger utils.Logger +} + +// writeHeader encodes and flush header to the stream +func (hw *headerWriter) writeHeader() error { + var headers bytes.Buffer + enc := qpack.NewEncoder(&headers) + enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}) + + for k, v := range hw.header { + for index := range v { + enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) + } + } + + buf := make([]byte, 0, frameHeaderLen+headers.Len()) + buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf) + hw.logger.Infof("Responding with %d", hw.status) + buf = append(buf, headers.Bytes()...) + + _, err := hw.str.Write(buf) + return err +} + +// first Write will trigger flushing header +func (hw *headerWriter) Write(p []byte) (int, error) { + if !hw.written { + if err := hw.writeHeader(); err != nil { + return 0, err + } + hw.written = true + } + return hw.str.Write(p) +} + type responseWriter struct { + *headerWriter conn quic.Connection - str quic.Stream bufferedStr *bufio.Writer buf []byte - header http.Header - status int // status code passed to WriteHeader headerWritten bool contentLen int64 // if handler set valid Content-Length header numWritten int64 // bytes written - - logger utils.Logger } var ( @@ -37,13 +79,16 @@ var ( ) func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { + hw := &headerWriter{ + str: str, + header: http.Header{}, + logger: logger, + } return &responseWriter{ - header: http.Header{}, - buf: make([]byte, 16), - conn: conn, - str: str, - bufferedStr: bufio.NewWriter(str), - logger: logger, + headerWriter: hw, + buf: make([]byte, frameHeaderLen), + conn: conn, + bufferedStr: bufio.NewWriter(hw), } } @@ -83,27 +128,8 @@ func (w *responseWriter) WriteHeader(status int) { } w.status = status - var headers bytes.Buffer - enc := qpack.NewEncoder(&headers) - enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - - for k, v := range w.header { - for index := range v { - enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) - } - } - - w.buf = w.buf[:0] - w.buf = (&headersFrame{Length: uint64(headers.Len())}).Append(w.buf) - w.logger.Infof("Responding with %d", status) - if _, err := w.bufferedStr.Write(w.buf); err != nil { - w.logger.Errorf("could not write headers frame: %s", err.Error()) - } - if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil { - w.logger.Errorf("could not write header frame payload: %s", err.Error()) - } if !w.headerWritten { - w.Flush() + w.writeHeader() } } @@ -146,6 +172,15 @@ func (w *responseWriter) Write(p []byte) (int, error) { } func (w *responseWriter) FlushError() error { + if !w.headerWritten { + w.WriteHeader(http.StatusOK) + } + if !w.written { + if err := w.writeHeader(); err != nil { + return err + } + w.written = true + } return w.bufferedStr.Flush() } diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index 8563ae8ab..88c18a0a2 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -9,11 +9,11 @@ import ( mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" "github.com/refraction-networking/uquic/internal/utils" - "github.com/golang/mock/gomock" "github.com/quic-go/qpack" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Response Writer", func() { diff --git a/http3/roundtrip.go b/http3/roundtrip.go index df8bc4094..3ef5ddd2a 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -52,7 +52,7 @@ type RoundTripper struct { // Enable support for HTTP/3 datagrams. // If set to true, QuicConfig.EnableDatagram will be set. - // See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html. + // See https://datatracker.ietf.org/doc/html/rfc9297. EnableDatagrams bool // Additional HTTP/3 settings. diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index 279308fdd..e685aa618 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -14,9 +14,9 @@ import ( "github.com/refraction-networking/uquic/internal/qerr" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) type mockBody struct { diff --git a/http3/server.go b/http3/server.go index b32e80edc..20a0323a9 100644 --- a/http3/server.go +++ b/http3/server.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "runtime" + "strconv" "strings" "sync" "time" @@ -33,12 +34,8 @@ var ( } ) -const ( - // NextProtoH3Draft29 is the ALPN protocol negotiated during the TLS handshake, for QUIC draft 29. - NextProtoH3Draft29 = "h3-29" - // NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2. - NextProtoH3 = "h3" -) +// NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2. +const NextProtoH3 = "h3" // StreamType is the stream type of a unidirectional stream. type StreamType uint64 @@ -178,7 +175,7 @@ type Server struct { // EnableDatagrams enables support for HTTP/3 datagrams. // If set to true, QuicConfig.EnableDatagram will be set. - // See https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-07. + // See https://datatracker.ietf.org/doc/html/rfc9297. EnableDatagrams bool // MaxHeaderBytes controls the maximum number of bytes the server will @@ -651,7 +648,12 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q // only write response when there is no panic if !panicked { - r.WriteHeader(http.StatusOK) + // response not written to the client yet, set Content-Length + if !r.written { + if _, haveCL := r.header["Content-Length"]; !haveCL { + r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10)) + } + } r.Flush() } // If the EOF was read by the handler, CancelRead() is a no-op. diff --git a/http3/server_test.go b/http3/server_test.go index b33c0f59d..5ab313dea 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -21,8 +21,8 @@ import ( "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/quicvarint" - "github.com/golang/mock/gomock" "github.com/quic-go/qpack" + "go.uber.org/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -181,6 +181,47 @@ var _ = Describe("Server", func() { Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) }) + It("sets Content-Length when the handler doesn't flush to the client", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foobar")) + }) + + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(exampleGetRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"})) + // status, content-length, date, content-type + Expect(hfs).To(HaveLen(4)) + }) + + It("not sets Content-Length when the handler flushes to the client", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foobar")) + // force flush + w.(http.Flusher).Flush() + }) + + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(exampleGetRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + // status, date, content-type + Expect(hfs).To(HaveLen(3)) + }) + It("handles a aborting handler", func() { s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic(http.ErrAbortHandler) diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index 1336cd7c4..e24232c0b 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -39,8 +39,6 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= @@ -136,8 +134,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQvuDk= -github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM= +github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -173,10 +171,11 @@ github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cb github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go.uber.org/mock v0.2.0 h1:TaP3xedm7JaAgScZO7tlvlKrqT0p7I6OsdGB5YNSMDU= +go.uber.org/mock v0.2.0/go.mod h1:J0y0rp9L3xiff1+ZBfKxlC1fz2+aO16tw0tsDOixfuM= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -195,7 +194,7 @@ golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTk golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= @@ -217,7 +216,6 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= @@ -261,9 +259,7 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -309,7 +305,7 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.8/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 9def9361a..6ab814a82 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -201,6 +201,8 @@ var _ = Describe("Handshake tests", func() { Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) + var certErr *tls.CertificateVerificationError + Expect(errors.As(transportErr, &certErr)).To(BeTrue()) }) It("fails the handshake if the client fails to provide the requested client cert", func() { @@ -452,7 +454,7 @@ var _ = Describe("Handshake tests", func() { It("rejects invalid Retry token with the INVALID_TOKEN error", func() { serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } - serverConfig.MaxRetryTokenAge = time.Nanosecond + serverConfig.MaxRetryTokenAge = -time.Second server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index d0474ad3f..698f3160c 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -42,7 +42,7 @@ var _ = Describe("HTTP tests", func() { rt *http3.RoundTripper server *http3.Server stoppedServing chan struct{} - port string + port int ) BeforeEach(func() { @@ -93,7 +93,7 @@ var _ = Describe("HTTP tests", func() { Expect(err).NotTo(HaveOccurred()) conn, err := net.ListenUDP("udp", addr) Expect(err).NotTo(HaveOccurred()) - port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port) + port = conn.LocalAddr().(*net.UDPAddr).Port stoppedServing = make(chan struct{}) @@ -120,7 +120,7 @@ var _ = Describe("HTTP tests", func() { }) It("downloads a hello", func() { - resp, err := client.Get("https://localhost:" + port + "/hello") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) @@ -128,13 +128,25 @@ var _ = Describe("HTTP tests", func() { Expect(string(body)).To(Equal("Hello, World!\n")) }) + It("sets content-length for small response", func() { + mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + w.Write([]byte("foobar")) + }) + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/small", port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len("foobar")))) + }) + It("requests to different servers with the same udpconn", func() { - resp, err := client.Get("https://localhost:" + port + "/remoteAddr") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/remoteAddr", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) addr1 := resp.Header.Get("X-RemoteAddr") Expect(addr1).ToNot(Equal("")) - resp, err = client.Get("https://127.0.0.1:" + port + "/remoteAddr") + resp, err = client.Get(fmt.Sprintf("https://127.0.0.1:%d/remoteAddr", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) addr2 := resp.Header.Get("X-RemoteAddr") @@ -146,7 +158,7 @@ var _ = Describe("HTTP tests", func() { group, ctx := errgroup.WithContext(context.Background()) for i := 0; i < 2; i++ { group.Go(func() error { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://localhost:"+port+"/hello", nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/hello", port), nil) Expect(err).ToNot(HaveOccurred()) resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) @@ -172,7 +184,7 @@ var _ = Describe("HTTP tests", func() { close(handlerCalled) }) - req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/headers/request", nil) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/headers/request", port), nil) Expect(err).ToNot(HaveOccurred()) req.Header.Set("foo", "bar") req.Header.Set("lorem", "ipsum") @@ -189,7 +201,7 @@ var _ = Describe("HTTP tests", func() { w.Header().Set("lorem", "ipsum") }) - resp, err := client.Get("https://localhost:" + port + "/headers/response") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/headers/response", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) Expect(resp.Header.Get("foo")).To(Equal("bar")) @@ -197,7 +209,7 @@ var _ = Describe("HTTP tests", func() { }) It("downloads a small file", func() { - resp, err := client.Get("https://localhost:" + port + "/prdata") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/prdata", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second)) @@ -206,7 +218,7 @@ var _ = Describe("HTTP tests", func() { }) It("downloads a large file", func() { - resp, err := client.Get("https://localhost:" + port + "/prdatalong") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/prdatalong", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 20*time.Second)) @@ -218,7 +230,7 @@ var _ = Describe("HTTP tests", func() { const num = 150 for i := 0; i < num; i++ { - resp, err := client.Get("https://localhost:" + port + "/hello") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) @@ -231,7 +243,7 @@ var _ = Describe("HTTP tests", func() { const num = 150 for i := 0; i < num; i++ { - resp, err := client.Get("https://localhost:" + port + "/prdata") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/prdata", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) Expect(resp.Body.Close()).To(Succeed()) @@ -240,7 +252,7 @@ var _ = Describe("HTTP tests", func() { It("posts a small message", func() { resp, err := client.Post( - "https://localhost:"+port+"/echo", + fmt.Sprintf("https://localhost:%d/echo", port), "text/plain", bytes.NewReader([]byte("Hello, world!")), ) @@ -253,7 +265,7 @@ var _ = Describe("HTTP tests", func() { It("uploads a file", func() { resp, err := client.Post( - "https://localhost:"+port+"/echo", + fmt.Sprintf("https://localhost:%d/echo", port), "text/plain", bytes.NewReader(PRData), ) @@ -277,7 +289,7 @@ var _ = Describe("HTTP tests", func() { }) client.Transport.(*http3.RoundTripper).DisableCompression = false - resp, err := client.Get("https://localhost:" + port + "/gzipped/hello") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/gzipped/hello", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) Expect(resp.Uncompressed).To(BeTrue()) @@ -303,7 +315,7 @@ var _ = Describe("HTTP tests", func() { } }) - req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/cancel", nil) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/cancel", port), nil) Expect(err).ToNot(HaveOccurred()) ctx, cancel := context.WithCancel(context.Background()) req = req.WithContext(ctx) @@ -336,7 +348,7 @@ var _ = Describe("HTTP tests", func() { }) r, w := io.Pipe() - req, err := http.NewRequest("PUT", "https://localhost:"+port+"/echoline", r) + req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("https://localhost:%d/echoline", port), r) Expect(err).ToNot(HaveOccurred()) rsp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) @@ -373,7 +385,7 @@ var _ = Describe("HTTP tests", func() { }() }) - req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/httpstreamer", nil) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/httpstreamer", port), nil) Expect(err).ToNot(HaveOccurred()) rsp, err := client.Transport.(*http3.RoundTripper).RoundTripOpt(req, http3.RoundTripOpt{DontCloseRequestStream: true}) Expect(err).ToNot(HaveOccurred()) @@ -431,7 +443,11 @@ var _ = Describe("HTTP tests", func() { }) expectedEnd := time.Now().Add(deadlineDelay) - resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a')) + resp, err := client.Post( + fmt.Sprintf("https://localhost:%d/read-deadline", port), + "text/plain", + neverEnding('a'), + ) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) @@ -453,7 +469,7 @@ var _ = Describe("HTTP tests", func() { expectedEnd := time.Now().Add(deadlineDelay) - resp, err := client.Get("https://localhost:" + port + "/write-deadline") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/write-deadline", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index 92a230257..b508f9aa7 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -2,9 +2,11 @@ package self_test import ( "context" + "crypto/rand" "io" "net" "runtime" + "sync/atomic" "time" . "github.com/onsi/ginkgo/v2" @@ -209,4 +211,67 @@ var _ = Describe("Multiplexing", func() { }) } }) + + It("sends and receives non-QUIC packets", func() { + addr1, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn1, err := net.ListenUDP("udp", addr1) + Expect(err).ToNot(HaveOccurred()) + defer conn1.Close() + tr1 := &quic.Transport{Conn: conn1} + + addr2, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn2, err := net.ListenUDP("udp", addr2) + Expect(err).ToNot(HaveOccurred()) + defer conn2.Close() + tr2 := &quic.Transport{Conn: conn2} + + server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + runServer(server) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var sentPackets, rcvdPackets atomic.Int64 + const packetLen = 128 + // send a non-QUIC packet every 100µs + go func() { + defer GinkgoRecover() + ticker := time.NewTicker(time.Millisecond / 10) + defer ticker.Stop() + for { + select { + case <-ticker.C: + case <-ctx.Done(): + return + } + b := make([]byte, packetLen) + rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet + _, err := tr1.WriteTo(b, tr2.Conn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + sentPackets.Add(1) + } + }() + + // receive and count non-QUIC packets + go func() { + defer GinkgoRecover() + for { + b := make([]byte, 1024) + n, addr, err := tr2.ReadNonQUICPacket(ctx, b) + if err != nil { + Expect(err).To(MatchError(context.Canceled)) + return + } + Expect(addr).To(Equal(tr1.Conn.LocalAddr())) + Expect(n).To(Equal(packetLen)) + rcvdPackets.Add(1) + } + }() + dial(tr2, server.Addr()) + Eventually(func() int64 { return sentPackets.Load() }).Should(BeNumerically(">", 10)) + Eventually(func() int64 { return rcvdPackets.Load() }).Should(BeNumerically(">=", sentPackets.Load()*4/5)) + }) }) diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index 3e827cb8c..d38ae9a74 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -56,7 +56,7 @@ var _ = Describe("TLS session resumption", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, - nil, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) var sessionKey string @@ -71,7 +71,7 @@ var _ = Describe("TLS session resumption", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, - nil, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) Expect(gets).To(Receive(Equal(sessionKey))) @@ -85,7 +85,7 @@ var _ = Describe("TLS session resumption", func() { It("doesn't use session resumption, if the config disables it", func() { sConf := getTLSConfig() sConf.SessionTicketsDisabled = true - server, err := quic.ListenAddr("localhost:0", sConf, nil) + server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) defer server.Close() @@ -98,7 +98,7 @@ var _ = Describe("TLS session resumption", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, - nil, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) Consistently(puts).ShouldNot(Receive()) @@ -114,7 +114,55 @@ var _ = Describe("TLS session resumption", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, - nil, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + + serverConn, err = server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) + }) + + It("doesn't use session resumption, if the config returned by GetConfigForClient disables it", func() { + sConf := &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + conf := getTLSConfig() + conf.SessionTicketsDisabled = true + return conf, nil + }, + } + + server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + defer server.Close() + + gets := make(chan string, 100) + puts := make(chan string, 100) + cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) + tlsConf := getTLSClientConfig() + tlsConf.ClientSessionCache = cache + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Consistently(puts).ShouldNot(Receive()) + Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + serverConn, err := server.Accept(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) + + conn, err = quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go index 919e99471..101c1ffdf 100644 --- a/integrationtests/self/zero_rtt_oldgo_test.go +++ b/integrationtests/self/zero_rtt_oldgo_test.go @@ -802,4 +802,116 @@ var _ = Describe("0-RTT", func() { Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) + + It("sends 0-RTT datagrams", func() { + tlsConf, clientTLSConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + EnableDatagrams: true, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + sentMessage := GeneratePRData(100) + var receivedMessage []byte + received := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(received) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + receivedMessage, err = conn.ReceiveMessage(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(sentMessage)).To(Succeed()) + <-conn.HandshakeComplete() + <-received + + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(receivedMessage).To(Equal(sentMessage)) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(zeroRTTPackets).To(HaveLen(1)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) + + It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { + tlsConf, clientTLSConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + EnableDatagrams: true, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: false, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + _, err = conn.ReceiveMessage(context.Background()) + Expect(err.Error()).To(Equal("datagram support disabled")) + <-conn.HandshakeComplete() + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + // the client can temporarily send datagrams but the server doesn't process them. + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(make([]byte, 100))).To(Succeed()) + <-conn.HandshakeComplete() + + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 5021e43f6..7fe819308 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -939,4 +939,119 @@ var _ = Describe("0-RTT", func() { ) Expect(restored).To(BeTrue()) }) + + It("sends 0-RTT datagrams", func() { + tlsConf := getTLSConfig() + clientTLSConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), clientTLSConf) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + sentMessage := GeneratePRData(100) + var receivedMessage []byte + received := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(received) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + receivedMessage, err = conn.ReceiveMessage(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(sentMessage)).To(Succeed()) + <-conn.HandshakeComplete() + <-received + + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(receivedMessage).To(Equal(sentMessage)) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(zeroRTTPackets).To(HaveLen(1)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) + + It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { + tlsConf := getTLSConfig() + clientTLSConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), clientTLSConf) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: false, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + _, err = conn.ReceiveMessage(context.Background()) + Expect(err.Error()).To(Equal("datagram support disabled")) + <-conn.HandshakeComplete() + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + // the client can temporarily send datagrams but the server doesn't process them. + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(make([]byte, 100))).To(Succeed()) + <-conn.HandshakeComplete() + + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) }) diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index b89caabd3..bef42f289 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "net" + "runtime" "runtime/pprof" "strconv" "strings" @@ -68,7 +69,11 @@ var _ = Describe("QUIC Proxy", func() { addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) _, err = net.ListenUDP("udp", addr) - Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: address already in use", proxy.LocalPort()))) + if runtime.GOOS == "windows" { + Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: Only one usage of each socket address (protocol/network address/port) is normally permitted.", proxy.LocalPort()))) + } else { + Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: address already in use", proxy.LocalPort()))) + } Expect(proxy.Close()).To(Succeed()) // stopping is tested in the next test }) diff --git a/internal/ackhandler/ackhandler_suite_test.go b/internal/ackhandler/ackhandler_suite_test.go index 069aaa791..a0cf3ee17 100644 --- a/internal/ackhandler/ackhandler_suite_test.go +++ b/internal/ackhandler/ackhandler_suite_test.go @@ -3,9 +3,9 @@ package ackhandler import ( "testing" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestCrypto(t *testing.T) { diff --git a/internal/ackhandler/mock_sent_packet_tracker_test.go b/internal/ackhandler/mock_sent_packet_tracker_test.go index fecb32721..cd34cfdbc 100644 --- a/internal/ackhandler/mock_sent_packet_tracker_test.go +++ b/internal/ackhandler/mock_sent_packet_tracker_test.go @@ -7,7 +7,7 @@ package ackhandler import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/ackhandler/mockgen.go b/internal/ackhandler/mockgen.go index b432a0a07..3d2f10823 100644 --- a/internal/ackhandler/mockgen.go +++ b/internal/ackhandler/mockgen.go @@ -2,5 +2,5 @@ package ackhandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketTracker" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketTracker" type SentPacketTracker = sentPacketTracker diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index d0377200b..a537399b4 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -3,14 +3,13 @@ package ackhandler import ( "time" - "github.com/golang/mock/gomock" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Received Packet Handler", func() { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 9b21d1fe9..15a09fa2c 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -4,8 +4,6 @@ import ( "fmt" "time" - "github.com/golang/mock/gomock" - "github.com/refraction-networking/uquic/internal/mocks" "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" @@ -14,6 +12,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) type customFrameHandler struct { diff --git a/internal/flowcontrol/flowcontrol_suite_test.go b/internal/flowcontrol/flowcontrol_suite_test.go index 8831296d3..6cfe981d7 100644 --- a/internal/flowcontrol/flowcontrol_suite_test.go +++ b/internal/flowcontrol/flowcontrol_suite_test.go @@ -3,9 +3,9 @@ package flowcontrol import ( "testing" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestFlowControl(t *testing.T) { diff --git a/internal/handshake/cipher_suite.go b/internal/handshake/cipher_suite.go index 2ed922a88..368a95d48 100644 --- a/internal/handshake/cipher_suite.go +++ b/internal/handshake/cipher_suite.go @@ -31,7 +31,7 @@ func getCipherSuite(id uint16) *cipherSuite { case tls.TLS_CHACHA20_POLY1305_SHA256: return &cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305} case tls.TLS_AES_256_GCM_SHA384: - return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadAESGCMTLS13} + return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA384, KeyLen: 32, AEAD: aeadAESGCMTLS13} default: panic(fmt.Sprintf("unknown cypher suite: %d", id)) } diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 5630b8808..f2e6b67dd 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "strings" "sync" "sync/atomic" "time" @@ -358,10 +359,15 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte { // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { - if h.tlsConf.SessionTicketsDisabled { - return nil, nil - } - if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{h.allow0RTT}); err != nil { + if err := qtls.SendSessionTicket(h.conn, h.allow0RTT); err != nil { + // Session tickets might be disabled by tls.Config.SessionTicketsDisabled. + // We can't check h.tlsConfig here, since the actual config might have been obtained from + // the GetConfigForClient callback. + // See https://github.com/golang/go/issues/62032. + // Once that issue is resolved, this error assertion can be removed. + if strings.Contains(err.Error(), "session ticket keys unavailable") { + return nil, nil + } return nil, err } ev := h.conn.NextEvent() @@ -660,8 +666,9 @@ func (h *cryptoSetup) ConnectionState() ConnectionState { } func wrapError(err error) error { + // alert 80 is an internal error if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 { - return qerr.NewLocalCryptoError(uint8(alertErr), err.Error()) + return qerr.NewLocalCryptoError(uint8(alertErr), err) } return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()} } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index f29517cda..fbc82512d 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -18,10 +18,9 @@ import ( "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) const ( diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go index 45ff1c3d6..7d82421c0 100644 --- a/internal/handshake/handshake_suite_test.go +++ b/internal/handshake/handshake_suite_test.go @@ -7,10 +7,9 @@ import ( tls "github.com/refraction-networking/utls" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestHandshake(t *testing.T) { diff --git a/internal/handshake/session_ticket_test.go b/internal/handshake/session_ticket_test.go index a113f0245..83c64a7d8 100644 --- a/internal/handshake/session_ticket_test.go +++ b/internal/handshake/session_ticket_test.go @@ -17,6 +17,7 @@ var _ = Describe("Session Ticket", func() { InitialMaxStreamDataBidiLocal: 1, InitialMaxStreamDataBidiRemote: 2, ActiveConnectionIDLimit: 10, + MaxDatagramFrameSize: 20, }, RTT: 1337 * time.Microsecond, } @@ -25,6 +26,7 @@ var _ = Describe("Session Ticket", func() { Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) Expect(t.Parameters.ActiveConnectionIDLimit).To(BeEquivalentTo(10)) + Expect(t.Parameters.MaxDatagramFrameSize).To(BeEquivalentTo(20)) Expect(t.RTT).To(Equal(1337 * time.Microsecond)) }) diff --git a/internal/handshake/u_crypto_setup.go b/internal/handshake/u_crypto_setup.go index a05aa3388..92074ca8b 100644 --- a/internal/handshake/u_crypto_setup.go +++ b/internal/handshake/u_crypto_setup.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "strings" "sync" "sync/atomic" "time" @@ -285,10 +286,15 @@ func (h *uCryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transp // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. func (h *uCryptoSetup) GetSessionTicket() ([]byte, error) { - if h.tlsConf.SessionTicketsDisabled { - return nil, nil - } - if err := h.conn.SendSessionTicket(h.allow0RTT); err != nil { + if err := qtls.SendSessionTicket(h.conn, h.allow0RTT); err != nil { + // Session tickets might be disabled by tls.Config.SessionTicketsDisabled. + // We can't check h.tlsConfig here, since the actual config might have been obtained from + // the GetConfigForClient callback. + // See https://github.com/golang/go/issues/62032. + // Once that issue is resolved, this error assertion can be removed. + if strings.Contains(err.Error(), "session ticket keys unavailable") { + return nil, nil + } return nil, err } ev := h.conn.NextEvent() diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 4072d6dd6..4bc1fa86b 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -13,9 +13,9 @@ import ( "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/utils" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Updatable AEAD", func() { diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index a368b3648..569b690ef 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -8,9 +8,9 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" + gomock "go.uber.org/mock/gomock" ) // MockReceivedPacketHandler is a mock of ReceivedPacketHandler interface. diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 76db968ca..37fae0e6a 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" diff --git a/internal/mocks/congestion.go b/internal/mocks/congestion.go index c81b7177b..a2ba0aabc 100644 --- a/internal/mocks/congestion.go +++ b/internal/mocks/congestion.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/connection_flow_controller.go b/internal/mocks/connection_flow_controller.go index 02d3c3a17..60d87d39c 100644 --- a/internal/mocks/connection_flow_controller.go +++ b/internal/mocks/connection_flow_controller.go @@ -7,7 +7,7 @@ package mocks import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 61a019a6e..5558f9730 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -7,7 +7,7 @@ package mocks import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" handshake "github.com/refraction-networking/uquic/internal/handshake" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index e03a58976..4a9d1fd70 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -9,11 +9,12 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" utils "github.com/refraction-networking/uquic/internal/utils" wire "github.com/refraction-networking/uquic/internal/wire" logging "github.com/refraction-networking/uquic/logging" + + gomock "go.uber.org/mock/gomock" ) // MockConnectionTracer is a mock of ConnectionTracer interface. diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index 42f81660e..b479c66c1 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -8,10 +8,10 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" logging "github.com/refraction-networking/uquic/logging" + gomock "go.uber.org/mock/gomock" ) // MockTracer is a mock of Tracer interface. diff --git a/internal/mocks/long_header_opener.go b/internal/mocks/long_header_opener.go index 61d55f270..9552aff7e 100644 --- a/internal/mocks/long_header_opener.go +++ b/internal/mocks/long_header_opener.go @@ -7,7 +7,7 @@ package mocks import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 6f68d26c7..58ed55786 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -1,19 +1,19 @@ package mocks -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mockquic -destination quic/stream.go github.com/refraction-networking/uquic Stream" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/refraction-networking/uquic EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocklogging -destination logging/tracer.go github.com/refraction-networking/uquic/logging Tracer" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocklogging -destination logging/connection_tracer.go github.com/refraction-networking/uquic/logging ConnectionTracer" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination short_header_sealer.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderSealer" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination short_header_opener.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderOpener" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination long_header_opener.go github.com/refraction-networking/uquic/internal/handshake LongHeaderOpener" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination crypto_setup_tmp.go github.com/refraction-networking/uquic/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/refraction-networking/uquic/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination stream_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol StreamFlowController" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination congestion.go github.com/refraction-networking/uquic/internal/congestion SendAlgorithmWithDebugInfos" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination connection_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol ConnectionFlowController" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketHandler" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler ReceivedPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockquic -destination quic/stream.go github.com/refraction-networking/uquic Stream" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/refraction-networking/uquic EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocklogging -destination logging/tracer.go github.com/refraction-networking/uquic/logging Tracer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocklogging -destination logging/connection_tracer.go github.com/refraction-networking/uquic/logging ConnectionTracer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination short_header_sealer.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderSealer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination short_header_opener.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination long_header_opener.go github.com/refraction-networking/uquic/internal/handshake LongHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination crypto_setup_tmp.go github.com/refraction-networking/uquic/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/refraction-networking/uquic/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination stream_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol StreamFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination congestion.go github.com/refraction-networking/uquic/internal/congestion SendAlgorithmWithDebugInfos" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination connection_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol ConnectionFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler ReceivedPacketHandler" // The following command produces a warning message on OSX, however, it still generates the correct mock file. // See https://github.com/golang/mock/issues/339 for details. -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go index 816ca5a35..428c0eb17 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/mocks/quic/early_conn.go @@ -9,9 +9,9 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" quic "github.com/refraction-networking/uquic" qerr "github.com/refraction-networking/uquic/internal/qerr" + gomock "go.uber.org/mock/gomock" ) // MockEarlyConnection is a mock of EarlyConnection interface. diff --git a/internal/mocks/quic/stream.go b/internal/mocks/quic/stream.go index c548883bb..97c52c526 100644 --- a/internal/mocks/quic/stream.go +++ b/internal/mocks/quic/stream.go @@ -9,9 +9,9 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" + gomock "go.uber.org/mock/gomock" ) // MockStream is a mock of Stream interface. diff --git a/internal/mocks/short_header_opener.go b/internal/mocks/short_header_opener.go index c4ab76deb..4f609b937 100644 --- a/internal/mocks/short_header_opener.go +++ b/internal/mocks/short_header_opener.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/short_header_sealer.go b/internal/mocks/short_header_sealer.go index 90f4b80f0..dfb654bfe 100644 --- a/internal/mocks/short_header_sealer.go +++ b/internal/mocks/short_header_sealer.go @@ -7,7 +7,7 @@ package mocks import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go index ab6dcade1..3a16e9375 100644 --- a/internal/mocks/stream_flow_controller.go +++ b/internal/mocks/stream_flow_controller.go @@ -7,7 +7,7 @@ package mocks import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/tls/client_session_cache.go b/internal/mocks/tls/client_session_cache.go index 5178c62ab..9a7c34440 100644 --- a/internal/mocks/tls/client_session_cache.go +++ b/internal/mocks/tls/client_session_cache.go @@ -8,7 +8,7 @@ import ( tls "github.com/refraction-networking/utls" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockClientSessionCache is a mock of ClientSessionCache interface. diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 98bc9ffbc..b8104882b 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -43,6 +43,21 @@ const ( ECNCE // 11 ) +func (e ECN) String() string { + switch e { + case ECNNon: + return "Not-ECT" + case ECT1: + return "ECT(1)" + case ECT0: + return "ECT(0)" + case ECNCE: + return "CE" + default: + return fmt.Sprintf("invalid ECN value: %d", e) + } +} + // A ByteCount in QUIC type ByteCount int64 diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index d9048f010..e672d31e1 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -22,4 +22,12 @@ var _ = Describe("Protocol", func() { Expect(ECN(0b00000001)).To(Equal(ECT1)) Expect(ECN(0b00000011)).To(Equal(ECNCE)) }) + + It("has a string representation for ECN", func() { + Expect(ECNNon.String()).To(Equal("Not-ECT")) + Expect(ECT0.String()).To(Equal("ECT(0)")) + Expect(ECT1.String()).To(Equal("ECT(1)")) + Expect(ECNCE.String()).To(Equal("CE")) + Expect(ECN(42).String()).To(Equal("invalid ECN value: 42")) + }) }) diff --git a/internal/qerr/errors.go b/internal/qerr/errors.go index 4a053ae0d..c3d8465b9 100644 --- a/internal/qerr/errors.go +++ b/internal/qerr/errors.go @@ -17,15 +17,16 @@ type TransportError struct { FrameType uint64 ErrorCode TransportErrorCode ErrorMessage string + error error // only set for local errors, sometimes } var _ error = &TransportError{} // NewLocalCryptoError create a new TransportError instance for a crypto error -func NewLocalCryptoError(tlsAlert uint8, errorMessage string) *TransportError { +func NewLocalCryptoError(tlsAlert uint8, err error) *TransportError { return &TransportError{ - ErrorCode: 0x100 + TransportErrorCode(tlsAlert), - ErrorMessage: errorMessage, + ErrorCode: 0x100 + TransportErrorCode(tlsAlert), + error: err, } } @@ -35,6 +36,9 @@ func (e *TransportError) Error() string { str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) } msg := e.ErrorMessage + if len(msg) == 0 && e.error != nil { + msg = e.error.Error() + } if len(msg) == 0 { msg = e.ErrorCode.Message() } @@ -48,6 +52,10 @@ func (e *TransportError) Is(target error) bool { return target == net.ErrClosed } +func (e *TransportError) Unwrap() error { + return e.error +} + // An ApplicationErrorCode is an application-defined error code. type ApplicationErrorCode uint64 diff --git a/internal/qerr/errors_test.go b/internal/qerr/errors_test.go index cfad61566..7bfbc7345 100644 --- a/internal/qerr/errors_test.go +++ b/internal/qerr/errors_test.go @@ -2,6 +2,7 @@ package qerr import ( "errors" + "fmt" "net" "github.com/refraction-networking/uquic/internal/protocol" @@ -10,6 +11,12 @@ import ( . "github.com/onsi/gomega" ) +type myError int + +var _ error = myError(0) + +func (e myError) Error() string { return fmt.Sprintf("my error %d", e) } + var _ = Describe("QUIC Errors", func() { Context("Transport Errors", func() { It("has a string representation", func() { @@ -41,12 +48,20 @@ var _ = Describe("QUIC Errors", func() { Context("crypto errors", func() { It("has a string representation for errors with a message", func() { - err := NewLocalCryptoError(0x42, "foobar") - Expect(err.Error()).To(Equal("CRYPTO_ERROR 0x142 (local): foobar")) + myErr := myError(1337) + err := NewLocalCryptoError(0x42, myErr) + Expect(err.Error()).To(Equal("CRYPTO_ERROR 0x142 (local): my error 1337")) + }) + + It("unwraps errors", func() { + var myErr myError + err := NewLocalCryptoError(0x42, myError(1337)) + Expect(errors.As(err, &myErr)).To(BeTrue()) + Expect(myErr).To(BeEquivalentTo(1337)) }) It("has a string representation for errors without a message", func() { - err := NewLocalCryptoError(0x2a, "") + err := NewLocalCryptoError(0x2a, nil) Expect(err.Error()).To(Equal("CRYPTO_ERROR 0x12a (local): tls: bad certificate")) }) }) diff --git a/internal/qtls/cipher_suite.go b/internal/qtls/cipher_suite.go new file mode 100644 index 000000000..07891c141 --- /dev/null +++ b/internal/qtls/cipher_suite.go @@ -0,0 +1,65 @@ +package qtls + +import ( + "crypto" + "crypto/cipher" + "fmt" + "unsafe" + + tls "github.com/refraction-networking/utls" +) + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13 +var cipherSuitesTLS13 []unsafe.Pointer + +//go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13 +var defaultCipherSuitesTLS13 []uint16 + +//go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES +var defaultCipherSuitesTLS13NoAES []uint16 + +var cipherSuitesModified bool + +// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls +// such that it only contains the cipher suite with the chosen id. +// The reset function returned resets them back to the original value. +func SetCipherSuite(id uint16) (reset func()) { + if cipherSuitesModified { + panic("cipher suites modified multiple times without resetting") + } + cipherSuitesModified = true + + origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...) + origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...) + origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...) + // The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls. + switch id { + case tls.TLS_AES_128_GCM_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[:1] + case tls.TLS_CHACHA20_POLY1305_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[1:2] + case tls.TLS_AES_256_GCM_SHA384: + cipherSuitesTLS13 = cipherSuitesTLS13[2:] + default: + panic(fmt.Sprintf("unexpected cipher suite: %d", id)) + } + defaultCipherSuitesTLS13 = []uint16{id} + defaultCipherSuitesTLS13NoAES = []uint16{id} + + return func() { + cipherSuitesTLS13 = origCipherSuitesTLS13 + defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13 + defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES + cipherSuitesModified = false + } +} diff --git a/internal/qtls/qtls_suite_test.go b/internal/qtls/qtls_suite_test.go index e8ce652a1..bde81e6ce 100644 --- a/internal/qtls/qtls_suite_test.go +++ b/internal/qtls/qtls_suite_test.go @@ -3,10 +3,9 @@ package qtls import ( "testing" - gomock "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestQTLS(t *testing.T) { diff --git a/internal/qtls/utls.go b/internal/qtls/utls.go index a056c7417..149c6cc3c 100644 --- a/internal/qtls/utls.go +++ b/internal/qtls/utls.go @@ -10,13 +10,14 @@ import ( ) type ( - QUICConn = tls.QUICConn - UQUICConn = tls.UQUICConn // [UQUIC] - QUICConfig = tls.QUICConfig - QUICEvent = tls.QUICEvent - QUICEventKind = tls.QUICEventKind - QUICEncryptionLevel = tls.QUICEncryptionLevel - AlertError = tls.AlertError + QUICConn = tls.QUICConn + UQUICConn = tls.UQUICConn // [UQUIC] + QUICConfig = tls.QUICConfig + QUICEvent = tls.QUICEvent + QUICEventKind = tls.QUICEventKind + QUICEncryptionLevel = tls.QUICEncryptionLevel + QUICSessionTicketOptions = tls.QUICSessionTicketOptions + AlertError = tls.AlertError ) const ( @@ -166,3 +167,14 @@ func findExtraData(extras [][]byte) []byte { } return nil } + +type QUICConnOrUQUICConn interface { + *QUICConn | *UQUICConn + SendSessionTicket(opts QUICSessionTicketOptions) error +} + +func SendSessionTicket[C QUICConnOrUQUICConn](c C, allow0RTT bool) error { + return c.SendSessionTicket(tls.QUICSessionTicketOptions{ + EarlyData: allow0RTT, + }) +} diff --git a/internal/wire/header.go b/internal/wire/header.go index 4060d5987..f6025019f 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -74,6 +74,10 @@ func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.Arbitra return destConnID, srcConnID, nil } +func IsPotentialQUICPacket(firstByte byte) bool { + return firstByte&0x40 > 0 +} + // IsLongHeaderPacket says if this is a Long Header packet func IsLongHeaderPacket(firstByte byte) bool { return firstByte&0x80 > 0 diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 48ae09234..faa6eaf37 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -503,6 +503,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), ActiveConnectionIDLimit: 2 + getRandomValueUpTo(math.MaxInt64-2), + MaxDatagramFrameSize: protocol.ByteCount(getRandomValueUpTo(int64(protocol.MaxDatagramFrameSize))), } Expect(params.ValidFor0RTT(params)).To(BeTrue()) b := params.MarshalForSessionTicket(nil) @@ -515,6 +516,7 @@ var _ = Describe("Transport Parameters", func() { Expect(tp.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) Expect(tp.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) Expect(tp.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) + Expect(tp.MaxDatagramFrameSize).To(Equal(params.MaxDatagramFrameSize)) }) It("rejects the parameters if it can't parse them", func() { @@ -540,6 +542,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: 5, MaxUniStreamNum: 6, ActiveConnectionIDLimit: 7, + MaxDatagramFrameSize: 1000, } BeforeEach(func() { @@ -611,6 +614,16 @@ var _ = Describe("Transport Parameters", func() { p.ActiveConnectionIDLimit = 0 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) + + It("accepts the parameters if the MaxDatagramFrameSize was increased", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the MaxDatagramFrameSize reduced", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) }) Context("client checks the parameters after successfully sending 0-RTT data", func() { @@ -623,6 +636,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: 5, MaxUniStreamNum: 6, ActiveConnectionIDLimit: 7, + MaxDatagramFrameSize: 1000, } BeforeEach(func() { @@ -699,6 +713,16 @@ var _ = Describe("Transport Parameters", func() { p.ActiveConnectionIDLimit = saved.ActiveConnectionIDLimit + 1 Expect(p.ValidForUpdate(saved)).To(BeTrue()) }) + + It("rejects the parameters if the MaxDatagramFrameSize reduced", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize - 1 + Expect(p.ValidForUpdate(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the MaxDatagramFrameSize increased", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize + 1 + Expect(p.ValidForUpdate(saved)).To(BeTrue()) + }) }) }) }) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 7a24e2cec..f1562b36d 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -450,6 +450,12 @@ func (p *TransportParameters) marshalVarintParam(b []byte, id transportParameter // Since the session ticket is encrypted, the serialization format is defined by the server. // For convenience, we use the same format that we also use for sending the transport parameters. func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte { + // [UQUIC] + if p.ClientOverride != nil { + return p.ClientOverride.Marshal() // TODO: does this work as expected? test needed + } + // [/UQUIC] + b = quicvarint.Append(b, transportParameterMarshalingVersion) // initial_max_stream_data_bidi_local @@ -464,6 +470,10 @@ func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte { b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) // initial_max_uni_streams b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + // max_datagram_frame_size + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) + } // active_connection_id_limit return p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) } @@ -482,6 +492,9 @@ func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error // ValidFor0RTT checks if the transport parameters match those saved in the session ticket. func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { + if saved.MaxDatagramFrameSize != protocol.InvalidByteCount && (p.MaxDatagramFrameSize == protocol.InvalidByteCount || p.MaxDatagramFrameSize < saved.MaxDatagramFrameSize) { + return false + } return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && @@ -494,6 +507,9 @@ func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { // ValidForUpdate checks that the new transport parameters don't reduce limits after resuming a 0-RTT connection. // It is only used on the client side. func (p *TransportParameters) ValidForUpdate(saved *TransportParameters) bool { + if saved.MaxDatagramFrameSize != protocol.InvalidByteCount && (p.MaxDatagramFrameSize == protocol.InvalidByteCount || p.MaxDatagramFrameSize < saved.MaxDatagramFrameSize) { + return false + } return p.ActiveConnectionIDLimit >= saved.ActiveConnectionIDLimit && p.InitialMaxData >= saved.InitialMaxData && p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index cbaa2cb52..3659d1031 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -40,7 +40,10 @@ func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnec buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) r := make([]byte, 1) _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. - buf.WriteByte(r[0] | 0x80) + // Setting the "QUIC bit" (0x40) is not required by the RFC, + // but it allows clients to demultiplex QUIC with a long list of other protocols. + // See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details. + buf.WriteByte(r[0] | 0xc0) utils.BigEndian.WriteUint32(buf, 0) // version 0 buf.WriteByte(uint8(destConnID.Len())) buf.Write(destConnID.Bytes()) diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 529a3234f..e47d2c83f 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -64,6 +64,7 @@ var _ = Describe("Version Negotiation Packets", func() { versions := []protocol.VersionNumber{1001, 1003} data := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(IsLongHeaderPacket(data[0])).To(BeTrue()) + Expect(data[0] & 0x40).ToNot(BeZero()) v, err := ParseVersion(data) Expect(err).ToNot(HaveOccurred()) Expect(v).To(BeZero()) diff --git a/logging/logging_suite_test.go b/logging/logging_suite_test.go index d37ada480..d808adfec 100644 --- a/logging/logging_suite_test.go +++ b/logging/logging_suite_test.go @@ -3,10 +3,9 @@ package logging import ( "testing" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestLogging(t *testing.T) { diff --git a/logging/mock_connection_tracer_test.go b/logging/mock_connection_tracer_test.go index 5fad29cee..e5dd38c2b 100644 --- a/logging/mock_connection_tracer_test.go +++ b/logging/mock_connection_tracer_test.go @@ -9,7 +9,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" utils "github.com/refraction-networking/uquic/internal/utils" wire "github.com/refraction-networking/uquic/internal/wire" diff --git a/logging/mock_tracer_test.go b/logging/mock_tracer_test.go index cb13831dd..d295c96d6 100644 --- a/logging/mock_tracer_test.go +++ b/logging/mock_tracer_test.go @@ -8,7 +8,7 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/logging/mockgen.go b/logging/mockgen.go index 8b7fab64a..66f712bcf 100644 --- a/logging/mockgen.go +++ b/logging/mockgen.go @@ -1,4 +1,4 @@ package logging -//go:generate sh -c "go run github.com/golang/mock/mockgen -package logging -self_package github.com/refraction-networking/uquic/logging -destination mock_connection_tracer_test.go github.com/refraction-networking/uquic/logging ConnectionTracer" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package logging -self_package github.com/refraction-networking/uquic/logging -destination mock_tracer_test.go github.com/refraction-networking/uquic/logging Tracer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package logging -self_package github.com/refraction-networking/uquic/logging -destination mock_connection_tracer_test.go github.com/refraction-networking/uquic/logging ConnectionTracer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package logging -self_package github.com/refraction-networking/uquic/logging -destination mock_tracer_test.go github.com/refraction-networking/uquic/logging Tracer" diff --git a/mock_ack_frame_source_test.go b/mock_ack_frame_source_test.go index ab70a5227..4ebc2d425 100644 --- a/mock_ack_frame_source_test.go +++ b/mock_ack_frame_source_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/mock_batch_conn_test.go b/mock_batch_conn_test.go index fcb23e34c..9621e7b4e 100644 --- a/mock_batch_conn_test.go +++ b/mock_batch_conn_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ipv4 "golang.org/x/net/ipv4" ) diff --git a/mock_conn_runner_test.go b/mock_conn_runner_test.go index ea9e8f566..d857fd580 100644 --- a/mock_conn_runner_test.go +++ b/mock_conn_runner_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_crypto_data_handler_test.go b/mock_crypto_data_handler_test.go index 9bc0e86f4..eb2e90fa1 100644 --- a/mock_crypto_data_handler_test.go +++ b/mock_crypto_data_handler_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" handshake "github.com/refraction-networking/uquic/internal/handshake" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index 737bc08e7..002b1d17c 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index 622cb78b1..7319739cf 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_mtu_discoverer_test.go b/mock_mtu_discoverer_test.go index 3ebe8f4ec..a1c15fb1c 100644 --- a/mock_mtu_discoverer_test.go +++ b/mock_mtu_discoverer_test.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_packer_test.go b/mock_packer_test.go index fe966fb9d..409a7a8ca 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 7b5c9a96b..b9e28eebe 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go index 823d23585..8402d5e46 100644 --- a/mock_packet_handler_test.go +++ b/mock_packet_handler_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_packetconn_test.go b/mock_packetconn_test.go index d6731e4ac..c8e20bf28 100644 --- a/mock_packetconn_test.go +++ b/mock_packetconn_test.go @@ -9,7 +9,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockPacketConn is a mock of PacketConn interface. diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index 30a291998..7ce24333d 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -9,7 +9,7 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" ) diff --git a/mock_raw_conn_test.go b/mock_raw_conn_test.go new file mode 100644 index 000000000..0a1a0f3ac --- /dev/null +++ b/mock_raw_conn_test.go @@ -0,0 +1,122 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go (interfaces: RawConn) + +// Package quic is a generated GoMock package. +package quic + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "go.uber.org/mock/gomock" +) + +// MockRawConn is a mock of RawConn interface. +type MockRawConn struct { + ctrl *gomock.Controller + recorder *MockRawConnMockRecorder +} + +// MockRawConnMockRecorder is the mock recorder for MockRawConn. +type MockRawConnMockRecorder struct { + mock *MockRawConn +} + +// NewMockRawConn creates a new mock instance. +func NewMockRawConn(ctrl *gomock.Controller) *MockRawConn { + mock := &MockRawConn{ctrl: ctrl} + mock.recorder = &MockRawConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRawConn) EXPECT() *MockRawConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRawConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRawConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRawConn)(nil).Close)) +} + +// LocalAddr mocks base method. +func (m *MockRawConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockRawConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockRawConn)(nil).LocalAddr)) +} + +// ReadPacket mocks base method. +func (m *MockRawConn) ReadPacket() (receivedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadPacket") + ret0, _ := ret[0].(receivedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadPacket indicates an expected call of ReadPacket. +func (mr *MockRawConnMockRecorder) ReadPacket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadPacket", reflect.TypeOf((*MockRawConn)(nil).ReadPacket)) +} + +// SetReadDeadline mocks base method. +func (m *MockRawConn) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockRawConn)(nil).SetReadDeadline), arg0) +} + +// WritePacket mocks base method. +func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WritePacket indicates an expected call of WritePacket. +func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2) +} + +// capabilities mocks base method. +func (m *MockRawConn) capabilities() connCapabilities { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "capabilities") + ret0, _ := ret[0].(connCapabilities) + return ret0 +} + +// capabilities indicates an expected call of capabilities. +func (mr *MockRawConnMockRecorder) capabilities() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "capabilities", reflect.TypeOf((*MockRawConn)(nil).capabilities)) +} diff --git a/mock_receive_stream_internal_test.go b/mock_receive_stream_internal_test.go index 72f9757e4..6237feef5 100644 --- a/mock_receive_stream_internal_test.go +++ b/mock_receive_stream_internal_test.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" wire "github.com/refraction-networking/uquic/internal/wire" diff --git a/mock_sealing_manager_test.go b/mock_sealing_manager_test.go index a87509059..1b9c9214c 100644 --- a/mock_sealing_manager_test.go +++ b/mock_sealing_manager_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" handshake "github.com/refraction-networking/uquic/internal/handshake" ) diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index e63df48f8..f55feaeb2 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -8,7 +8,7 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_send_stream_internal_test.go b/mock_send_stream_internal_test.go index 2332e5efd..46bf75659 100644 --- a/mock_send_stream_internal_test.go +++ b/mock_send_stream_internal_test.go @@ -9,7 +9,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" diff --git a/mock_sender_test.go b/mock_sender_test.go index bb3fa128b..4c9899374 100644 --- a/mock_sender_test.go +++ b/mock_sender_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_stream_getter_test.go b/mock_stream_getter_test.go index 6f8590573..3ed365151 100644 --- a/mock_stream_getter_test.go +++ b/mock_stream_getter_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index 7acb402cb..369f43be1 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -9,7 +9,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index 53294be42..20e69a837 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -8,7 +8,7 @@ import ( context "context" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index c23ad3a9e..2c0608568 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/mock_token_store_test.go b/mock_token_store_test.go index 72d85962d..c5d286431 100644 --- a/mock_token_store_test.go +++ b/mock_token_store_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockTokenStore is a mock of TokenStore interface. diff --git a/mock_unknown_packet_handler_test.go b/mock_unknown_packet_handler_test.go index 09c115696..fa702529c 100644 --- a/mock_unknown_packet_handler_test.go +++ b/mock_unknown_packet_handler_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockUnknownPacketHandler is a mock of UnknownPacketHandler interface. diff --git a/mock_unpacker_test.go b/mock_unpacker_test.go index a0d3dc604..89dbf9c87 100644 --- a/mock_unpacker_test.go +++ b/mock_unpacker_test.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/mockgen.go b/mockgen.go index de20437ad..5adfbf591 100644 --- a/mockgen.go +++ b/mockgen.go @@ -2,73 +2,76 @@ package quic -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_conn_test.go github.com/refraction-networking/uquic SendConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_conn_test.go github.com/refraction-networking/uquic SendConn" type SendConn = sendConn -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_sender_test.go github.com/refraction-networking/uquic Sender" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_raw_conn_test.go github.com/refraction-networking/uquic RawConn" +type RawConn = rawConn + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_sender_test.go github.com/refraction-networking/uquic Sender" type Sender = sender -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_internal_test.go github.com/refraction-networking/uquic StreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_internal_test.go github.com/refraction-networking/uquic StreamI" type StreamI = streamI -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_stream_test.go github.com/refraction-networking/uquic CryptoStream" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_stream_test.go github.com/refraction-networking/uquic CryptoStream" type CryptoStream = cryptoStream -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_receive_stream_internal_test.go github.com/refraction-networking/uquic ReceiveStreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_receive_stream_internal_test.go github.com/refraction-networking/uquic ReceiveStreamI" type ReceiveStreamI = receiveStreamI -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_stream_internal_test.go github.com/refraction-networking/uquic SendStreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_stream_internal_test.go github.com/refraction-networking/uquic SendStreamI" type SendStreamI = sendStreamI -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_getter_test.go github.com/refraction-networking/uquic StreamGetter" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_getter_test.go github.com/refraction-networking/uquic StreamGetter" type StreamGetter = streamGetter -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_sender_test.go github.com/refraction-networking/uquic StreamSender" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_sender_test.go github.com/refraction-networking/uquic StreamSender" type StreamSender = streamSender -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_data_handler_test.go github.com/refraction-networking/uquic CryptoDataHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_data_handler_test.go github.com/refraction-networking/uquic CryptoDataHandler" type CryptoDataHandler = cryptoDataHandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_frame_source_test.go github.com/refraction-networking/uquic FrameSource" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_frame_source_test.go github.com/refraction-networking/uquic FrameSource" type FrameSource = frameSource -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_ack_frame_source_test.go github.com/refraction-networking/uquic AckFrameSource" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_ack_frame_source_test.go github.com/refraction-networking/uquic AckFrameSource" type AckFrameSource = ackFrameSource -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_manager_test.go github.com/refraction-networking/uquic StreamManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_manager_test.go github.com/refraction-networking/uquic StreamManager" type StreamManager = streamManager -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_sealing_manager_test.go github.com/refraction-networking/uquic SealingManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_sealing_manager_test.go github.com/refraction-networking/uquic SealingManager" type SealingManager = sealingManager -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_unpacker_test.go github.com/refraction-networking/uquic Unpacker" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_unpacker_test.go github.com/refraction-networking/uquic Unpacker" type Unpacker = unpacker -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packer_test.go github.com/refraction-networking/uquic Packer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packer_test.go github.com/refraction-networking/uquic Packer" type Packer = packer -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_mtu_discoverer_test.go github.com/refraction-networking/uquic MTUDiscoverer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_mtu_discoverer_test.go github.com/refraction-networking/uquic MTUDiscoverer" type MTUDiscoverer = mtuDiscoverer -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_conn_runner_test.go github.com/refraction-networking/uquic ConnRunner" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_conn_runner_test.go github.com/refraction-networking/uquic ConnRunner" type ConnRunner = connRunner -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_quic_conn_test.go github.com/refraction-networking/uquic QUICConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_quic_conn_test.go github.com/refraction-networking/uquic QUICConn" type QUICConn = quicConn -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_test.go github.com/refraction-networking/uquic PacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_test.go github.com/refraction-networking/uquic PacketHandler" type PacketHandler = packetHandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_unknown_packet_handler_test.go github.com/refraction-networking/uquic UnknownPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_unknown_packet_handler_test.go github.com/refraction-networking/uquic UnknownPacketHandler" type UnknownPacketHandler = unknownPacketHandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_manager_test.go github.com/refraction-networking/uquic PacketHandlerManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_manager_test.go github.com/refraction-networking/uquic PacketHandlerManager" type PacketHandlerManager = packetHandlerManager // Need to use source mode for the batchConn, since reflect mode follows type aliases. // See https://github.com/golang/mock/issues/244 for details. // -//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_token_store_test.go github.com/refraction-networking/uquic TokenStore" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_packetconn_test.go net PacketConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_token_store_test.go github.com/refraction-networking/uquic TokenStore" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_packetconn_test.go net PacketConn" diff --git a/packet_handler_map.go b/packet_handler_map.go index a99a8562b..41287c2aa 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -26,9 +26,9 @@ type connCapabilities struct { // rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { ReadPacket() (receivedPacket, error) - // The size parameter is used for GSO. - // If GSO is not support, len(b) must be equal to size. - WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error) + // WritePacket writes a packet on the wire. + // If GSO is enabled, it's the caller's responsibility to set the correct control message. + WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer diff --git a/packet_packer_test.go b/packet_packer_test.go index abf59f32a..74e05a67a 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "errors" "fmt" "net" "time" @@ -17,10 +18,9 @@ import ( "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Packet packer", func() { @@ -334,7 +334,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - quicErr := qerr.NewLocalCryptoError(0x42, "crypto error") + quicErr := qerr.NewLocalCryptoError(0x42, errors.New("crypto error")) quicErr.FrameType = 0x1234 p, err := packer.PackConnectionClose(quicErr, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 8ca1c8b40..9584df4c6 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -10,9 +10,9 @@ import ( "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Packet Unpacker", func() { diff --git a/quic_suite_test.go b/quic_suite_test.go index d979d81bc..954ca60b9 100644 --- a/quic_suite_test.go +++ b/quic_suite_test.go @@ -9,9 +9,9 @@ import ( "sync" "testing" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestQuicGo(t *testing.T) { diff --git a/receive_stream_test.go b/receive_stream_test.go index 9cdbbd72f..23396e1f3 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -12,10 +12,10 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/onsi/gomega/gbytes" + "go.uber.org/mock/gomock" ) var _ = Describe("Receive Stream", func() { diff --git a/send_conn.go b/send_conn.go index 63636e786..a3feaf628 100644 --- a/send_conn.go +++ b/send_conn.go @@ -1,10 +1,12 @@ package quic import ( + "fmt" "math" "net" "github.com/refraction-networking/uquic/internal/protocol" + "github.com/refraction-networking/uquic/internal/utils" ) // A sendConn allows sending using a simple Write() on a non-connected packet conn. @@ -20,61 +22,86 @@ type sendConn interface { type sconn struct { rawConn + localAddr net.Addr remoteAddr net.Addr - info packetInfo - oob []byte + + logger utils.Logger + + info packetInfo + oob []byte + // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. + gotGSOError bool } var _ sendConn = &sconn{} -func newSendConn(c rawConn, remote net.Addr) *sconn { - sc := &sconn{ - rawConn: c, - remoteAddr: remote, - } - if c.capabilities().GSO { - // add 32 bytes, so we can add the UDP_SEGMENT msg - sc.oob = make([]byte, 0, 32) +func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logger) *sconn { + localAddr := c.LocalAddr() + if info.addr.IsValid() { + if udpAddr, ok := localAddr.(*net.UDPAddr); ok { + addrCopy := *udpAddr + addrCopy.IP = info.addr.AsSlice() + localAddr = &addrCopy + } } - return sc -} -func newSendConnWithPacketInfo(c rawConn, remote net.Addr, info packetInfo) *sconn { oob := info.OOB() - if c.capabilities().GSO { - // add 32 bytes, so we can add the UDP_SEGMENT msg - l := len(oob) - oob = append(oob, make([]byte, 32)...) - oob = oob[:l] - } + // add 32 bytes, so we can add the UDP_SEGMENT msg + l := len(oob) + oob = append(oob, make([]byte, 32)...) + oob = oob[:l] return &sconn{ rawConn: c, + localAddr: localAddr, remoteAddr: remote, info: info, oob: oob, + logger: logger, } } func (c *sconn) Write(p []byte, size protocol.ByteCount) error { + if !c.capabilities().GSO { + if protocol.ByteCount(len(p)) != size { + panic(fmt.Sprintf("inconsistent packet size (%d vs %d)", len(p), size)) + } + _, err := c.WritePacket(p, c.remoteAddr, c.oob) + return err + } + // GSO is supported. Append the control message and send. if size > math.MaxUint16 { panic("size overflow") } - _, err := c.WritePacket(p, uint16(size), c.remoteAddr, c.oob) + _, err := c.WritePacket(p, c.remoteAddr, appendUDPSegmentSizeMsg(c.oob, uint16(size))) + if err != nil && isGSOError(err) { + // disable GSO for future calls + c.gotGSOError = true + if c.logger.Debug() { + c.logger.Debugf("GSO failed when sending to %s", c.remoteAddr) + } + // send out the packets one by one + for len(p) > 0 { + l := len(p) + if l > int(size) { + l = int(size) + } + if _, err := c.WritePacket(p[:l], c.remoteAddr, c.oob); err != nil { + return err + } + p = p[l:] + } + return nil + } return err } -func (c *sconn) RemoteAddr() net.Addr { - return c.remoteAddr -} - -func (c *sconn) LocalAddr() net.Addr { - addr := c.rawConn.LocalAddr() - if c.info.addr.IsValid() { - if udpAddr, ok := addr.(*net.UDPAddr); ok { - addrCopy := *udpAddr - addrCopy.IP = c.info.addr.AsSlice() - addr = &addrCopy - } +func (c *sconn) capabilities() connCapabilities { + capabilities := c.rawConn.capabilities() + if capabilities.GSO { + capabilities.GSO = !c.gotGSOError } - return addr + return capabilities } + +func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *sconn) LocalAddr() net.Addr { return c.localAddr } diff --git a/send_conn_test.go b/send_conn_test.go index 56fe92366..382c0ccaf 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -2,46 +2,81 @@ package quic import ( "net" + "net/netip" + + "github.com/refraction-networking/uquic/internal/utils" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) -var _ = Describe("Connection (for sending packets)", func() { - var ( - c sendConn - packetConn *MockPacketConn - addr net.Addr - ) - - BeforeEach(func() { - addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - packetConn = NewMockPacketConn(mockCtrl) - rawConn, err := wrapConn(packetConn) - Expect(err).ToNot(HaveOccurred()) - c = newSendConnWithPacketInfo(rawConn, addr, packetInfo{}) - }) +// Only if appendUDPSegmentSizeMsg actually appends a message (and isn't only a stub implementation), +// GSO is actually supported on this platform. +var platformSupportsGSO = len(appendUDPSegmentSizeMsg([]byte{}, 1337)) > 0 - It("writes", func() { - packetConn.EXPECT().WriteTo([]byte("foobar"), addr) - Expect(c.Write([]byte("foobar"), 6)).To(Succeed()) - }) +var _ = Describe("Connection (for sending packets)", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - It("gets the remote address", func() { + It("gets the local and remote addresses", func() { + localAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1234} + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr().Return(localAddr) + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + Expect(c.LocalAddr().String()).To(Equal("192.168.0.1:1234")) Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337")) }) - It("gets the local address", func() { - addr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 0, 1), - Port: 1234, - } - packetConn.EXPECT().LocalAddr().Return(addr) - Expect(c.LocalAddr()).To(Equal(addr)) + It("uses the local address from the packet info", func() { + localAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1234} + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr().Return(localAddr) + c := newSendConn(rawConn, remoteAddr, packetInfo{addr: netip.AddrFrom4([4]byte{127, 0, 0, 42})}, utils.DefaultLogger) + Expect(c.LocalAddr().String()).To(Equal("127.0.0.42:1234")) }) - It("closes", func() { - packetConn.EXPECT().Close() - Expect(c.Close()).To(Succeed()) - }) + if platformSupportsGSO { + It("writes with GSO", func() { + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).Do(func(_ []byte, _ net.Addr, oob []byte) { + msg := appendUDPSegmentSizeMsg([]byte{}, 3) + Expect(oob).To(Equal(msg)) + }) + Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + }) + + It("disables GSO if writing fails", func() { + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + Expect(c.capabilities().GSO).To(BeTrue()) + gomock.InOrder( + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).DoAndReturn(func(_ []byte, _ net.Addr, oob []byte) (int, error) { + msg := appendUDPSegmentSizeMsg([]byte{}, 3) + Expect(oob).To(Equal(msg)) + return 0, errGSO + }), + rawConn.EXPECT().WritePacket([]byte("foo"), remoteAddr, gomock.Len(0)).Return(3, nil), + rawConn.EXPECT().WritePacket([]byte("bar"), remoteAddr, gomock.Len(0)).Return(3, nil), + ) + Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + Expect(c.capabilities().GSO).To(BeFalse()) // GSO support is now disabled + // make sure we actually enforce that + Expect(func() { c.Write([]byte("foobar"), 3) }).To(PanicWith("inconsistent packet size (6 vs 3)")) + }) + } else { + It("writes without GSO", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Len(0)) + Expect(c.Write([]byte("foobar"), 6)).To(Succeed()) + }) + } }) diff --git a/send_queue_test.go b/send_queue_test.go index 1bc260b16..6abe56123 100644 --- a/send_queue_test.go +++ b/send_queue_test.go @@ -5,9 +5,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Send Queue", func() { diff --git a/send_stream_test.go b/send_stream_test.go index 0a9f74996..ab5589f92 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -11,7 +11,6 @@ import ( "golang.org/x/exp/rand" - "github.com/golang/mock/gomock" "github.com/refraction-networking/uquic/internal/ackhandler" "github.com/refraction-networking/uquic/internal/mocks" "github.com/refraction-networking/uquic/internal/protocol" @@ -20,6 +19,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/onsi/gomega/gbytes" + "go.uber.org/mock/gomock" ) var _ = Describe("Send Stream", func() { diff --git a/server.go b/server.go index 1815be3fc..5b80ec29c 100644 --- a/server.go +++ b/server.go @@ -633,7 +633,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) } conn = s.newConn( - newSendConnWithPacketInfo(s.conn, p.remoteAddr, p.info), + newSendConn(s.conn, p.remoteAddr, p.info, s.logger), s.connHandler, origDestConnID, retrySrcConnID, @@ -743,7 +743,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, uint16(len(buf.Data)), remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) return err } @@ -842,7 +842,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, uint16(len(b.Data)), remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) return err } @@ -880,7 +880,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { if s.tracer != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/server_test.go b/server_test.go index c3064016e..d1563ad30 100644 --- a/server_test.go +++ b/server_test.go @@ -21,9 +21,9 @@ import ( "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Server", func() { diff --git a/streams_map_incoming_test.go b/streams_map_incoming_test.go index 3928997bb..897e0883d 100644 --- a/streams_map_incoming_test.go +++ b/streams_map_incoming_test.go @@ -10,9 +10,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) type mockGenericStream struct { diff --git a/streams_map_outgoing_test.go b/streams_map_outgoing_test.go index 4ab03581c..43d9d3db2 100644 --- a/streams_map_outgoing_test.go +++ b/streams_map_outgoing_test.go @@ -13,9 +13,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Streams Map (outgoing)", func() { diff --git a/streams_map_test.go b/streams_map_test.go index 1afc1db9f..0fceffd66 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -6,8 +6,6 @@ import ( "fmt" "net" - "github.com/golang/mock/gomock" - "github.com/refraction-networking/uquic/internal/flowcontrol" "github.com/refraction-networking/uquic/internal/mocks" "github.com/refraction-networking/uquic/internal/protocol" @@ -16,6 +14,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func (e streamError) TestError() error { diff --git a/sys_conn.go b/sys_conn.go index d1ba2bbd4..88e098fd6 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -1,7 +1,6 @@ package quic import ( - "fmt" "log" "net" "os" @@ -105,10 +104,7 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, _ []byte) (n int, err error) { - if uint16(len(b)) != packetSize { - panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b))) - } +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } diff --git a/sys_conn_df_linux.go b/sys_conn_df_linux.go index 4a988b337..5c8da7d02 100644 --- a/sys_conn_df_linux.go +++ b/sys_conn_df_linux.go @@ -4,11 +4,7 @@ package quic import ( "errors" - "log" - "os" - "strconv" "syscall" - "unsafe" "golang.org/x/sys/unix" @@ -38,43 +34,9 @@ func setDF(rawConn syscall.RawConn) (bool, error) { return true, nil } -func maybeSetGSO(rawConn syscall.RawConn) bool { - enable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_ENABLE_GSO")) - if !enable { - return false - } - - var setErr error - if err := rawConn.Control(func(fd uintptr) { - setErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_UDP, unix.UDP_SEGMENT, 1) - }); err != nil { - setErr = err - } - if setErr != nil { - log.Println("failed to enable GSO") - return false - } - return true -} - func isSendMsgSizeErr(err error) bool { // https://man7.org/linux/man-pages/man7/udp.7.html return errors.Is(err, unix.EMSGSIZE) } -func isRecvMsgSizeErr(err error) bool { return false } - -func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte { - startLen := len(b) - const dataLen = 2 // payload is a uint16 - b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) - h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) - h.Level = syscall.IPPROTO_UDP - h.Type = unix.UDP_SEGMENT - h.SetLen(unix.CmsgLen(dataLen)) - - // UnixRights uses the private `data` method, but I *think* this achieves the same goal. - offset := startLen + unix.CmsgSpace(0) - *(*uint16)(unsafe.Pointer(&b[offset])) = size - return b -} +func isRecvMsgSizeErr(error) bool { return false } diff --git a/sys_conn_helper_darwin.go b/sys_conn_helper_darwin.go index bf735f0f5..758cf7788 100644 --- a/sys_conn_helper_darwin.go +++ b/sys_conn_helper_darwin.go @@ -5,6 +5,7 @@ package quic import ( "encoding/binary" "net/netip" + "syscall" "golang.org/x/sys/unix" ) @@ -29,3 +30,5 @@ func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) { } return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.LittleEndian.Uint32(body), true } + +func isGSOSupported(syscall.RawConn) bool { return false } diff --git a/sys_conn_helper_freebsd.go b/sys_conn_helper_freebsd.go index fe5a7c20e..a2baae3b3 100644 --- a/sys_conn_helper_freebsd.go +++ b/sys_conn_helper_freebsd.go @@ -4,6 +4,7 @@ package quic import ( "net/netip" + "syscall" "golang.org/x/sys/unix" ) @@ -24,3 +25,5 @@ func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) { } return netip.AddrFrom4(*(*[4]byte)(body)), 0, true } + +func isGSOSupported(syscall.RawConn) bool { return false } diff --git a/sys_conn_helper_linux.go b/sys_conn_helper_linux.go index 61224eaaf..6a049241b 100644 --- a/sys_conn_helper_linux.go +++ b/sys_conn_helper_linux.go @@ -4,8 +4,12 @@ package quic import ( "encoding/binary" + "errors" "net/netip" + "os" + "strconv" "syscall" + "unsafe" "golang.org/x/sys/unix" ) @@ -48,3 +52,46 @@ func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) { } return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.LittleEndian.Uint32(body), true } + +// isGSOSupported tests if the kernel supports GSO. +// Sending with GSO might still fail later on, if the interface doesn't support it (see isGSOError). +func isGSOSupported(conn syscall.RawConn) bool { + disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_GSO")) + if err == nil && disabled { + return false + } + var serr error + if err := conn.Control(func(fd uintptr) { + _, serr = unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + }); err != nil { + return false + } + return serr == nil +} + +func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte { + startLen := len(b) + const dataLen = 2 // payload is a uint16 + b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) + h.Level = syscall.IPPROTO_UDP + h.Type = unix.UDP_SEGMENT + h.SetLen(unix.CmsgLen(dataLen)) + + // UnixRights uses the private `data` method, but I *think* this achieves the same goal. + offset := startLen + unix.CmsgSpace(0) + *(*uint16)(unsafe.Pointer(&b[offset])) = size + return b +} + +func isGSOError(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have tx checksums enabled, + // which is a hard requirement of UDP_SEGMENT. See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/sys_conn_helper_linux_test.go b/sys_conn_helper_linux_test.go index fa39c5233..4cf59abe2 100644 --- a/sys_conn_helper_linux_test.go +++ b/sys_conn_helper_linux_test.go @@ -1,22 +1,24 @@ -// We need root permissions to use RCVBUFFORCE. -// This test is therefore only compiled when the root build flag is set. -// It can only succeed if the tests are then also run with root permissions. -//go:build linux && root +//go:build linux package quic import ( + "errors" "net" "os" + "golang.org/x/sys/unix" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) +var errGSO = &os.SyscallError{Err: unix.EIO} + var _ = Describe("forcing a change of send and receive buffer sizes", func() { It("forces a change of the receive buffer size", func() { if os.Getuid() != 0 { - Fail("Must be root to force change the receive buffer size") + Skip("Must be root to force change the receive buffer size") } c, err := net.ListenPacket("udp", "127.0.0.1:0") @@ -43,7 +45,7 @@ var _ = Describe("forcing a change of send and receive buffer sizes", func() { It("forces a change of the send buffer size", func() { if os.Getuid() != 0 { - Fail("Must be root to force change the send buffer size") + Skip("Must be root to force change the send buffer size") } c, err := net.ListenPacket("udp", "127.0.0.1:0") @@ -67,4 +69,10 @@ var _ = Describe("forcing a change of send and receive buffer sizes", func() { // The kernel doubles this value (to allow space for bookkeeping overhead) Expect(size).To(Equal(2 * large)) }) + + It("detects GSO errors", func() { + Expect(isGSOError(errGSO)).To(BeTrue()) + Expect(isGSOError(nil)).To(BeFalse()) + Expect(isGSOError(errors.New("test"))).To(BeFalse()) + }) }) diff --git a/sys_conn_helper_nonlinux.go b/sys_conn_helper_nonlinux.go index 80b795c33..cace82d5d 100644 --- a/sys_conn_helper_nonlinux.go +++ b/sys_conn_helper_nonlinux.go @@ -4,3 +4,6 @@ package quic func forceSetReceiveBuffer(c any, bytes int) error { return nil } func forceSetSendBuffer(c any, bytes int) error { return nil } + +func appendUDPSegmentSizeMsg([]byte, uint16) []byte { return nil } +func isGSOError(error) bool { return false } diff --git a/sys_conn_helper_nonlinux_test.go b/sys_conn_helper_nonlinux_test.go new file mode 100644 index 000000000..29d42ad33 --- /dev/null +++ b/sys_conn_helper_nonlinux_test.go @@ -0,0 +1,7 @@ +//go:build !linux + +package quic + +import "errors" + +var errGSO = errors.New("fake GSO error") diff --git a/sys_conn_no_gso.go b/sys_conn_no_gso.go deleted file mode 100644 index 6f6a8c917..000000000 --- a/sys_conn_no_gso.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build darwin || freebsd - -package quic - -import "syscall" - -func maybeSetGSO(_ syscall.RawConn) bool { return false } -func appendUDPSegmentSizeMsg(_ []byte, _ uint16) []byte { return nil } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 140bb2f47..66d5ce67c 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -5,7 +5,6 @@ package quic import ( "encoding/binary" "errors" - "fmt" "log" "net" "net/netip" @@ -128,10 +127,6 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { bc = ipv4.NewPacketConn(c) } - // Try enabling GSO. - // This will only succeed on Linux, and only for kernels > 4.18. - supportsGSO := maybeSetGSO(rawConn) - msgs := make([]ipv4.Message, batchSize) for i := range msgs { // preallocate the [][]byte @@ -142,9 +137,11 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { batchConn: bc, messages: msgs, readPos: batchSize, + cap: connCapabilities{ + DF: supportsDF, + GSO: isGSOSupported(rawConn), + }, } - oobConn.cap.DF = supportsDF - oobConn.cap.GSO = supportsGSO for i := 0; i < batchSize; i++ { oobConn.messages[i].OOB = make([]byte, oobBufferSize) } @@ -231,17 +228,9 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { } // WritePacket writes a new packet. -// If the connection supports GSO (and we activated GSO support before), -// it appends the UDP_SEGMENT size message to oob. -// Callers are advised to make sure that oob has a sufficient capacity, -// such that appending the UDP_SEGMENT size message doesn't cause an allocation. -func (c *oobConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, oob []byte) (n int, err error) { - if c.cap.GSO { - oob = appendUDPSegmentSizeMsg(oob, packetSize) - } else if uint16(len(b)) != packetSize { - panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b))) - } - n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) +// If the connection supports GSO, it's the caller's responsibility to append the right control mesage. +func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) { + n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } diff --git a/sys_conn_oob_test.go b/sys_conn_oob_test.go index dce9d0b5f..0a4efb947 100644 --- a/sys_conn_oob_test.go +++ b/sys_conn_oob_test.go @@ -13,9 +13,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("OOB Conn Test", func() { diff --git a/sys_conn_test.go b/sys_conn_test.go index 41269a0fb..0af911da8 100644 --- a/sys_conn_test.go +++ b/sys_conn_test.go @@ -6,10 +6,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Basic Conn Test", func() { diff --git a/tools.go b/tools.go index e848317f1..d00ce748d 100644 --- a/tools.go +++ b/tools.go @@ -3,6 +3,6 @@ package quic import ( - _ "github.com/golang/mock/mockgen" _ "github.com/onsi/ginkgo/v2/ginkgo" + _ "go.uber.org/mock/mockgen" ) diff --git a/transport.go b/transport.go index ff4cf1442..a0d0784a3 100644 --- a/transport.go +++ b/transport.go @@ -6,14 +6,14 @@ import ( "errors" "net" "sync" + "sync/atomic" "time" tls "github.com/refraction-networking/utls" - "github.com/refraction-networking/uquic/internal/wire" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" + "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" ) @@ -86,6 +86,9 @@ type Transport struct { createdConn bool isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial + readingNonQUICPackets atomic.Bool + nonQUICPackets chan receivedPacket + logger utils.Logger } @@ -149,26 +152,15 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen // Dial dials a new connection to a remote host (not using 0-RTT). func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { - if err := validateConfig(conf); err != nil { - return nil, err - } - conf = populateConfig(conf) - - if err := t.init(t.isSingleUse); err != nil { - return nil, err - } - var onClose func() - if t.isSingleUse { - onClose = func() { t.Close() } - } - tlsConf = tlsConf.Clone() - tlsConf.MinVersion = tls.VersionTLS13 - - return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) + return t.dial(ctx, addr, "", tlsConf, conf, false) } // DialEarly dials a new connection, attempting to use 0-RTT if possible. func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + return t.dial(ctx, addr, "", tlsConf, conf, true) +} + +func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { if err := validateConfig(conf); err != nil { return nil, err } @@ -183,8 +175,8 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C } tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - - return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) + setTLSConfigServerName(tlsConf, addr, host) + return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) } func (t *Transport) init(allowZeroLengthConnIDs bool) error { @@ -200,7 +192,6 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { return } } - t.conn = conn t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn @@ -234,7 +225,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, uint16(len(b)), addr, nil) + return t.conn.WritePacket(b, addr, nil) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -252,7 +243,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, uint16(len(p.payload)), p.addr, p.info.OOB()) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -347,6 +338,13 @@ func (t *Transport) listen(conn rawConn) { } func (t *Transport) handlePacket(p receivedPacket) { + if len(p.data) == 0 { + return + } + if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) { + t.handleNonQUICPacket(p) + return + } connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) @@ -413,7 +411,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } } @@ -435,3 +433,61 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool { } return false } + +func (t *Transport) handleNonQUICPacket(p receivedPacket) { + // Strictly speaking, this is racy, + // but we only care about receiving packets at some point after ReadNonQUICPacket has been called. + if !t.readingNonQUICPackets.Load() { + return + } + select { + case t.nonQUICPackets <- p: + default: + if t.Tracer != nil { + t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + } + } +} + +const maxQueuedNonQUICPackets = 32 + +// ReadNonQUICPacket reads non-QUIC packets received on the underlying connection. +// The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0. +// Note that this is stricter than the detection logic defined in RFC 9443. +func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) { + if err := t.init(false); err != nil { + return 0, nil, err + } + if !t.readingNonQUICPackets.Load() { + t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets) + t.readingNonQUICPackets.Store(true) + } + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + case p := <-t.nonQUICPackets: + n := copy(b, p.data) + return n, p.remoteAddr, nil + case <-t.listening: + return 0, nil, errors.New("closed") + } +} + +func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) { + // If no ServerName is set, infer the ServerName from the host we're connecting to. + if tlsConf.ServerName != "" { + return + } + if host == "" { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + tlsConf.ServerName = udpAddr.IP.String() + return + } + } + h, _, err := net.SplitHostPort(host) + if err != nil { // This happens if the host doesn't contain a port number. + tlsConf.ServerName = host + return + } + tlsConf.ServerName = h +} diff --git a/transport_test.go b/transport_test.go index 025e53228..588562184 100644 --- a/transport_test.go +++ b/transport_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "context" "crypto/rand" "errors" "net" @@ -15,9 +16,9 @@ import ( "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Transport", func() { @@ -123,7 +124,7 @@ var _ = Describe("Transport", func() { tr.Close() }) - It("drops unparseable packets", func() { + It("drops unparseable QUIC packets", func() { addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) tracer := mocklogging.NewMockTracer(mockCtrl) @@ -137,7 +138,7 @@ var _ = Describe("Transport", func() { tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) packetChan <- packetToRead{ addr: addr, - data: []byte{0, 1, 2, 3}, + data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3}, } Eventually(dropped).Should(BeClosed()) @@ -324,6 +325,90 @@ var _ = Describe("Transport", func() { conns := getMultiplexer().(*connMultiplexer).conns Expect(len(conns)).To(BeZero()) }) + + It("allows receiving non-QUIC packets", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} + packetChan := make(chan packetToRead) + tracer := mocklogging.NewMockTracer(mockCtrl) + tr := &Transport{ + Conn: newMockPacketConn(packetChan), + ConnectionIDLength: 10, + Tracer: tracer, + } + tr.init(true) + receivedPacketChan := make(chan []byte) + go func() { + defer GinkgoRecover() + b := make([]byte, 100) + n, addr, err := tr.ReadNonQUICPacket(context.Background(), b) + Expect(err).ToNot(HaveOccurred()) + Expect(addr).To(Equal(remoteAddr)) + receivedPacketChan <- b[:n] + }() + // Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called. + // Give the Go routine some time to spin up. + time.Sleep(scaleDuration(50 * time.Millisecond)) + packetChan <- packetToRead{ + addr: remoteAddr, + data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, + } + + Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3}))) + + // shutdown + close(packetChan) + tr.Close() + }) + + It("drops non-QUIC packet if the application doesn't process them quickly enough", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} + packetChan := make(chan packetToRead) + tracer := mocklogging.NewMockTracer(mockCtrl) + tr := &Transport{ + Conn: newMockPacketConn(packetChan), + ConnectionIDLength: 10, + Tracer: tracer, + } + tr.init(true) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10)) + Expect(err).To(MatchError(context.Canceled)) + + for i := 0; i < maxQueuedNonQUICPackets; i++ { + packetChan <- packetToRead{ + addr: remoteAddr, + data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, + } + } + + done := make(chan struct{}) + tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(done) + }) + packetChan <- packetToRead{ + addr: remoteAddr, + data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, + } + Eventually(done).Should(BeClosed()) + + // shutdown + close(packetChan) + tr.Close() + }) + + remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234} + DescribeTable("setting the tls.Config.ServerName", + func(expected string, conf *tls.Config, addr net.Addr, host string) { + setTLSConfigServerName(conf, addr, host) + Expect(conf.ServerName).To(Equal(expected)) + }, + Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"), + Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"), + Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"), + Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""), + ) }) type mockSyscallConn struct { diff --git a/u_connection.go b/u_connection.go index 37871f158..dc118bc6c 100644 --- a/u_connection.go +++ b/u_connection.go @@ -82,7 +82,7 @@ var newUClientConnection = func( } s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - oneRTTStream := newCryptoStream() + oneRTTStream := newCryptoStream(true) var params *wire.TransportParameters diff --git a/u_parrot.go b/u_parrot.go index 0727bbe75..a658ebfa6 100644 --- a/u_parrot.go +++ b/u_parrot.go @@ -75,7 +75,7 @@ func QUICID2Spec(id QUICID) (QUICSpec, error) { CompressionMethods: []uint8{ 0x0, // no compression }, - Extensions: ShuffleTLSExtensions([]tls.TLSExtension{ + Extensions: tls.ShuffleChromeTLSExtensions([]tls.TLSExtension{ ShuffleQUICTransportParameters(&tls.QUICTransportParametersExtension{ // Order of QTPs are always shuffled TransportParameters: tls.TransportParameters{ tls.InitialMaxStreamsUni(103), @@ -192,7 +192,7 @@ func QUICID2Spec(id QUICID) (QUICSpec, error) { CompressionMethods: []uint8{ 0x0, }, - Extensions: ShuffleTLSExtensions([]tls.TLSExtension{ + Extensions: tls.ShuffleChromeTLSExtensions([]tls.TLSExtension{ ShuffleQUICTransportParameters(&tls.QUICTransportParametersExtension{ // Order of QTPs are always shuffled TransportParameters: tls.TransportParameters{ tls.InitialMaxStreamsUni(103), diff --git a/u_transport.go b/u_transport.go index c1ee3228c..7c140109a 100644 --- a/u_transport.go +++ b/u_transport.go @@ -5,6 +5,7 @@ import ( "net" "github.com/refraction-networking/uquic/internal/protocol" + "github.com/refraction-networking/uquic/internal/utils" tls "github.com/refraction-networking/utls" ) @@ -16,37 +17,15 @@ type UTransport struct { // Dial dials a new connection to a remote host (not using 0-RTT). func (t *UTransport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { - if err := validateConfig(conf); err != nil { - return nil, err - } - conf = populateConfig(conf) - - // [UQUIC] - // Override the default connection ID generator if the user has specified a length in QUICSpec. - if t.QUICSpec != nil { - if t.QUICSpec.InitialPacketSpec.SrcConnIDLength != 0 { - t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.QUICSpec.InitialPacketSpec.SrcConnIDLength} - } else { - t.ConnectionIDGenerator = &protocol.ExpEmptyConnectionIDGenerator{} - } - } - // [/UQUIC] - - if err := t.init(t.isSingleUse); err != nil { - return nil, err - } - var onClose func() - if t.isSingleUse { - onClose = func() { t.Close() } - } - tlsConf = tlsConf.Clone() - tlsConf.MinVersion = tls.VersionTLS13 - - return udial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false, t.QUICSpec) + return t.dial(ctx, addr, "", tlsConf, conf, false) } // DialEarly dials a new connection, attempting to use 0-RTT if possible. func (t *UTransport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + return t.dial(ctx, addr, "", tlsConf, conf, true) +} + +func (t *UTransport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { if err := validateConfig(conf); err != nil { return nil, err } @@ -72,8 +51,8 @@ func (t *UTransport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls. } tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - - return udial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true, t.QUICSpec) + setTLSConfigServerName(tlsConf, addr, host) + return udial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT, t.QUICSpec) } func (ut *UTransport) MakeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *Config) (EarlyConnection, error) {