diff --git a/controllers/providers/assistant/gslb.go b/controllers/providers/assistant/gslb.go index 1332a2afb8..de00eb0c50 100644 --- a/controllers/providers/assistant/gslb.go +++ b/controllers/providers/assistant/gslb.go @@ -104,7 +104,7 @@ func (r *Gslb) CoreDNSExposedIPs() ([]string, error) { func extractIPFromLB(lb corev1.LoadBalancerIngress, ns utils.DNSList) (ips []string, err error) { if lb.Hostname != "" { - IPs, err := utils.Dig(lb.Hostname, ns...) + IPs, err := utils.Dig(lb.Hostname, 8, ns...) if err != nil { log.Warn().Err(err). Str("loadBalancerHostname", lb.Hostname). diff --git a/controllers/refresolver/ingress/ingress.go b/controllers/refresolver/ingress/ingress.go index 8c4a4c5034..7fd8a5a421 100644 --- a/controllers/refresolver/ingress/ingress.go +++ b/controllers/refresolver/ingress/ingress.go @@ -166,7 +166,7 @@ func (rr *ReferenceResolver) GetGslbExposedIPs(edgeDNSServers utils.DNSList) ([] gslbIngressIPs = append(gslbIngressIPs, ip.IP) } if len(ip.Hostname) > 0 { - IPs, err := utils.Dig(ip.Hostname, edgeDNSServers...) + IPs, err := utils.Dig(ip.Hostname, 8, edgeDNSServers...) if err != nil { log.Warn().Err(err).Msg("Dig error") return nil, err diff --git a/controllers/refresolver/istiovirtualservice/istiovirtualservice.go b/controllers/refresolver/istiovirtualservice/istiovirtualservice.go index 93f4050f68..a2007b641a 100644 --- a/controllers/refresolver/istiovirtualservice/istiovirtualservice.go +++ b/controllers/refresolver/istiovirtualservice/istiovirtualservice.go @@ -199,7 +199,7 @@ func (rr *ReferenceResolver) GetGslbExposedIPs(edgeDNSServers utils.DNSList) ([] gslbIngressIPs = append(gslbIngressIPs, ip.IP) } if len(ip.Hostname) > 0 { - IPs, err := utils.Dig(ip.Hostname, edgeDNSServers...) + IPs, err := utils.Dig(ip.Hostname, 8, edgeDNSServers...) if err != nil { log.Warn().Err(err).Msg("Dig error") return nil, err diff --git a/controllers/utils/dns.go b/controllers/utils/dns.go index 583380d2e5..e318e25a75 100644 --- a/controllers/utils/dns.go +++ b/controllers/utils/dns.go @@ -47,12 +47,17 @@ func (l DNSList) String() string { // Dig returns a list of IP addresses for a given FQDN by using the dns servers from edgeDNSServers // dns servers are tried one by one from the edgeDNSServers and if there is a non-error response it is returned and the rest is not tried -func Dig(fqdn string, edgeDNSServers ...DNSServer) (ips []string, err error) { +func Dig(fqdn string, maxRecursion int, edgeDNSServers ...DNSServer) (ips []string, err error) { + if maxRecursion < 0 { + return nil, fmt.Errorf("maximum recursion limit reached") + } + maxRecursion-- + if len(edgeDNSServers) == 0 { return nil, fmt.Errorf("empty edgeDNSServers, provide at least one") } if len(fqdn) == 0 { - return + return ips, nil } if !strings.HasSuffix(fqdn, ".") { @@ -64,11 +69,45 @@ func Dig(fqdn string, edgeDNSServers ...DNSServer) (ips []string, err error) { if err != nil { return nil, fmt.Errorf("dig error: %s", err) } + aRecords := make([]*dns.A, 0) + cnameRecords := make([]*dns.CNAME, 0) for _, a := range ack.Answer { - ips = append(ips, a.(*dns.A).A.String()) + switch v := a.(type) { + case *dns.A: + ips = append(ips, v.A.String()) + aRecords = append(aRecords, v) + case *dns.CNAME: + cnameRecords = append(cnameRecords, v) + } + } + aResolved := func(c *dns.CNAME) bool { + for _, a := range aRecords { + if c.Target == a.Hdr.Name { + return true + } + } + return false + } + cResolved := func(c *dns.CNAME) bool { + for _, cname := range cnameRecords { + if c.Target == cname.Hdr.Name { + return true + } + } + return false + } + // Check for non-resolved CNAMEs + for _, cname := range cnameRecords { + if !aResolved(cname) && !cResolved(cname) { + cnameIPs, err := Dig(cname.Target, maxRecursion, edgeDNSServers...) + if err != nil { + return nil, err + } + ips = append(ips, cnameIPs...) + } } sort.Strings(ips) - return + return ips, nil } func Exchange(m *dns.Msg, edgeDNSServers []DNSServer) (msg *dns.Msg, err error) { diff --git a/controllers/utils/dns_test.go b/controllers/utils/dns_test.go index a1e12cf5b8..af57e84594 100644 --- a/controllers/utils/dns_test.go +++ b/controllers/utils/dns_test.go @@ -21,6 +21,7 @@ Generated by GoLic, for more details see: https://github.com/AbsaOSS/golic import ( "context" "fmt" + "net" "net/http" "strings" "testing" @@ -45,7 +46,7 @@ func TestValidDigFQDNWithDot(t *testing.T) { } fqdn := defaultFqdn // act - result, err := Dig(fqdn+".", defaultEdgeDNSServer) + result, err := Dig(fqdn+".", 8, defaultEdgeDNSServer) // assert if strings.Contains(fmt.Sprintf("%v", err), "timeout") { @@ -63,7 +64,7 @@ func TestValidDig(t *testing.T) { } fqdn := defaultFqdn // act - result, err := Dig(fqdn, defaultEdgeDNSServer) + result, err := Dig(fqdn, 8, defaultEdgeDNSServer) if err != nil && strings.HasSuffix(err.Error(), "->8.8.8.8:53: i/o timeout") { // udp 8.8.8.8:53 may be blocked on some local environments t.Skip() @@ -81,7 +82,7 @@ func TestEmptyFQDNButValidEdgeDNS(t *testing.T) { } fqdn := "" // act - result, err := Dig(fqdn, defaultEdgeDNSServer) + result, err := Dig(fqdn, 8, defaultEdgeDNSServer) // assert assert.NoError(t, err) assert.Nil(t, result) @@ -91,7 +92,7 @@ func TestEmptyEdgeDNS(t *testing.T) { // arrange fqdn := "whatever" // act - result, err := Dig(fqdn, DNSServer{Host: "", Port: 53}) + result, err := Dig(fqdn, 8, DNSServer{Host: "", Port: 53}) // assert assert.Error(t, err) assert.Nil(t, result) @@ -101,12 +102,30 @@ func TestEmptyDNSList(t *testing.T) { // arrange fqdn := "whatever" // act - result, err := Dig(fqdn, []DNSServer{}...) + result, err := Dig(fqdn, 8, []DNSServer{}...) // assert assert.Error(t, err) assert.Nil(t, result) } +func TestValidDigButMaxRecursion(t *testing.T) { + testServer := DNSServer{ + Host: server, + Port: port, + } + NewFakeDNS(testSettings). + AddCNAMERecord("foo.cloud.example.com.", "bar.cloud.example.com."). + AddCNAMERecord("bar.cloud.example.com.", "baz.cloud.example.com."). + AddARecord("baz.cloud.example.com.", net.IPv4(10, 1, 0, 3)). + Start(). + RunTestFunc(func() { + result, err := Dig("foo.cloud.example.com", 1, testServer) + assert.Error(t, err) + assert.Nil(t, result) + }) + +} + func TestOneValidEdgeDNSInTheList(t *testing.T) { if !connected() { t.Skipf("no connectivity, skipping") @@ -120,7 +139,7 @@ func TestOneValidEdgeDNSInTheList(t *testing.T) { } fqdn := defaultFqdn // act - result, err := Dig(fqdn, edgeDNSServers...) + result, err := Dig(fqdn, 8, edgeDNSServers...) // assert if err != nil && strings.HasSuffix(err.Error(), "->8.8.8.8:253: i/o timeout") { // udp 8.8.8.8:253 may be blocked on some local environments @@ -140,7 +159,7 @@ func TestNoValidEdgeDNSInTheList(t *testing.T) { } fqdn := defaultFqdn // act - result, err := Dig(fqdn, edgeDNSServers...) + result, err := Dig(fqdn, 8, edgeDNSServers...) // assert assert.Error(t, err) assert.Nil(t, result) @@ -155,7 +174,7 @@ func TestEmptyEdgeDNSInTheList(t *testing.T) { } fqdn := defaultFqdn // act - result, err := Dig(fqdn, edgeDNSServers...) + result, err := Dig(fqdn, 8, edgeDNSServers...) // assert assert.Error(t, err) assert.Nil(t, result) @@ -173,7 +192,7 @@ func TestMultipleValidEdgeDNSInTheList(t *testing.T) { } fqdn := defaultFqdn // act - result, err := Dig(fqdn, edgeDNSServers...) + result, err := Dig(fqdn, 8, edgeDNSServers...) // assert assert.NoError(t, err) assert.NotEmpty(t, result) @@ -185,12 +204,63 @@ func TestValidEdgeDNSButNonExistingFQDN(t *testing.T) { edgeDNSServer := "localhost" fqdn := "some-valid-ip-fqdn-123" // act - result, err := Dig(fqdn, DNSServer{Host: edgeDNSServer, Port: 53}) + result, err := Dig(fqdn, 8, DNSServer{Host: edgeDNSServer, Port: 53}) // assert assert.Error(t, err) assert.Nil(t, result) } +func TestDigCNAME(t *testing.T) { + testServer := DNSServer{ + Host: server, + Port: port, + } + testSettings := FakeDNSSettings{ + FakeDNSPort: port, + EdgeDNSZoneFQDN: "example.com.", + DNSZoneFQDN: "cloud.example.com.", + Dump: true, + } + NewFakeDNS(testSettings). + AddCNAMERecord("foo.cloud.example.com.", "bar.cloud.example.com."). + AddARecord("bar.cloud.example.com.", net.IPv4(10, 1, 0, 3)). + Start(). + RunTestFunc(func() { + result, err := Dig("foo.cloud.example.com", 8, testServer) + if err != nil { + panic(err) + } + assert.Equal(t, result, []string{"10.1.0.3"}) + }) + +} + +func TestDigCNAMERecursion(t *testing.T) { + testServer := DNSServer{ + Host: server, + Port: port, + } + testSettings := FakeDNSSettings{ + FakeDNSPort: port, + EdgeDNSZoneFQDN: "example.com.", + DNSZoneFQDN: "cloud.example.com.", + Dump: true, + } + NewFakeDNS(testSettings). + AddCNAMERecord("foo.cloud.example.com.", "bar.cloud.example.com."). + AddCNAMERecord("bar.cloud.example.com.", "baz.cloud.example.com."). + AddARecord("baz.cloud.example.com.", net.IPv4(10, 1, 0, 3)). + Start(). + RunTestFunc(func() { + result, err := Dig("foo.cloud.example.com", 8, testServer) + if err != nil { + panic(err) + } + assert.Equal(t, result, []string{"10.1.0.3"}) + }) + +} + func connected() (ok bool) { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) diff --git a/controllers/utils/fakedns.go b/controllers/utils/fakedns.go index 576a2dd794..52d3fd79de 100644 --- a/controllers/utils/fakedns.go +++ b/controllers/utils/fakedns.go @@ -36,6 +36,7 @@ type FakeDNSSettings struct { FakeDNSPort int EdgeDNSZoneFQDN string DNSZoneFQDN string + Dump bool } // DNSMock acts as DNS server but returns mock values @@ -125,6 +126,15 @@ func (m *DNSMock) AddAAAARecord(ip net.IP) *DNSMock { return m } +func (m *DNSMock) AddCNAMERecord(fqdn string, cname string) *DNSMock { + rr := &dns.CNAME{ + Hdr: dns.RR_Header{Name: fqdn, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 0}, + Target: cname, + } + m.records[dns.TypeA] = append(m.records[dns.TypeA], rr) + return m +} + func (m *DNSMock) listen() (err error) { dns.HandleFunc(m.settings.EdgeDNSZoneFQDN, m.handleReflect) for e := range m.serve() { @@ -190,7 +200,7 @@ func (m *DNSMock) handleReflect(w dns.ResponseWriter, r *dns.Msg) { if m.records[r.Question[0].Qtype] != nil { for _, rr := range m.records[r.Question[0].Qtype] { fqdn := strings.Split(rr.String(), "\t")[0] - if fqdn == r.Question[0].Name { + if fqdn == r.Question[0].Name || m.settings.Dump { msg.Answer = append(msg.Answer, rr) } }