Skip to content

Commit

Permalink
Follow CNAMES on all lookups by default (#397)
Browse files Browse the repository at this point in the history
* CNAME following works for test lookups, integration tests failing tho

* fixed return type so integration tests pass

* add comments and clean up

* add flag --follow-cnames

* add DNAME integration test

* add DNAME comments

* add verbosity to test GH runner

* remove verbosity flag for debugging

* sigh...didn't add iterative flag to test

* mistakenly misnamed the zdns exe in the test

* add cname loop handling

* lint

* follow CNAMEs on external lookups too

* fixed bugs

* handle dnames, untested

* fixed bugs with dname handling

* PR comments

* removed guard code that prevented following cnames with external lookups

* return last good status when following CNAMEs, like dig

* remove unneeded guard

* return last good result/status if following cnames, fixes tests

* return last good status if we hit error traversing cnames

---------

Co-authored-by: Zakir Durumeric <zakird@gmail.com>
  • Loading branch information
phillip-stephens and zakird authored Jul 31, 2024
1 parent c58d8ca commit 22c38f1
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 57 deletions.
2 changes: 2 additions & 0 deletions src/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ type CLIConf struct {
NamePrefix string
NameOverride string
NameServerMode bool
FollowCNAMEs bool

Module string
Class uint16
Expand Down Expand Up @@ -159,6 +160,7 @@ func init() {
rootCmd.PersistentFlags().BoolVar(&GC.CheckingDisabled, "checking-disabled", false, "Sends DNS packets with the CD bit set")
rootCmd.PersistentFlags().BoolVar(&GC.RecycleSockets, "recycle-sockets", true, "Create long-lived unbound UDP socket for each thread at launch and reuse for all (UDP) queries")
rootCmd.PersistentFlags().BoolVar(&GC.NameServerMode, "name-server-mode", false, "Treats input as nameservers to query with a static query rather than queries to send to a static name server")
rootCmd.PersistentFlags().BoolVar(&GC.FollowCNAMEs, "follow-cnames", true, "Follow CNAMEs/DNAMEs in the lookup process")

rootCmd.PersistentFlags().StringVar(&GC.NameServersString, "name-servers", "", "List of DNS servers to use. Can be passed as comma-delimited string or via @/path/to/file. If no port is specified, defaults to 53.")
rootCmd.PersistentFlags().StringVar(&GC.LocalAddrString, "local-addr", "", "comma-delimited list of local addresses to use, serve as the source IP for outbound queries")
Expand Down
1 change: 1 addition & 0 deletions src/cli/worker_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func populateResolverConfig(gc *CLIConf, flags *pflag.FlagSet) *zdns.ResolverCon
// copy nameservers to resolver config
config.ExternalNameServers = gc.NameServers
config.LookupAllNameServers = gc.LookupAllNameServers
config.FollowCNAMEs = gc.FollowCNAMEs

if gc.UseNSID {
config.EdnsOptions = append(config.EdnsOptions, new(dns.EDNS0_NSID))
Expand Down
38 changes: 8 additions & 30 deletions src/zdns/alookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func (r *Resolver) DoTargetedLookup(name, nameServer string, ipMode IPVersionMod
res := IPResult{}
candidateSet := map[string][]Answer{}
cnameSet := map[string][]Answer{}
dnameSet := map[string][]Answer{}
var ipv4 []string
var ipv6 []string
var ipv4Trace Trace
Expand All @@ -37,7 +38,7 @@ func (r *Resolver) DoTargetedLookup(name, nameServer string, ipMode IPVersionMod
var ipv6status Status

if lookupIPv4 {
ipv4, ipv4Trace, ipv4status, _ = recursiveIPLookup(r, name, nameServer, dns.TypeA, candidateSet, cnameSet, name, 0, isIterative)
ipv4, ipv4Trace, ipv4status, _ = recursiveIPLookup(r, name, nameServer, dns.TypeA, candidateSet, cnameSet, dnameSet, name, 0, isIterative)
if len(ipv4) > 0 {
ipv4 = Unique(ipv4)
res.IPv4Addresses = make([]string, len(ipv4))
Expand All @@ -46,8 +47,9 @@ func (r *Resolver) DoTargetedLookup(name, nameServer string, ipMode IPVersionMod
}
candidateSet = map[string][]Answer{}
cnameSet = map[string][]Answer{}
dnameSet = map[string][]Answer{}
if lookupIPv6 {
ipv6, ipv6Trace, ipv6status, _ = recursiveIPLookup(r, name, nameServer, dns.TypeAAAA, candidateSet, cnameSet, name, 0, isIterative)
ipv6, ipv6Trace, ipv6status, _ = recursiveIPLookup(r, name, nameServer, dns.TypeAAAA, candidateSet, cnameSet, dnameSet, name, 0, isIterative)
if len(ipv6) > 0 {
ipv6 = Unique(ipv6)
res.IPv6Addresses = make([]string, len(ipv6))
Expand All @@ -73,7 +75,7 @@ func (r *Resolver) DoTargetedLookup(name, nameServer string, ipMode IPVersionMod

// recursiveIPLookup helper fn that recursively follows both A/AAAA records and CNAME records to find IP addresses
// returns an array of IP addresses, a trace of the lookups, a status, and an error
func recursiveIPLookup(r *Resolver, name, nameServer string, dnsType uint16, candidateSet map[string][]Answer, cnameSet map[string][]Answer, origName string, depth int, isIterative bool) ([]string, Trace, Status, error) {
func recursiveIPLookup(r *Resolver, name, nameServer string, dnsType uint16, candidateSet map[string][]Answer, cnameSet map[string][]Answer, dnameSet map[string][]Answer, origName string, depth int, isIterative bool) ([]string, Trace, Status, error) {
// avoid infinite loops
if name == origName && depth != 0 {
return nil, make(Trace, 0), StatusError, errors.New("infinite redirection loop")
Expand All @@ -99,8 +101,8 @@ func recursiveIPLookup(r *Resolver, name, nameServer string, dnsType uint16, can
return nil, trace, status, err
}

populateResults(result.Answers, dnsType, candidateSet, cnameSet, garbage)
populateResults(result.Additional, dnsType, candidateSet, cnameSet, garbage)
populateResults(result.Answers, dnsType, candidateSet, cnameSet, dnameSet, garbage)
populateResults(result.Additional, dnsType, candidateSet, cnameSet, dnameSet, garbage)
}
// our cache should now have any data that exists about the current name
if res, ok := candidateSet[name]; ok && len(res) > 0 {
Expand All @@ -113,7 +115,7 @@ func recursiveIPLookup(r *Resolver, name, nameServer string, dnsType uint16, can
} else if res, ok = cnameSet[name]; ok && len(res) > 0 {
// we have a CNAME and need to further recurse to find IPs
shortName := strings.ToLower(strings.TrimSuffix(res[0].Answer, "."))
res, secondTrace, status, err := recursiveIPLookup(r, shortName, nameServer, dnsType, candidateSet, cnameSet, origName, depth+1, isIterative)
res, secondTrace, status, err := recursiveIPLookup(r, shortName, nameServer, dnsType, candidateSet, cnameSet, dnameSet, origName, depth+1, isIterative)
trace = append(trace, secondTrace...)
return res, trace, status, err
} else if res, ok = garbage[name]; ok && len(res) > 0 {
Expand All @@ -124,27 +126,3 @@ func recursiveIPLookup(r *Resolver, name, nameServer string, dnsType uint16, can
return ips, trace, StatusNoError, nil
}
}

// populateResults is a helper function to populate the candidateSet, cnameSet, and garbage maps as recursiveIPLookup
// follows CNAME and A/AAAA records to get all IPs for a given domain
func populateResults(records []interface{}, dnsType uint16, candidateSet map[string][]Answer, cnameSet map[string][]Answer, garbage map[string][]Answer) {
for _, a := range records {
// filter only valid answers of requested type or CNAME (#163)
if ans, ok := a.(Answer); ok {
lowerCaseName := strings.ToLower(strings.TrimSuffix(ans.Name, "."))
// Verify that the answer type matches requested type
if VerifyAddress(ans.Type, ans.Answer) {
ansType := dns.StringToType[ans.Type]
if dnsType == ansType {
candidateSet[lowerCaseName] = append(candidateSet[lowerCaseName], ans)
} else if ok && dns.TypeCNAME == ansType {
cnameSet[lowerCaseName] = append(cnameSet[lowerCaseName], ans)
} else {
garbage[lowerCaseName] = append(garbage[lowerCaseName], ans)
}
} else {
garbage[lowerCaseName] = append(garbage[lowerCaseName], ans)
}
}
}
}
224 changes: 200 additions & 24 deletions src/zdns/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,30 +100,172 @@ func (r *Resolver) doSingleDstServerLookup(q Question, nameServer string, isIter
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
if isIterative {
r.verboseLog(0, "MIEKG-IN: iterative lookup for ", q.Name, " (", q.Type, ")")
result, trace, status, lookupErr := r.iterativeLookup(ctx, q, nameServer, 1, ".", make(Trace, 0))
r.verboseLog(0, "MIEKG-OUT: iterative lookup for ", q.Name, " (", q.Type, "): status: ", status, " , err: ", lookupErr)
return &result, trace, status, lookupErr
if r.followCNAMEs {
return r.followingLookup(ctx, q, nameServer, isIterative)
}
res, status, try, err := r.retryingLookup(ctx, q, nameServer, true)
res, trace, status, err := r.lookup(ctx, q, nameServer, isIterative)
if err != nil {
return &res, nil, status, fmt.Errorf("could not perform retrying lookup for name %v: %w", q.Name, err)
}
var t TraceStep
t.Result = res
t.DNSType = q.Type
t.DNSClass = q.Class
t.Name = q.Name
t.NameServer = nameServer
t.Layer = q.Name
t.Depth = 1
t.Cached = false
t.Try = try
trace := Trace{t}
return &res, trace, status, err
}

// lookup performs a DNS lookup for a given question and nameserver taking care of iterative and external lookups
func (r *Resolver) lookup(ctx context.Context, q Question, nameServer string, isIterative bool) (SingleQueryResult, Trace, Status, error) {
var res SingleQueryResult
var trace Trace
var status Status
var err error
if util.HasCtxExpired(&ctx) {
return res, trace, StatusTimeout, nil
}
if isIterative {
r.verboseLog(1, "MIEKG-IN: following iterative lookup for ", q.Name, " (", q.Type, ")")
res, trace, status, err = r.iterativeLookup(ctx, q, nameServer, 1, ".", trace)
r.verboseLog(1, "MIEKG-OUT: following iterative lookup for ", q.Name, " (", q.Type, "): status: ", status, " , err: ", err)
} else {
tries := 0
// external lookup
r.verboseLog(1, "MIEKG-IN: following external lookup for ", q.Name, " (", q.Type, ")")
res, status, tries, err = r.retryingLookup(ctx, q, nameServer, true)
r.verboseLog(1, "MIEKG-OUT: following external lookup for ", q.Name, " (", q.Type, ") with ", tries, " attempts: status: ", status, " , err: ", err)
var t TraceStep
t.Result = res
t.DNSType = q.Type
t.DNSClass = q.Class
t.Name = q.Name
t.NameServer = nameServer
t.Layer = q.Name
t.Depth = 1
t.Cached = false
t.Try = tries
trace = Trace{t}
}
return res, trace, status, err
}

// followingLoopup follows CNAMEs and DNAMEs in a DNS lookup for either an iterative or external lookup
// If an error occurs during the lookup, the last good result/status is returned along with the error and a full trace
// If an error occurs on the first lookup, the bad result/status is returned along with the error and a full trace
func (r *Resolver) followingLookup(ctx context.Context, q Question, nameServer string, isIterative bool) (*SingleQueryResult, Trace, Status, error) {
var res SingleQueryResult
var trace Trace
var status Status

candidateSet := make(map[string][]Answer)
cnameSet := make(map[string][]Answer)
garbage := make(map[string][]Answer)
allAnswerSet := make([]interface{}, 0)
dnameSet := make(map[string][]Answer)

originalName := q.Name // in case this is a CNAME, this keeps track of the original name while we change the question
currName := q.Name // this is the current name we are looking up
r.verboseLog(0, "MIEKG-IN: starting a C/DNAME following lookup for ", originalName, " (", q.Type, ")")
for i := 0; i < r.maxDepth; i++ {
q.Name = currName // update the question with the current name, this allows following CNAMEs
iterRes, iterTrace, iterStatus, lookupErr := r.lookup(ctx, q, nameServer, isIterative)
// append iterTrace to the global trace so we can return full trace
if iterTrace != nil {
trace = append(trace, iterTrace...)
}
if iterStatus != StatusNoError || lookupErr != nil {
if i == 0 {
// only have 1 result to return
return &iterRes, trace, iterStatus, lookupErr
}
// return the last good result/status if we're traversing CNAMEs
return &res, trace, status, errors.Wrapf(lookupErr, "iterative lookup failed for name %v at depth %d", q.Name, i)
}
// update the result with the latest iteration since there's no error
// We'll return the latest good result if we're traversing CNAMEs
res = iterRes
status = iterStatus

if q.Type == dns.TypeMX {
// MX records have a special lookup format, so we won't attempt to follow CNAMES here
return &res, trace, status, nil
}

// populateResults will parse the Answers and update the candidateSet, cnameSet, and garbage caching maps
populateResults(res.Answers, q.Type, candidateSet, cnameSet, dnameSet, garbage)
for _, ans := range res.Answers {
answer, ok := ans.(Answer)
if !ok {
continue
}
allAnswerSet = append(allAnswerSet, answer)
}

if isLookupComplete(originalName, candidateSet, cnameSet, dnameSet) {
return &SingleQueryResult{
Answers: allAnswerSet,
Additional: res.Additional,
Protocol: res.Protocol,
Resolver: res.Resolver,
Flags: res.Flags,
}, trace, StatusNoError, nil
}

if candidates, ok := cnameSet[currName]; ok && len(candidates) > 0 {
// we have a CNAME and need to further recurse to find IPs
currName = strings.ToLower(strings.TrimSuffix(candidates[0].Answer, "."))
continue
} else if candidates, ok = garbage[currName]; ok && len(candidates) > 0 {
return nil, trace, StatusError, errors.New("unexpected record type received")
}
// for each key in DNAMESet, check if the current name has a substring that matches the key.
// if so, replace that substring
foundDNameMatch := false
for k, v := range dnameSet {
if strings.Contains(currName, k) {
currName = strings.Replace(currName, k, strings.TrimSuffix(v[0].Answer, "."), 1)
foundDNameMatch = true
break
}
}
if foundDNameMatch {
continue
} else {
// we have no data whatsoever about this name. return an empty recordset to the user
return &iterRes, trace, StatusNoError, nil
}
}
log.Debugf("MIEKG-IN: max recursion depth reached for %s lookup", originalName)
return nil, trace, StatusServFail, nil
}

// isLookupComplete checks if there's a valid answer using the originalName and following CNAMES
// An illustrative example of why this fn is needed, say we're doing an A lookup for foo.com. There exists a CNAME from
// foo.com -> bar.com. Therefore, the candidate set will contain an A record for bar.com, and we need to ensure there's
// a complete path from foo.com -> bar.com -> bar.com's A record following the maps. This fn checks that path.
func isLookupComplete(originalName string, candidateSet map[string][]Answer, cNameSet map[string][]Answer, dNameSet map[string][]Answer) bool {
maxDepth := len(cNameSet) + len(dNameSet) + 1
currName := originalName
for i := 0; i < maxDepth; i++ {
if currName == originalName && i != 0 {
// we're in a loop
return true
}
if candidates, ok := candidateSet[currName]; ok && len(candidates) > 0 {
return true
}
if candidates, ok := cNameSet[currName]; ok && len(candidates) > 0 {
// CNAME found, update currName
currName = strings.ToLower(strings.TrimSuffix(candidates[0].Answer, "."))
continue
}
// for each key in DNAMESet, check if the current name has a substring that matches the key.
// if so, replace that substring
for k, v := range dNameSet {
if strings.Contains(currName, k) {
currName = strings.Replace(currName, k, strings.TrimSuffix(v[0].Answer, "."), 1)
break
}
}
}
return false
}

// TODO - This is incomplete. We only lookup all nameservers for the initial name server lookup, then just send the DNS query to this set.
// If we want to iteratively lookup all nameservers at each level of the query, we need to fix this.
// Issue - https://github.com/zmap/zdns/issues/362
Expand Down Expand Up @@ -207,7 +349,7 @@ func (r *Resolver) iterativeLookup(ctx context.Context, q Question, nameServer s
r.verboseLog(depth+2, "ITERATIVE_TIMEOUT ", q, ", Layer: ", layer, ", Nameserver: ", nameServer)
status = StatusIterTimeout
}
if status != StatusNoError {
if status != StatusNoError || err != nil {
r.verboseLog((depth + 1), "-> error occurred during lookup")
return result, trace, status, err
} else if len(result.Answers) != 0 || result.Flags.Authoritative {
Expand Down Expand Up @@ -408,7 +550,7 @@ func (r *Resolver) iterateOnAuthorities(ctx context.Context, q Question, depth i
}
for i, elem := range result.Authorities {
r.verboseLog(depth+1, "Trying Authority: ", elem)
ns, nsStatus, newLayer, newTrace := r.extractAuthority(ctx, elem, layer, depth, result, trace)
ns, nsStatus, newLayer, newTrace := r.extractAuthority(ctx, elem, layer, depth, &result, trace)
r.verboseLog((depth + 1), "Output from extract authorities: ", ns)
if nsStatus == StatusIterTimeout {
r.verboseLog((depth + 2), "--> Hit iterative timeout: ")
Expand All @@ -417,9 +559,9 @@ func (r *Resolver) iterateOnAuthorities(ctx context.Context, q Question, depth i
}
if nsStatus != StatusNoError {
var err error
newStatus, err := handleStatus(&nsStatus, err)
newStatus, err := handleStatus(nsStatus, err)
// default case we continue
if newStatus == nil && err == nil {
if err == nil {
if i+1 == len(result.Authorities) {
r.verboseLog((depth + 2), "--> Auth find Failed. Unknown error. No more authorities to try, terminating: ", nsStatus)
var r SingleQueryResult
Expand All @@ -434,7 +576,7 @@ func (r *Resolver) iterateOnAuthorities(ctx context.Context, q Question, depth i
if i+1 == len(result.Authorities) {
// We don't allow the continue fall through in order to report the last auth falure code, not STATUS_EROR
r.verboseLog((depth + 2), "--> Final auth find non-success. Last auth. Terminating: ", nsStatus)
return localResult, newTrace, *newStatus, err
return localResult, newTrace, newStatus, err
} else {
r.verboseLog((depth + 2), "--> Auth find non-success. Trying next: ", nsStatus)
continue
Expand All @@ -457,7 +599,7 @@ func (r *Resolver) iterateOnAuthorities(ctx context.Context, q Question, depth i
panic("should not be able to reach here")
}

func (r *Resolver) extractAuthority(ctx context.Context, authority interface{}, layer string, depth int, result SingleQueryResult, trace Trace) (string, Status, string, Trace) {
func (r *Resolver) extractAuthority(ctx context.Context, authority interface{}, layer string, depth int, result *SingleQueryResult, trace Trace) (string, Status, string, Trace) {
// Is it an answer
ans, ok := authority.(Answer)
if !ok {
Expand All @@ -475,7 +617,7 @@ func (r *Resolver) extractAuthority(ctx context.Context, authority interface{},
// Short circuit a lookup from the glue
// Normally this would be handled by caching, but we want to support following glue
// that would normally be cache poison. Because it's "ok" and quite common
res, status := checkGlue(server, result)
res, status := checkGlue(server, *result)
if status != StatusNoError {
// Fall through to normal query
var q Question
Expand Down Expand Up @@ -527,3 +669,37 @@ func FindTxtRecord(res *SingleQueryResult, regex *regexp.Regexp) (string, error)
}
return "", errors.New("no such TXT record found")
}

// populateResults is a helper function to populate the candidateSet, cnameSet, and garbage maps to follow CNAMES
// These maps are keyed by the domain name and contain the relevant answers for that domain
// candidateSet is a map of Answers that have a type matching the requested type.
// cnameSet is a map of Answers that are CNAME records
// dnameSet is a map of Answers that are DNAME records
// garbage is a map of Answers that are not of the requested type or CNAME records
// follows CNAME/DNAME and A/AAAA records to get all IPs for a given domain
func populateResults(records []interface{}, dnsType uint16, candidateSet map[string][]Answer, cnameSet map[string][]Answer, dnameSet map[string][]Answer, garbage map[string][]Answer) {
var ans Answer
var ok bool
for _, a := range records {
// filter only valid answers of requested type or CNAME (#163)
if ans, ok = a.(Answer); !ok {
continue
}
lowerCaseName := strings.ToLower(strings.TrimSuffix(ans.Name, "."))
// Verify that the answer type matches requested type
if VerifyAddress(ans.Type, ans.Answer) {
ansType := dns.StringToType[ans.Type]
if dnsType == ansType {
candidateSet[lowerCaseName] = append(candidateSet[lowerCaseName], ans)
} else if dns.TypeCNAME == ansType {
cnameSet[lowerCaseName] = append(cnameSet[lowerCaseName], ans)
} else if dns.TypeDNAME == ansType {
dnameSet[lowerCaseName] = append(dnameSet[lowerCaseName], ans)
} else {
garbage[lowerCaseName] = append(garbage[lowerCaseName], ans)
}
} else {
garbage[lowerCaseName] = append(garbage[lowerCaseName], ans)
}
}
}
Loading

0 comments on commit 22c38f1

Please sign in to comment.