Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass domains as NameServers and NameServer struct re-factor #435

Merged
merged 17 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
examples pass and code compiles
  • Loading branch information
phillip-stephens committed Sep 4, 2024
commit d21fa8b4f99501e418db85eadd476f4c42e3d8b7
4 changes: 2 additions & 2 deletions examples/multi_thread_lookup/multi_threaded.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ func initializeResolver(cache *zdns.Cache) *zdns.Resolver {
log.Fatal("Error getting local IP: ", err)
}
resolverConfig.LocalAddrsV4 = []net.IP{localAddr}
resolverConfig.ExternalNameServersV4 = []string{"1.1.1.1:53"}
resolverConfig.RootNameServersV4 = []string{"198.41.0.4:53"}
resolverConfig.ExternalNameServersV4 = []zdns.NameServer{{IP: net.ParseIP("1.1.1.1"), Port: 53}}
resolverConfig.RootNameServersV4 = []zdns.NameServer{{IP: net.ParseIP("198.41.0.4"), Port: 53}}
resolverConfig.IPVersionMode = zdns.IPv4Only
// Set any desired options on the ResolverConfig object
resolverConfig.Cache = cache
Expand Down
6 changes: 3 additions & 3 deletions examples/single_lookup/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func main() {
dnsQuestion := &zdns.Question{Name: domain, Type: dns.TypeA, Class: dns.ClassINET}
resolver := initializeResolver()

result, _, status, err := resolver.ExternalLookup(dnsQuestion, "1.1.1.1:53")
result, _, status, err := resolver.ExternalLookup(dnsQuestion, &zdns.NameServer{IP: net.ParseIP("1.1.1.1"), Port: 53})
if err != nil {
log.Fatal("Error looking up domain: ", err)
}
Expand Down Expand Up @@ -67,8 +67,8 @@ func initializeResolver() *zdns.Resolver {
// Set any desired options on the ResolverConfig object
resolverConfig.LogLevel = log.InfoLevel
resolverConfig.LocalAddrsV4 = []net.IP{localAddr}
resolverConfig.ExternalNameServersV4 = []string{"1.1.1.1:53"}
resolverConfig.RootNameServersV4 = []string{"198.41.0.4:53"}
resolverConfig.ExternalNameServersV4 = []zdns.NameServer{{IP: net.ParseIP("1.1.1.1"), Port: 53}}
resolverConfig.RootNameServersV4 = []zdns.NameServer{{IP: net.ParseIP("198.41.0.4"), Port: 53}}
resolverConfig.IPVersionMode = zdns.IPv4Only
// Create a new Resolver object with the ResolverConfig object, it will retain all settings set on the ResolverConfig object
resolver, err := zdns.InitResolver(resolverConfig)
Expand Down
2 changes: 1 addition & 1 deletion src/zdns/alookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (

// DoTargetedLookup performs a lookup of the given domain name against the given nameserver, looking up both IPv4 and IPv6 addresses
// Will follow CNAME records as well as A/AAAA records to get IP addresses
func (r *Resolver) DoTargetedLookup(name, nameServer string, isIterative, lookupA, lookupAAAA bool) (*IPResult, Trace, Status, error) {
func (r *Resolver) DoTargetedLookup(name string, nameServer *NameServer, isIterative, lookupA, lookupAAAA bool) (*IPResult, Trace, Status, error) {
name = strings.ToLower(name)
res := IPResult{}
singleQueryRes := &SingleQueryResult{}
Expand Down
104 changes: 49 additions & 55 deletions src/zdns/lookup.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/zdns/nslookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type NSResult struct {
}

// DoNSLookup performs a DNS NS lookup on the given name against the given name server.
func (r *Resolver) DoNSLookup(lookupName, nameServer string, isIterative, lookupA, lookupAAAA bool) (*NSResult, Trace, Status, error) {
func (r *Resolver) DoNSLookup(lookupName string, nameServer *NameServer, isIterative, lookupA, lookupAAAA bool) (*NSResult, Trace, Status, error) {
if len(lookupName) == 0 {
return nil, nil, "", errors.New("no name provided for NS lookup")
}
Expand Down
141 changes: 66 additions & 75 deletions src/zdns/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ type ResolverConfig struct {
IterativeTimeout time.Duration // applicable to iterative queries only, timeout for a single iteration step
Timeout time.Duration // timeout for the resolution of a single name
MaxDepth int
ExternalNameServersV4 []string // v4 name servers used for external lookups
ExternalNameServersV6 []string // v6 name servers used for external lookups
RootNameServersV4 []string // v4 root servers used for iterative lookups
RootNameServersV6 []string // v6 root servers used for iterative lookups
LookupAllNameServers bool // perform the lookup via all the nameservers for the domain
FollowCNAMEs bool // whether iterative lookups should follow CNAMEs/DNAMEs
DNSConfigFilePath string // path to the DNS config file, ex: /etc/resolv.conf
ExternalNameServersV4 []NameServer // v4 name servers used for external lookups
ExternalNameServersV6 []NameServer // v6 name servers used for external lookups
RootNameServersV4 []NameServer // v4 root servers used for iterative lookups
RootNameServersV6 []NameServer // v6 root servers used for iterative lookups
LookupAllNameServers bool // perform the lookup via all the nameservers for the domain
FollowCNAMEs bool // whether iterative lookups should follow CNAMEs/DNAMEs
DNSConfigFilePath string // path to the DNS config file, ex: /etc/resolv.conf

DNSSecEnabled bool
EdnsOptions []dns.EDNS0
Expand Down Expand Up @@ -110,13 +110,9 @@ func (rc *ResolverConfig) Validate() error {

// Validate all nameservers have ports and are valid IPs
for _, ns := range util.Concat(rc.ExternalNameServersV4, rc.ExternalNameServersV6) {
ipString, _, err := net.SplitHostPort(ns)
if err != nil {
return fmt.Errorf("could not parse external name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53 or [::1]:53", ns)
}
ip := net.ParseIP(ipString)
if ip == nil {
return fmt.Errorf("could not parse external name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53 or [::1]:53", ns)
ns.PopulateDefaultPort()
if isValid, reason := ns.IsValid(); !isValid {
return fmt.Errorf("invalid external name server: %s", reason)
}
}
// Root Nameservers
Expand All @@ -131,13 +127,9 @@ func (rc *ResolverConfig) Validate() error {

// Validate all nameservers have ports and are valid IPs
for _, ns := range util.Concat(rc.RootNameServersV4, rc.RootNameServersV6) {
ipString, _, err := net.SplitHostPort(ns)
if err != nil {
return fmt.Errorf("could not parse root name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53 or [::1]:53", ns)
}
ip := net.ParseIP(ipString)
if ip == nil {
return fmt.Errorf("could not parse root name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53 or [::1]:53", ns)
ns.PopulateDefaultPort()
if isValid, reason := ns.IsValid(); !isValid {
return fmt.Errorf("invalid root name server: %s", reason)
}
}

Expand Down Expand Up @@ -182,12 +174,8 @@ func (rc *ResolverConfig) Validate() error {

// Ensure no IPv6 link-local/multicast external/root nameservers are used
for _, ns := range util.Concat(rc.ExternalNameServersV6, rc.RootNameServersV6) {
ip, _, err := util.SplitHostPort(ns)
if err != nil {
return errors.Wrapf(err, "could not split host and port for nameserver: %s", ns)
}
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return fmt.Errorf("link-local IPv6 external/root nameservers are not supported: %v", ip)
if ns.IP.IsLinkLocalUnicast() || ns.IP.IsLinkLocalMulticast() {
return fmt.Errorf("link-local IPv6 external/root nameservers are not supported: %v", ns.IP)
}
}

Expand All @@ -208,11 +196,7 @@ func (rc *ResolverConfig) validateLoopbackConsistency() error {
allIPs := make([]net.IP, 0, allIPsLength)
allIPs = append(allIPs, allLocalAddrs...)
for _, ns := range util.Concat(allExternalNameServers, allRootNameServers) {
ip, _, err := util.SplitHostPort(ns)
if err != nil {
return errors.Wrapf(err, "could not split host and port for nameserver: %s", ns)
}
allIPs = append(allIPs, ip)
allIPs = append(allIPs, ns.IP)
}
allIPsLoopback := true
noneIPsLoopback := true
Expand All @@ -231,8 +215,18 @@ func (rc *ResolverConfig) validateLoopbackConsistency() error {

func (rc *ResolverConfig) PrintInfo() {
log.Infof("using local addresses: %v", util.Concat(rc.LocalAddrsV4, rc.LocalAddrsV6))
log.Infof("for non-iterative lookups, using external nameservers: %s", strings.Join(util.Concat(rc.ExternalNameServersV4, rc.ExternalNameServersV6), ", "))
log.Infof("for iterative lookups, using nameservers: %s", strings.Join(util.Concat(rc.RootNameServersV4, rc.RootNameServersV6), ", "))
externalNameServers := util.Concat(rc.ExternalNameServersV4, rc.ExternalNameServersV6)
rootNameServers := util.Concat(rc.RootNameServersV4, rc.RootNameServersV6)
externalNameServerStrings := make([]string, 0, len(externalNameServers))
rootNameServerStrings := make([]string, 0, len(rootNameServers))
for _, ns := range externalNameServers {
externalNameServerStrings = append(externalNameServerStrings, ns.String())
}
for _, ns := range rootNameServers {
rootNameServerStrings = append(rootNameServerStrings, ns.String())
}
log.Infof("for non-iterative lookups, using external nameservers: %s", strings.Join(externalNameServerStrings, ", "))
log.Infof("for iterative lookups, using nameservers: %s", strings.Join(rootNameServerStrings, ", "))
}

// NewResolverConfig creates a new ResolverConfig with default values.
Expand Down Expand Up @@ -294,8 +288,8 @@ type Resolver struct {
iterativeTimeout time.Duration
timeout time.Duration // timeout for the network conns
maxDepth int
externalNameServers []string // name servers used by external lookups (either OS or user specified)
rootNameServers []string // root servers used for iterative lookups
externalNameServers []NameServer // name servers used by external lookups (either OS or user specified)
rootNameServers []NameServer // root servers used for iterative lookups
lookupAllNameServers bool
followCNAMEs bool // whether iterative lookups should follow CNAMEs/DNAMEs

Expand Down Expand Up @@ -363,40 +357,45 @@ func InitResolver(config *ResolverConfig) (*Resolver, error) {
r.connInfoIPv6 = connInfo
}
// need to deep-copy here so we're not reliant on the state of the resolver config post-resolver creation
r.externalNameServers = make([]string, 0)
r.externalNameServers = make([]NameServer, 0, len(config.ExternalNameServersV4)+len(config.ExternalNameServersV6))
if config.IPVersionMode == IPv4Only || config.IPVersionMode == IPv4OrIPv6 {
ipv4Nameservers := make([]string, len(config.ExternalNameServersV4))
// copy over IPv4 nameservers
elemsCopied := copy(ipv4Nameservers, config.ExternalNameServersV4)
if elemsCopied != len(config.ExternalNameServersV4) {
log.Fatal("failed to copy entire IPv4 name servers list from config")
for _, ns := range config.ExternalNameServersV4 {
r.externalNameServers = append(r.externalNameServers, *ns.DeepCopy())
}
r.externalNameServers = append(r.externalNameServers, ipv4Nameservers...)
}
ipv6Nameservers := make([]string, len(config.ExternalNameServersV6))
if config.IPVersionMode == IPv6Only || config.IPVersionMode == IPv4OrIPv6 {
// copy over IPv6 nameservers
elemsCopied := copy(ipv6Nameservers, config.ExternalNameServersV6)
if elemsCopied != len(config.ExternalNameServersV6) {
log.Fatal("failed to copy entire IPv6 name servers list from config")
for _, ns := range config.ExternalNameServersV6 {
r.externalNameServers = append(r.externalNameServers, *ns.DeepCopy())
}
r.externalNameServers = append(r.externalNameServers, ipv6Nameservers...)
}
// deep copy external name servers from config to resolver
r.iterativeTimeout = config.IterativeTimeout
r.maxDepth = config.MaxDepth
r.rootNameServers = make([]string, 0, len(config.RootNameServersV4)+len(config.RootNameServersV6))
r.rootNameServers = make([]NameServer, 0, len(config.RootNameServersV4)+len(config.RootNameServersV6))
if r.ipVersionMode != IPv6Only && len(config.RootNameServersV4) == 0 {
// add IPv4 root servers
r.rootNameServers = append(r.rootNameServers, RootServersV4...)
for _, rootNS := range RootServersV4 {
ns := NameServer{IP: net.ParseIP(rootNS)}
ns.PopulateDefaultPort()
r.rootNameServers = append(r.rootNameServers, ns)
}
} else if r.ipVersionMode != IPv6Only {
r.rootNameServers = append(r.rootNameServers, config.RootNameServersV4...)
for _, ns := range config.RootNameServersV4 {
r.rootNameServers = append(r.rootNameServers, *ns.DeepCopy())
}
}
if r.ipVersionMode != IPv4Only && len(config.RootNameServersV6) == 0 {
// add IPv6 root servers
r.rootNameServers = append(r.rootNameServers, RootServersV6...)
// add IPv4 root servers
for _, rootNS := range RootServersV6 {
ns := NameServer{IP: net.ParseIP(rootNS)}
ns.PopulateDefaultPort()
r.rootNameServers = append(r.rootNameServers, ns)
}
} else if r.ipVersionMode != IPv4Only {
r.rootNameServers = append(r.rootNameServers, config.RootNameServersV6...)
for _, ns := range config.RootNameServersV6 {
r.rootNameServers = append(r.rootNameServers, *ns.DeepCopy())
}
}
return r, nil
}
Expand Down Expand Up @@ -444,39 +443,31 @@ func getConnectionInfo(localAddr []net.IP, transportMode transportMode, timeout
// multiple lookups concurrently, create a new Resolver object for each concurrent lookup.
// Returns the result of the lookup, the trace of the lookup (what each nameserver along the lookup returned), the
// status of the lookup, and any error that occurred.
func (r *Resolver) ExternalLookup(q *Question, dstServer string) (*SingleQueryResult, Trace, Status, error) {
func (r *Resolver) ExternalLookup(q *Question, dstServer *NameServer) (*SingleQueryResult, Trace, Status, error) {
if r.isClosed {
log.Fatal("resolver has been closed, cannot perform lookup")
}

if dstServer == "" {
if dstServer == nil {
dstServer = r.randomExternalNameServer()
log.Info("no name server provided for external lookup, using random external name server: ", dstServer)
}
dstServerWithPort, err := util.AddDefaultPortToDNSServerName(dstServer)
if err != nil {
return nil, nil, StatusIllegalInput, fmt.Errorf("could not parse name server (%s): %w. Correct format IPv4 1.1.1.1:53 or IPv6 [::1]:53", dstServer, err)
}
if dstServer != dstServerWithPort {
log.Info("no port provided for external lookup, using default port 53")
}
dstServerIP, _, err := util.SplitHostPort(dstServerWithPort)
if err != nil {
return nil, nil, StatusIllegalInput, fmt.Errorf("could not parse name server (%s): %w. Correct format IPv4 1.1.1.1:53 or IPv6 [::1]:53", dstServer, err)
dstServer.PopulateDefaultPort()
if isValid, reason := dstServer.IsValid(); !isValid {
return nil, nil, StatusIllegalInput, fmt.Errorf("could not parse name server (%s): %s", dstServer.String(), reason)
}
if util.IsIPv6(&dstServerIP) && r.connInfoIPv6 == nil {
if util.IsIPv6(&dstServer.IP) && r.connInfoIPv6 == nil {
return nil, nil, StatusIllegalInput, fmt.Errorf("IPv6 external lookup requested for domain %s but no IPv6 local addresses provided to resolver", q.Name)
} else if dstServerIP.To4() != nil && r.connInfoIPv4 == nil {
} else if dstServer.IP.To4() != nil && r.connInfoIPv4 == nil {
return nil, nil, StatusIllegalInput, fmt.Errorf("IPv4 external lookup requested for domain %s but no IPv4 local addresses provided to resolver", q.Name)
}
// check that local address and dstServer's don't have a loopback mismatch
if dstServerIP.To4() != nil && r.connInfoIPv4.localAddr.IsLoopback() != dstServerIP.IsLoopback() {
if dstServer.IP.To4() != nil && r.connInfoIPv4.localAddr.IsLoopback() != dstServer.IP.IsLoopback() {
return nil, nil, StatusIllegalInput, errors.New("cannot mix loopback and non-loopback addresses")
} else if util.IsIPv6(&dstServerIP) && r.connInfoIPv6.localAddr.IsLoopback() != dstServerIP.IsLoopback() {
} else if util.IsIPv6(&dstServer.IP) && r.connInfoIPv6.localAddr.IsLoopback() != dstServer.IP.IsLoopback() {
return nil, nil, StatusIllegalInput, errors.New("cannot mix loopback and non-loopback addresses")
}
// dstServer has been validated and has a port, continue with lookup
dstServer = dstServerWithPort
lookup, trace, status, err := r.lookupClient.DoSingleDstServerLookup(r, *q, dstServer, false)
return lookup, trace, status, err
}
Expand Down Expand Up @@ -509,20 +500,20 @@ func (r *Resolver) Close() {
}
}

func (r *Resolver) randomExternalNameServer() string {
func (r *Resolver) randomExternalNameServer() *NameServer {
l := len(r.externalNameServers)
if r.externalNameServers == nil || l == 0 {
log.Fatal("no external name servers specified")
}
return r.externalNameServers[rand.Intn(l)]
return &r.externalNameServers[rand.Intn(l)]
}

func (r *Resolver) randomRootNameServer() string {
func (r *Resolver) randomRootNameServer() *NameServer {
l := len(r.rootNameServers)
if r.rootNameServers == nil || l == 0 {
log.Fatal("no root name servers specified")
}
return r.rootNameServers[rand.Intn(l)]
return &r.rootNameServers[rand.Intn(l)]
}

func (r *Resolver) verboseLog(depth int, args ...interface{}) {
Expand Down
57 changes: 56 additions & 1 deletion src/zdns/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
*/
package zdns

import "fmt"
import (
"fmt"
"net"

"github.com/zmap/zdns/src/internal/util"
)

type transportMode int

Expand Down Expand Up @@ -92,3 +97,53 @@ func (iip IterationIPPreference) IsValid() (bool, string) {
}
return true, ""
}

type NameServer struct {
IP net.IP // ip address, required
Port uint16 // udp/tcp port
DomainName string // used for SNI with TLS, required if you want to validate server certs
}

func (ns *NameServer) String() string {
if ns == nil || ns.IP == nil {
return ""
}
if ns.IP.To4() != nil {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
} else if util.IsIPv6(&ns.IP) {
return fmt.Sprintf("[%s]:%d", ns.IP.String(), ns.Port)
}
return ""
}

func (ns *NameServer) PopulateDefaultPort() {
if ns.Port == 0 {
ns.Port = GetDefaultPort()
}
}

func (ns *NameServer) IsValid() (bool, string) {
if ns.IP == nil {
return false, "missing IP address"
}
if ns.IP != nil && ns.IP.To4() == nil && ns.IP.To16() == nil {
return false, "invalid IP address"
}
if ns.Port == 0 {
return false, "missing port"
}
return true, ""
}

func (ns *NameServer) DeepCopy() *NameServer {
if ns == nil {
return nil
}
ip := make(net.IP, len(ns.IP))
copy(ip, ns.IP)
return &NameServer{
IP: ip,
Port: ns.Port,
DomainName: ns.DomainName,
}
}
8 changes: 8 additions & 0 deletions src/zdns/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ import (
"github.com/zmap/dns"
)

const (
DefaultPort = 53
)

func dotName(name string) string {
return strings.Join([]string{name, "."}, "")
}
Expand Down Expand Up @@ -184,3 +188,7 @@ func handleStatus(status Status, err error) (Status, error) {
return s, nil
}
}

func GetDefaultPort() uint16 {
return DefaultPort
}