From 5ca531eee55b6ebbe583abd34b7495d04d6f255f Mon Sep 17 00:00:00 2001 From: Alexandre Fiori Date: Mon, 16 Nov 2015 15:00:26 -0500 Subject: [PATCH] Move freegeoip daemon code to its own package Moving contents from cmd/freegeoip/main.go to apiserver package for better test coverage. This change updates the -addr command line flag and its behavior, and is backwards incomplatible. People using -addr must switch over to using -http now. In order to enable HTTPS, one must use -https and the server might listen on both HTTP and HTTPS. The -pprof flag changed to -internal-server and serves not only pprof but also metrics for prometheus (http://prometheus.io). These are under /debug/pprof (https://golang.org/pkg/net/http/pprof/) and /metrics accordingly. Bringing back the -read-timeout and -write-timeout command line flags for server tuning. Fixed a race condition bug in the redis quota algorithm, at the exchange of 1 redis incr per request following advice from pattern #2 from http://redis.io/commands/incr. Also added rate limit response headers for all HTTP and HTTPS requests, inspired by GitHub's API: X-RateLimit-Limit: number of requests allowed per interval (def. 1h) X-RateLimit-Remaining: number of requests remaining, per user X-RateLimit-Reset: time in seconds before resetting the limit Added the -logtostdout command line flag to close #146. Minor fix to the background database download back off algorithm, added -api-prefix and -cors-origin command line flags, and tests. --- apiserver/cmd.go | 135 +++++++++++++++++++ apiserver/cmd_test.go | 26 ++++ apiserver/cors.go | 44 ++++++ apiserver/cors_test.go | 89 ++++++++++++ apiserver/db.go | 40 ++++++ apiserver/doc.go | 7 + apiserver/http.go | 82 +++++++++++ apiserver/http_test.go | 62 +++++++++ apiserver/metrics.go | 37 +++++ apiserver/ratelimit.go | 96 +++++++++++++ apiserver/ratelimit_test.go | 68 ++++++++++ cmd/freegeoip/main.go | 261 +----------------------------------- db.go | 28 ++-- encoder.go | 1 - 14 files changed, 699 insertions(+), 277 deletions(-) create mode 100644 apiserver/cmd.go create mode 100644 apiserver/cmd_test.go create mode 100644 apiserver/cors.go create mode 100644 apiserver/cors_test.go create mode 100644 apiserver/db.go create mode 100644 apiserver/doc.go create mode 100644 apiserver/http.go create mode 100644 apiserver/http_test.go create mode 100644 apiserver/metrics.go create mode 100644 apiserver/ratelimit.go create mode 100644 apiserver/ratelimit_test.go diff --git a/apiserver/cmd.go b/apiserver/cmd.go new file mode 100644 index 0000000..4729c72 --- /dev/null +++ b/apiserver/cmd.go @@ -0,0 +1,135 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package apiserver + +import ( + "flag" + "fmt" + "log" + "net/http" + "os" + "strings" + "time" + + // embed pprof server. + _ "net/http/pprof" + + "github.com/fiorix/freegeoip" + "github.com/fiorix/go-redis/redis" + gorilla "github.com/gorilla/handlers" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/net/http2" +) + +// Version tag. +var Version = "3.0.7" + +var maxmindDB = "http://geolite.maxmind.com/download/geoip/database/GeoLite2-City.mmdb.gz" + +var ( + flAPIPrefix = flag.String("api-prefix", "/", "Prefix for API endpoints") + flCORSOrigin = flag.String("cors-origin", "*", "CORS origin API endpoints") + flHTTPAddr = flag.String("http", ":8080", "Address in form of ip:port to listen on for HTTP") + flHTTPSAddr = flag.String("https", "", "Address in form of ip:port to listen on for HTTPS") + flCertFile = flag.String("cert", "cert.pem", "X.509 certificate file") + flKeyFile = flag.String("key", "key.pem", "X.509 key file") + flReadTimeout = flag.Duration("read-timeout", 30*time.Second, "Read timeout for HTTP and HTTPS client conns") + flWriteTimeout = flag.Duration("write-timeout", 15*time.Second, "Write timeout for HTTP and HTTPS client conns") + flPublicDir = flag.String("public", "", "Public directory to serve at the {prefix}/ endpoint") + flDB = flag.String("db", maxmindDB, "IP database file or URL") + flUpdateIntvl = flag.Duration("update", 24*time.Hour, "Database update check interval") + flRetryIntvl = flag.Duration("retry", time.Hour, "Max time to wait before retrying to download database") + flUseXFF = flag.Bool("use-x-forwarded-for", false, "Use the X-Forwarded-For header when available (e.g. when running behind proxies)") + flSilent = flag.Bool("silent", false, "Do not log HTTP or HTTPS requests to stderr") + flLogToStdout = flag.Bool("logtostdout", false, "Log to stdout instead of stderr") + flRedisAddr = flag.String("redis", "127.0.0.1:6379", "Redis address in form of ip:port[,ip:port] for quota") + flRedisTimeout = flag.Duration("redis-timeout", 500*time.Millisecond, "Redis read/write timeout") + flQuotaMax = flag.Int("quota-max", 0, "Max requests per source IP per interval; set 0 to turn off") + flQuotaIntvl = flag.Duration("quota-interval", time.Hour, "Quota expiration interval per source IP querying the API") + flVersion = flag.Bool("version", false, "Show version and exit") + flInternalServer = flag.String("internal-server", "", "Address in form of ip:port to listen on for /metrics and /debug/pprof") +) + +// Run is the entrypoint for the freegeoip daemon tool. +func Run() error { + flag.Parse() + + if *flVersion { + fmt.Printf("freegeoip v%s\n", Version) + return nil + } + + if *flLogToStdout { + log.SetOutput(os.Stdout) + } + + log.SetPrefix("[freegeoip] ") + + addrs := strings.Split(*flRedisAddr, ",") + rc, err := redis.Dial(addrs...) + if err != nil { + return err + } + rc.Timeout = *flRedisTimeout + + db, err := openDB(*flDB, *flUpdateIntvl, *flRetryIntvl) + if err != nil { + return err + } + go watchEvents(db) + + ah := NewHandler(&HandlerConfig{ + Prefix: *flAPIPrefix, + Origin: *flCORSOrigin, + PublicDir: *flPublicDir, + DB: db, + RateLimiter: RateLimiter{ + Redis: rc, + Max: *flQuotaMax, + Interval: *flQuotaIntvl, + }, + }) + + if !*flSilent { + ah = gorilla.CombinedLoggingHandler(os.Stderr, ah) + } + + if *flUseXFF { + ah = freegeoip.ProxyHandler(ah) + } + + if len(*flInternalServer) > 0 { + http.Handle("/metrics", prometheus.Handler()) + log.Println("freegeoip internal server starting on", *flInternalServer) + go func() { log.Fatal(http.ListenAndServe(*flInternalServer, nil)) }() + } + + if *flHTTPAddr != "" { + log.Println("freegeoip http server starting on", *flHTTPAddr) + srv := &http.Server{ + Addr: *flHTTPAddr, + Handler: ah, + ReadTimeout: *flReadTimeout, + WriteTimeout: *flWriteTimeout, + ConnState: ConnStateMetrics(httpConnsGauge), + } + go func() { log.Fatal(srv.ListenAndServe()) }() + } + + if *flHTTPSAddr != "" { + log.Println("freegeoip https server starting on", *flHTTPSAddr) + srv := &http.Server{ + Addr: *flHTTPSAddr, + Handler: ah, + ReadTimeout: *flReadTimeout, + WriteTimeout: *flWriteTimeout, + ConnState: ConnStateMetrics(httpsConnsGauge), + } + http2.ConfigureServer(srv, nil) + go func() { log.Fatal(srv.ListenAndServeTLS(*flCertFile, *flKeyFile)) }() + } + + select {} +} diff --git a/apiserver/cmd_test.go b/apiserver/cmd_test.go new file mode 100644 index 0000000..c5c43f0 --- /dev/null +++ b/apiserver/cmd_test.go @@ -0,0 +1,26 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package apiserver + +import ( + "flag" + "testing" + "time" +) + +func TestCmd(t *testing.T) { + flag.Set("http", ":0") + flag.Set("db", "../testdata/db.gz") + flag.Set("silent", "true") + errc := make(chan error) + go func() { + errc <- Run() + }() + select { + case err := <-errc: + t.Fatal(err) + case <-time.After(time.Second): + } +} diff --git a/apiserver/cors.go b/apiserver/cors.go new file mode 100644 index 0000000..0af9ace --- /dev/null +++ b/apiserver/cors.go @@ -0,0 +1,44 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package apiserver + +import ( + "net/http" + "strings" +) + +// cors is an HTTP handler for managing cross-origin resource sharing. +// Ref: https://en.wikipedia.org/wiki/Cross-origin_resource_sharing. +func cors(f http.Handler, origin string, methods ...string) http.Handler { + ms := strings.Join(methods, ", ") + ", OPTIONS" + md := make(map[string]struct{}) + for _, method := range methods { + md[method] = struct{}{} + } + cf := func(w http.ResponseWriter, r *http.Request) { + orig := origin + if orig == "*" { + if ro := r.Header.Get("Origin"); ro != "" { + orig = ro + } + } + w.Header().Set("Access-Control-Allow-Origin", orig) + w.Header().Set("Access-Control-Allow-Methods", ms) + w.Header().Set("Access-Control-Allow-Credentials", "true") + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + if _, exists := md[r.Method]; exists { + f.ServeHTTP(w, r) + return + } + w.Header().Set("Allow", ms) + http.Error(w, + http.StatusText(http.StatusMethodNotAllowed), + http.StatusMethodNotAllowed) + } + return http.HandlerFunc(cf) +} diff --git a/apiserver/cors_test.go b/apiserver/cors_test.go new file mode 100644 index 0000000..5dcfe02 --- /dev/null +++ b/apiserver/cors_test.go @@ -0,0 +1,89 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package apiserver + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestCORS(t *testing.T) { + // set up the test server + handler := func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "hello world") + } + mux := http.NewServeMux() + mux.Handle("/", cors(http.HandlerFunc(handler), "*", "GET")) + ts := httptest.NewServer(mux) + defer ts.Close() + // create and issue an OPTIONS request and + // validate response status and headers. + req, err := http.NewRequest("OPTIONS", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("Origin", ts.URL) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("Unexpected response status: %s", resp.Status) + } + if resp.ContentLength != 0 { + t.Fatalf("Unexpected Content-Length. Want 0, have %d", + resp.ContentLength) + } + want := []struct { + Name string + Value string + }{ + {"Access-Control-Allow-Origin", ts.URL}, + {"Access-Control-Allow-Methods", "GET, OPTIONS"}, + {"Access-Control-Allow-Credentials", "true"}, + } + for _, th := range want { + if v := resp.Header.Get(th.Name); v != th.Value { + t.Fatalf("Unexpected value for %q. Want %q, have %q", + th.Name, th.Value, v) + } + } + // issue a GET request and validate response headers and body + resp, err = http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + want[0].Value = "*" // Origin + for _, th := range want { + if v := resp.Header.Get(th.Name); v != th.Value { + t.Fatalf("Unexpected value for %q. Want %q, have %q", + th.Name, th.Value, v) + } + } + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + wb := []byte("hello world") + if !bytes.Equal(b, wb) { + t.Fatalf("Unexpected response body. Want %q, have %q", b, wb) + } + // issue a POST request and validate response status + resp, err = http.PostForm(ts.URL, url.Values{}) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("Unexpected response status: %s", resp.Status) + } +} diff --git a/apiserver/db.go b/apiserver/db.go new file mode 100644 index 0000000..b546f6e --- /dev/null +++ b/apiserver/db.go @@ -0,0 +1,40 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package apiserver + +import ( + "log" + "net/url" + "time" + + "github.com/fiorix/freegeoip" +) + +// openDB opens and returns the IP database. +func openDB(dsn string, updateIntvl, maxRetryIntvl time.Duration) (db *freegeoip.DB, err error) { + u, err := url.Parse(dsn) + if err != nil || len(u.Scheme) == 0 { + db, err = freegeoip.Open(dsn) + } else { + db, err = freegeoip.OpenURL(dsn, updateIntvl, maxRetryIntvl) + } + return +} + +// watchEvents logs and collect metrics of database events. +func watchEvents(db *freegeoip.DB) { + for { + select { + case file := <-db.NotifyOpen(): + log.Println("database loaded:", file) + dbEventCounter.WithLabelValues("loaded", file).Inc() + case err := <-db.NotifyError(): + log.Println("database error:", err) + dbEventCounter.WithLabelValues("failed", err.Error()).Inc() + case <-db.NotifyClose(): + return + } + } +} diff --git a/apiserver/doc.go b/apiserver/doc.go new file mode 100644 index 0000000..0434eea --- /dev/null +++ b/apiserver/doc.go @@ -0,0 +1,7 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package apiserver provides the freegeoip web server API, used by +// the freegeoip daemon tool. +package apiserver diff --git a/apiserver/http.go b/apiserver/http.go new file mode 100644 index 0000000..75657fd --- /dev/null +++ b/apiserver/http.go @@ -0,0 +1,82 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package apiserver + +import ( + "net" + "net/http" + "path/filepath" + + "github.com/fiorix/freegeoip" + "github.com/prometheus/client_golang/prometheus" +) + +// HandlerConfig holds configuration for freegeoip http handlers. +type HandlerConfig struct { + Prefix string + Origin string + PublicDir string + DB *freegeoip.DB + RateLimiter RateLimiter +} + +// NewHandler creates a freegeoip http handler. +func NewHandler(conf *HandlerConfig) http.Handler { + ah := &apiHandler{conf} + mux := http.NewServeMux() + ah.RegisterPublicDir(mux) + ah.RegisterEncoder(mux, "csv", &freegeoip.CSVEncoder{UseCRLF: true}) + ah.RegisterEncoder(mux, "xml", &freegeoip.XMLEncoder{Indent: true}) + ah.RegisterEncoder(mux, "json", &freegeoip.JSONEncoder{}) + return mux +} + +type ConnStateFunc func(c net.Conn, s http.ConnState) + +func ConnStateMetrics(g prometheus.Gauge) ConnStateFunc { + return func(c net.Conn, s http.ConnState) { + switch s { + case http.StateNew: + g.Inc() + case http.StateClosed: + g.Dec() + } + } +} + +type apiHandler struct { + conf *HandlerConfig +} + +func (ah *apiHandler) prefix(path string) string { + p := filepath.Clean(filepath.Join("/", ah.conf.Prefix, path)) + if p[len(p)-1] != '/' { + p += "/" + } + return p +} + +func (ah *apiHandler) RegisterPublicDir(mux *http.ServeMux) { + fs := http.FileServer(http.Dir(ah.conf.PublicDir)) + fs = prometheus.InstrumentHandler("frontend", fs) + prefix := ah.prefix("") + mux.Handle(prefix, http.StripPrefix(prefix, fs)) +} + +func (ah *apiHandler) RegisterEncoder(mux *http.ServeMux, path string, enc freegeoip.Encoder) { + f := http.Handler(freegeoip.NewHandler(ah.conf.DB, enc)) + if ah.conf.RateLimiter.Max > 0 { + rl := ah.conf.RateLimiter + rl.Handler = f + f = &rl + } + origin := ah.conf.Origin + if origin == "" { + origin = "*" + } + f = cors(f, origin, "GET", "HEAD") + f = prometheus.InstrumentHandler(path, f) + mux.Handle(ah.prefix(path), f) +} diff --git a/apiserver/http_test.go b/apiserver/http_test.go new file mode 100644 index 0000000..1323924 --- /dev/null +++ b/apiserver/http_test.go @@ -0,0 +1,62 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package apiserver + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/fiorix/freegeoip" + "github.com/fiorix/go-redis/redis" +) + +func newTestHandler(db *freegeoip.DB) http.Handler { + return NewHandler(&HandlerConfig{ + Prefix: "/api", + PublicDir: ".", + DB: db, + RateLimiter: RateLimiter{ + Redis: redis.New(), + Max: 5, + Interval: time.Second, + KeyMaker: KeyMakerFunc(func(r *http.Request) string { + return "handler-test" + }), + }, + }) +} + +func TestHandler(t *testing.T) { + db, err := freegeoip.Open("../testdata/db.gz") + if err != nil { + t.Fatal(err) + } + defer db.Close() + s := httptest.NewServer(newTestHandler(db)) + defer s.Close() + // query some known location... + resp, err := http.Get(s.URL + "/api/json/200.1.2.3") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatal(resp.Status) + } + m := struct { + Country string `json:"country_name"` + City string `json:"city"` + }{} + if err = json.NewDecoder(resp.Body).Decode(&m); err != nil { + t.Fatal(err) + } + if m.Country != "Venezuela" && m.City != "Caracas" { + t.Fatal("Query data does not match: want Caracas,Venezuela, have %q,%q", + m.City, m.Country) + } +} diff --git a/apiserver/metrics.go b/apiserver/metrics.go new file mode 100644 index 0000000..4694e82 --- /dev/null +++ b/apiserver/metrics.go @@ -0,0 +1,37 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package apiserver + +import "github.com/prometheus/client_golang/prometheus" + +// Experimental metrics for Prometheus, might change in the future. + +var dbEventCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "db_event_counter", + Help: "Counter per DB event", + }, + []string{"event", "data"}, +) + +var httpConnsGauge = prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "current_http_conns", + Help: "Current number of HTTP connections", + }, +) + +var httpsConnsGauge = prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "current_https_conns", + Help: "Current number of HTTPS connections", + }, +) + +func init() { + prometheus.MustRegister(dbEventCounter) + prometheus.MustRegister(httpConnsGauge) + prometheus.MustRegister(httpsConnsGauge) +} diff --git a/apiserver/ratelimit.go b/apiserver/ratelimit.go new file mode 100644 index 0000000..be7d8cc --- /dev/null +++ b/apiserver/ratelimit.go @@ -0,0 +1,96 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package apiserver + +import ( + "errors" + "net" + "net/http" + "strconv" + "sync" + "time" + + "github.com/fiorix/go-redis/redis" +) + +var ( + errQuotaExceeded = errors.New("Quota exceeded") + errRedisUnavailable = errors.New("Try again later") +) + +// A KeyMaker makes keys from the http.Request object to the RateLimiter. +type KeyMaker interface { + KeyFor(r *http.Request) string +} + +// KeyMakerFunc is an adapter function for KeyMaker. +type KeyMakerFunc func(r *http.Request) string + +// KeyFor implements the KeyMaker interface. +func (f KeyMakerFunc) KeyFor(r *http.Request) string { + return f(r) +} + +// DefaultKeyMaker is a KeyMaker that returns the client IP +// address from http.Request.RemoteAddr. +var DefaultKeyMaker = KeyMakerFunc(func(r *http.Request) string { + addr, _, _ := net.SplitHostPort(r.RemoteAddr) + return addr +}) + +// A RateLimiter is an http.Handler that wraps another handler, +// and calls it up to a certain limit, max per interval. +type RateLimiter struct { + Redis *redis.Client + Max int + Interval time.Duration + KeyMaker KeyMaker + Handler http.Handler + + secInterval int + once sync.Once +} + +// ServeHTTP implements the http.Handler interface. +func (rl *RateLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + rl.once.Do(func() { + rl.secInterval = int(rl.Interval.Seconds()) + if rl.KeyMaker == nil { + rl.KeyMaker = DefaultKeyMaker + } + }) + status, err := rl.do(w, r) + if err != nil { + http.Error(w, err.Error(), status) + return + } +} + +func (rl *RateLimiter) do(w http.ResponseWriter, r *http.Request) (int, error) { + k := rl.KeyMaker.KeyFor(r) + nreq, err := rl.Redis.Incr(k) + if err != nil { + return http.StatusServiceUnavailable, errRedisUnavailable + } + var ttl = 0 + if nreq == 1 { + if _, err = rl.Redis.Expire(k, rl.secInterval); err != nil { + return http.StatusServiceUnavailable, errRedisUnavailable + } + ttl = rl.secInterval + } else if ttl, err = rl.Redis.TTL(k); err != nil { + return http.StatusServiceUnavailable, errRedisUnavailable + } + rem := rl.Max - nreq + w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.Max)) + w.Header().Set("X-RateLimit-Reset", strconv.Itoa(ttl)) + if rem < 0 { + w.Header().Set("X-RateLimit-Remaining", "0") + return http.StatusForbidden, errQuotaExceeded + } + w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(rem)) + rl.Handler.ServeHTTP(w, r) + return http.StatusOK, nil +} diff --git a/apiserver/ratelimit_test.go b/apiserver/ratelimit_test.go new file mode 100644 index 0000000..d1b6bd1 --- /dev/null +++ b/apiserver/ratelimit_test.go @@ -0,0 +1,68 @@ +// Copyright 2009-2015 The freegeoip authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package apiserver + +import ( + "log" + "net/http" + "net/http/httptest" + "strconv" + "sync" + "testing" + "time" + + "github.com/fiorix/go-redis/redis" +) + +func TestRateLimiter(t *testing.T) { + counter := struct { + sync.Mutex + n int + }{} + hf := func(w http.ResponseWriter, r *http.Request) { + counter.Lock() + counter.n++ + counter.Unlock() + } + kmf := func(r *http.Request) string { + return "rate-limiter-test" + } + rl := &RateLimiter{ + Redis: redis.New(), + Max: 2, + Interval: time.Second, + KeyMaker: KeyMakerFunc(kmf), + Handler: http.HandlerFunc(hf), + } + mux := http.NewServeMux() + mux.Handle("/", rl) + s := httptest.NewServer(mux) + defer s.Close() + for i := 0; i < 3; i++ { + resp, err := http.Get(s.URL) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode == http.StatusServiceUnavailable { + t.Skip("redis unavailable, can't test") + } + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusForbidden && i != 2 { + t.Fatal(resp.Status) + } + } + lim, _ := strconv.Atoi(resp.Header.Get("X-RateLimit-Limit")) + rem, _ := strconv.Atoi(resp.Header.Get("X-RateLimit-Remaining")) + res, _ := strconv.Atoi(resp.Header.Get("X-RateLimit-Reset")) + switch { + case i == 0 && lim == 2 && rem == 1 && res > 0: + case (i == 1 || i == 2) && lim == 2 && rem == 0 && res > 0: + default: + log.Fatalf("Unexpected values: limit=%d, remaining=%d, reset=%d", + lim, rem, res) + } + } +} diff --git a/cmd/freegeoip/main.go b/cmd/freegeoip/main.go index 0ea55f4..5a65817 100644 --- a/cmd/freegeoip/main.go +++ b/cmd/freegeoip/main.go @@ -5,270 +5,13 @@ package main import ( - "flag" - "fmt" "log" - "net" - "net/http" - "net/url" - "runtime" - "strconv" - "strings" - "time" - _ "net/http/pprof" - - "github.com/fiorix/freegeoip" - "github.com/fiorix/go-redis/redis" - "github.com/gorilla/context" + "github.com/fiorix/freegeoip/apiserver" ) -// Version tag. -var Version = "3.0.6" - -var maxmindFile = "http://geolite.maxmind.com/download/geoip/database/GeoLite2-City.mmdb.gz" - func main() { - addr := flag.String("addr", ":8080", "Address in form of ip:port to listen on") - certFile := flag.String("cert", "", "X.509 certificate file") - keyFile := flag.String("key", "", "X.509 key file") - public := flag.String("public", "", "Public directory to serve at the / endpoint") - ipdb := flag.String("db", maxmindFile, "IP database file or URL") - updateIntvl := flag.Duration("update", 24*time.Hour, "Database update check interval") - retryIntvl := flag.Duration("retry", time.Hour, "Max time to wait before retrying update") - useXFF := flag.Bool("use-x-forwarded-for", false, "Use the X-Forwarded-For header when available") - silent := flag.Bool("silent", false, "Do not log requests to stderr") - redisAddr := flag.String("redis", "127.0.0.1:6379", "Redis address in form of ip:port[,ip:port] for quota") - redisTimeout := flag.Duration("redis-timeout", 500*time.Millisecond, "Redis read/write timeout") - quotaMax := flag.Int("quota-max", 0, "Max requests per source IP per interval; Set 0 to turn off") - quotaIntvl := flag.Duration("quota-interval", time.Hour, "Quota expiration interval") - version := flag.Bool("version", false, "Show version and exit") - pprof := flag.String("pprof", "", "Address in form of ip:port to listen on for pprof") - flag.Parse() - - if *version { - fmt.Printf("freegeoip v%s\n", Version) - return - } - - addrs := strings.Split(*redisAddr, ",") - rc, err := redis.Dial(addrs...) - if err != nil { + if err := apiserver.Run(); err != nil { log.Fatal(err) } - rc.Timeout = *redisTimeout - - db, err := openDB(*ipdb, *updateIntvl, *retryIntvl) - if err != nil { - log.Fatal(err) - } - - runtime.GOMAXPROCS(runtime.NumCPU()) - - encoders := map[string]http.Handler{ - "/csv/": freegeoip.NewHandler(db, &freegeoip.CSVEncoder{UseCRLF: true}), - "/xml/": freegeoip.NewHandler(db, &freegeoip.XMLEncoder{Indent: true}), - "/json/": freegeoip.NewHandler(db, &freegeoip.JSONEncoder{}), - } - - if *quotaMax > 0 { - seconds := int((*quotaIntvl).Seconds()) - for path, f := range encoders { - encoders[path] = userQuota(rc, *quotaMax, seconds, f, *silent) - } - } - - mux := http.NewServeMux() - for path, handler := range encoders { - mux.Handle(path, handler) - } - - if len(*public) > 0 { - mux.Handle("/", http.FileServer(http.Dir(*public))) - } - - handler := CORS(mux, "GET", "HEAD") - - if !*silent { - log.Println("freegeoip server starting on", *addr) - go logEvents(db) - handler = logHandler(handler) - } - - if *useXFF { - handler = freegeoip.ProxyHandler(handler) - } - - if len(*pprof) > 0 { - go func() { - log.Fatal(http.ListenAndServe(*pprof, nil)) - }() - } - - if len(*certFile) > 0 && len(*keyFile) > 0 { - err = http.ListenAndServeTLS(*addr, *certFile, *keyFile, handler) - } else { - err = http.ListenAndServe(*addr, handler) - } - if err != nil { - log.Fatal(err) - } -} - -// openDB opens and returns the IP database. -func openDB(dsn string, updateIntvl, maxRetryIntvl time.Duration) (db *freegeoip.DB, err error) { - u, err := url.Parse(dsn) - if err != nil || len(u.Scheme) == 0 { - db, err = freegeoip.Open(dsn) - } else { - db, err = freegeoip.OpenURL(dsn, updateIntvl, maxRetryIntvl) - } - return -} - -// CORS is an http handler that checks for allowed request methods (verbs) -// and adds CORS headers to all http responses. -// -// See http://en.wikipedia.org/wiki/Cross-origin_resource_sharing for details. -func CORS(f http.Handler, allow ...string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", - strings.Join(allow, ", ")+", OPTIONS") - if r.Method == "OPTIONS" { - w.WriteHeader(200) - return - } - for _, method := range allow { - if r.Method == method { - f.ServeHTTP(w, r) - return - } - } - w.Header().Set("Allow", strings.Join(allow, ", ")+", OPTIONS") - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), - http.StatusMethodNotAllowed) - }) -} - -// userQuota is a handler that provides a rate limiter to the freegeoip API. -// It allows qmax requests per qintvl, in seconds. -// -// If redis is not available it responds with service unavailable. -func userQuota(rc *redis.Client, qmax int, qintvl int, f http.Handler, silent bool) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var ip string - if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 { - ip = r.RemoteAddr[:idx] - } else { - ip = r.RemoteAddr - } - sreq, err := rc.Get(ip) - if err != nil { - if !silent { - context.Set(r, "log", err.Error()) - } - http.Error(w, "Try again later", - http.StatusServiceUnavailable) - return - } - if len(sreq) == 0 { - err = rc.SetEx(ip, qintvl, "1") - if err != nil { - if !silent { - context.Set(r, "log", err.Error()) - } - http.Error(w, "Try again later", - http.StatusServiceUnavailable) - return - } - f.ServeHTTP(w, r) - return - } - nreq, _ := strconv.Atoi(sreq) - if nreq >= qmax { - http.Error(w, "Quota exceeded", http.StatusForbidden) - return - } - _, err = rc.Incr(ip) - if err != nil && !silent { - context.Set(r, "log", err.Error()) - } - f.ServeHTTP(w, r) - }) -} - -// logEvents logs database events. -func logEvents(db *freegeoip.DB) { - for { - select { - case file := <-db.NotifyOpen(): - log.Println("database loaded:", file) - case err := <-db.NotifyError(): - log.Println("database error:", err) - case <-db.NotifyClose(): - return - } - } -} - -// logHandler logs http requests. -func logHandler(f http.Handler) http.Handler { - empty := "" - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := responseWriter{w, http.StatusOK, 0} - start := time.Now() - f.ServeHTTP(&resp, r) - elapsed := time.Since(start) - extra := context.Get(r, "log") - if extra != nil { - defer context.Clear(r) - } else { - extra = empty - } - log.Printf("%q %d %q %q %s %q %db in %s %q", - r.Proto, - resp.status, - r.Method, - r.URL.Path, - remoteIP(r), - r.Header.Get("User-Agent"), - resp.bytes, - elapsed, - extra, - ) - }) -} - -// remoteIP returns the client's address without the port number. -func remoteIP(r *http.Request) string { - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return r.RemoteAddr - } - return host -} - -// responseWriter is an http.ResponseWriter that records the returned -// status and bytes written to the client. -type responseWriter struct { - http.ResponseWriter - status int - bytes int -} - -// Write implements the http.ResponseWriter interface. -func (f *responseWriter) Write(b []byte) (int, error) { - n, err := f.ResponseWriter.Write(b) - if err != nil { - return 0, err - } - f.bytes += n - return n, nil -} - -// WriteHeader implements the http.ResponseWriter interface. -func (f *responseWriter) WriteHeader(code int) { - f.status = code - f.ResponseWriter.WriteHeader(code) } diff --git a/db.go b/db.go index 20529dd..39a5ceb 100644 --- a/db.go +++ b/db.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net" "net/http" "os" @@ -26,7 +27,7 @@ var ( // ErrUnavailable may be returned by DB.Lookup when the database // points to a URL and is not yet available because it's being // downloaded in background. - ErrUnavailable = errors.New("No database available") + ErrUnavailable = errors.New("no database available") // Local cached copy of a database downloaded from a URL. defaultDB = filepath.Join(os.TempDir(), "freegeoip", "db.gz") @@ -72,7 +73,8 @@ func Open(dsn string) (db *DB, err error) { } // OpenURL creates and initializes a DB from a remote file. -// It automatically downloads and updates the file in background. +// It automatically downloads and updates the file in background, and +// keeps a local copy on $TMPDIR. func OpenURL(url string, updateInterval, maxRetryInterval time.Duration) (db *DB, err error) { db = &DB{ file: defaultDB, @@ -171,29 +173,21 @@ func (db *DB) setReader(reader *maxminddb.Reader, modtime time.Time) { } func (db *DB) autoUpdate(url string) { - var sleep time.Duration - var retrying bool + backoff := time.Second for { err := db.runUpdate(url) if err != nil { - db.sendError(fmt.Errorf("Database update failed: %s", err)) - if !retrying { - retrying = true - sleep = 5 * time.Second - } else { - sleep *= 2 - if sleep > db.maxRetryInterval { - sleep = db.maxRetryInterval - } - } + bs := backoff.Seconds() + ms := db.maxRetryInterval.Seconds() + backoff = time.Duration(math.Min(bs*math.E, ms)) * time.Second + db.sendError(fmt.Errorf("download failed (will retry in %s): %s", backoff, err)) } else { - retrying = false - sleep = db.updateInterval + backoff = db.updateInterval } select { case <-db.notifyQuit: return - case <-time.After(sleep): + case <-time.After(backoff): // Sleep till time for the next update attempt. } } diff --git a/encoder.go b/encoder.go index c787cb3..4c5dc39 100644 --- a/encoder.go +++ b/encoder.go @@ -205,7 +205,6 @@ type responseRecord struct { // // See the maxminddb documentation for supported languages. func newResponse(query *maxmindQuery, ip net.IP, lang []string) *responseRecord { - record := &responseRecord{ IP: ip.String(), CountryCode: query.Country.ISOCode,