diff --git a/cmd/ip-monitor/main.go b/cmd/ip-monitor/main.go index be2ff2e..07523c8 100644 --- a/cmd/ip-monitor/main.go +++ b/cmd/ip-monitor/main.go @@ -14,7 +14,7 @@ import ( func main() { cli_args := cli.Arguments() config := cfg.LoadConfiguration(cli_args.ConfigPath) - sigs := make(chan os.Signal) + sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM) logger.Infoln("Starting IP-Monitor Daemon") diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 728c48f..acf8846 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -20,19 +20,21 @@ var ( func Start(config cfg.Configuration) { go func() { for { - select { - case <-Exit: - state = make(map[int]net.IP) - return - default: - Orchestrate(config) - time.Sleep(time.Duration(config.Interval) * time.Second) + Run(config) + for i := 1; i <= config.Interval; i++ { + select { + case <-Exit: + state = make(map[int]net.IP) + return + default: + time.Sleep(1 * time.Second) + } } } }() } -func Orchestrate(config cfg.Configuration) { +func Run(config cfg.Configuration) { var wg = &sync.WaitGroup{} for i := range config.Monitors { wg.Add(1) diff --git a/pkg/dns_resolver/dns_resolver.go b/pkg/dns_resolver/dns_resolver.go index ffd15ec..a98b21e 100644 --- a/pkg/dns_resolver/dns_resolver.go +++ b/pkg/dns_resolver/dns_resolver.go @@ -12,6 +12,10 @@ import ( "github.com/d0m84/ip-monitor/pkg/logger" ) +var ( + timeout int = 10 +) + func CheckIfCNAME(domain string) (string, bool, error) { target, err := net.LookupCNAME(domain) if err != nil { @@ -29,10 +33,10 @@ func FindFinalTarget(domain string) (string, error) { var err error var target string = domain var is_cname bool - for i := 0; i < 50; i++ { + for i := 0; i < 2; i++ { target, is_cname, err = CheckIfCNAME(target) if err != nil { - logger.Errorf("Error checking if %s is a CNAME: %s", target, err) + logger.Errorf("Error checking if %s is a CNAME: %s", domain, err) return "", err } if !is_cname { @@ -40,13 +44,14 @@ func FindFinalTarget(domain string) (string, error) { return target, nil } } - return "", errors.New("dns cname check limit reached") + logger.Errorf("Maximum CNAME lookup limit reached for %s", domain) + return "", errors.New("dns cname lookup limit reached") } func FindNameServers(domain string) ([]*net.NS, error) { - domainParts := strings.Split(domain, ".") - for i := range domainParts { - t := domainParts[i:len(domainParts):len(domainParts)] + domain_parts := strings.Split(domain, ".") + for i := range domain_parts { + t := domain_parts[i:len(domain_parts):len(domain_parts)] d := strings.Join(t, ".") nameservers, err := net.LookupNS(d) @@ -83,7 +88,7 @@ func LookupAuthorative(domain string, ip_version string) ([]net.IP, error) { PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { d := net.Dialer{ - Timeout: time.Millisecond * time.Duration(10000), + Timeout: time.Second * time.Duration(timeout), } return d.DialContext(ctx, network, nameserver) }, @@ -106,8 +111,7 @@ func Resolve(domain string, ip_version string) (net.IP, error) { target, err := FindFinalTarget(domain) if err != nil { - logger.Errorf("Error checking for final target of %s: %s", domain, err) - return nil, errors.New("dns cname check error") + return nil, errors.New("dns cname lookup error") } ips, err := LookupAuthorative(target, ip_version) diff --git a/pkg/http_resolver/http_resolver.go b/pkg/http_resolver/http_resolver.go index 733f05c..07c74c9 100644 --- a/pkg/http_resolver/http_resolver.go +++ b/pkg/http_resolver/http_resolver.go @@ -11,33 +11,37 @@ import ( "github.com/d0m84/ip-monitor/pkg/logger" ) +var ( + timeout int = 10 +) + func Resolve(provider string, ip_version string) (net.IP, error) { - var zeroDialer net.Dialer - var httpClient = &http.Client{Timeout: 10 * time.Second} - var tcpVersion string = ip_version + var dialer net.Dialer + var client = &http.Client{Timeout: time.Second * time.Duration(timeout)} + var tcp_version string = "tcp" if ip_version == "ip4" { - tcpVersion = "tcp4" + tcp_version = "tcp4" } else if ip_version == "ip6" { - tcpVersion = "tcp6" + tcp_version = "tcp6" } transport := http.DefaultTransport.(*http.Transport).Clone() transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return zeroDialer.DialContext(ctx, tcpVersion, addr) + return dialer.DialContext(ctx, tcp_version, addr) } - httpClient.Transport = transport + client.Transport = transport - resp, err := httpClient.Get(provider) + resp, err := client.Get(provider) if err != nil { logger.Errorf("Error connecting to HTTP IP provider: %s", err) return nil, errors.New("http error") } defer resp.Body.Close() - statusOK := resp.StatusCode >= 200 && resp.StatusCode < 300 - if !statusOK { + status_ok := resp.StatusCode >= 200 && resp.StatusCode < 300 + if !status_ok { logger.Errorf("HTTP status error from IP provider: %s", provider) return nil, errors.New("status error") }