From 22a9a34564410af581670ba0f15149cd6eafa0ac Mon Sep 17 00:00:00 2001 From: Andrew Ayer Date: Tue, 13 Oct 2020 17:37:43 -0400 Subject: [PATCH] Check for mismatched query IDs when using TCP --- client.go | 17 ++++++++++++----- client_test.go | 21 +++++++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 4630e3ced8..b7f125779c 100644 --- a/client.go +++ b/client.go @@ -185,12 +185,19 @@ func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err erro } co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout()))) - for { + if _, ok := co.Conn.(net.PacketConn); ok { + for { + r, err = co.ReadMsg() + // Ignore replies with mismatched IDs because they might be + // responses to earlier queries that timed out. + if err != nil || r.Id == m.Id { + break + } + } + } else { r, err = co.ReadMsg() - // Ignore replies with mismatched IDs because they might be - // responses to earlier queries that timed out. - if err != nil || r.Id == m.Id { - break + if err == nil && r.Id != m.Id { + err = ErrId } } rtt = time.Since(t) diff --git a/client_test.go b/client_test.go index f6c71b66b2..13168e5750 100644 --- a/client_test.go +++ b/client_test.go @@ -225,6 +225,27 @@ func TestClientSyncBadThenGoodID(t *testing.T) { } } +func TestClientSyncTCPBadID(t *testing.T) { + HandleFunc("miek.nl.", HelloServerBadID) + defer HandleRemove("miek.nl.") + + s, addrstr, err := RunLocalTCPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c := &Client{ + Net: "tcp", + } + if _, _, err := c.Exchange(m, addrstr); err != ErrId { + t.Errorf("did not find a bad Id") + } +} + func TestClientEDNS0(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.")