Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions cmd/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ import (
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"

"github.com/Showmax/go-fqdn"
"github.com/alecthomas/kingpin/v2"
"github.com/cenkalti/backoff/v4"
"github.com/fsnotify/fsnotify"
"github.com/prometheus-community/pushprox/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
Expand All @@ -41,13 +44,13 @@ import (
)

var (
myFqdn = kingpin.Flag("fqdn", "FQDN to register with").Default(fqdn.Get()).String()
proxyURL = kingpin.Flag("proxy-url", "Push proxy to talk to.").Required().String()
caCertFile = kingpin.Flag("tls.cacert", "<file> CA certificate to verify peer against").String()
tlsCert = kingpin.Flag("tls.cert", "<cert> Client certificate file").String()
tlsKey = kingpin.Flag("tls.key", "<key> Private key file").String()
metricsAddr = kingpin.Flag("metrics-addr", "Serve Prometheus metrics at this address").Default(":9369").String()

myFqdn = kingpin.Flag("fqdn", "FQDN to register with").Default(fqdn.Get()).String()
proxyURL = kingpin.Flag("proxy-url", "Push proxy to talk to.").Required().String()
caCertFile = kingpin.Flag("tls.cacert", "<file> CA certificate to verify peer against").String()
tlsCert = kingpin.Flag("tls.cert", "<cert> Client certificate file").String()
tlsKey = kingpin.Flag("tls.key", "<key> Private key file").String()
metricsAddr = kingpin.Flag("metrics-addr", "Serve Prometheus metrics at this address").Default(":9369").String()
bearerTokenPath = kingpin.Flag("bearer-token-path", "<path> Path to file containing bearer token to authenticate requests").String()
retryInitialWait = kingpin.Flag("proxy.retry.initial-wait", "Amount of time to wait after proxy failure").Default("1s").Duration()
retryMaxWait = kingpin.Flag("proxy.retry.max-wait", "Maximum amount of time to wait between proxy poll retries").Default("5s").Duration()
)
Expand All @@ -71,6 +74,10 @@ var (
Help: "Number of poll errors",
},
)
// bearerToken holds the current token string used for authentication.
// Access must be synchronized to avoid race conditions between the watcher and HTTP handlers.
bearerToken string
bearerTokenMutex sync.RWMutex
)

func init() {
Expand Down Expand Up @@ -114,6 +121,14 @@ func (c *Coordinator) doScrape(request *http.Request, client *http.Client) {
c.handleErr(request, client, err)
return
}
bearerTokenMutex.RLock()
token := bearerToken
bearerTokenMutex.RUnlock()

if token != "" {
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
}

ctx, cancel := context.WithTimeout(request.Context(), timeout)
defer cancel()
request = request.WithContext(ctx)
Expand Down Expand Up @@ -225,6 +240,56 @@ func (c *Coordinator) loop(bo backoff.BackOff, client *http.Client) {
}
}

// I need to automatically reload the bearer token if it changes
// i.e. if it a short lived kubernetes token one
func watchBearerTokenFile(path string, logger *slog.Logger) {
loadBearerToken := func() {
tokenBytes, err := os.ReadFile(path)
if err != nil {
logger.Error("Failed to read bearer token", "path", path, "err", err)
return
}
bearerTokenMutex.Lock()
bearerToken = strings.TrimSpace(string(tokenBytes)) // also trim spaces/newlines
bearerTokenMutex.Unlock()
logger.Info("Bearer token loaded from file", "path", path)
}

loadBearerToken() // initial load

watcher, err := fsnotify.NewWatcher()
if err != nil {
logger.Error("Failed to create fsnotify watcher", "err", err)
os.Exit(1)
}
defer watcher.Close()

tokenDir := filepath.Dir(path)
if err := watcher.Add(tokenDir); err != nil {
logger.Error("Failed to watch token directory", "dir", tokenDir, "err", err)
os.Exit(1)
}

for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Name == path &&
(event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create) {
logger.Info("Bearer token file changed, reloading", "event", event)
loadBearerToken()
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
logger.Warn("fsnotify error", "err", err)
}
}
}

func main() {
promslogConfig := promslog.Config{}
flag.AddFlags(kingpin.CommandLine, &promslogConfig)
Expand Down Expand Up @@ -276,6 +341,11 @@ func main() {
}()
}

// Set bearer token based on path
if *bearerTokenPath != "" {
go watchBearerTokenFile(*bearerTokenPath, coordinator.logger)
}

transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Expand Down
217 changes: 208 additions & 9 deletions cmd/client/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,104 @@
package main

import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/prometheus/common/promslog"
)

func prepareTest() (*httptest.Server, Coordinator) {
// This test server acts as the proxyURL
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "GET /index.html HTTP/1.0\n\nOK")
switch r.URL.Path {
case "/poll":
// On /poll, respond with an HTTP request serialized in the body
var buf bytes.Buffer
req, _ := http.NewRequest("GET", fmt.Sprintf("http://%s/", *myFqdn), nil)
req.Header.Set("id", "test-scrape-id")
req.Header.Set("X-Prometheus-Scrape-Timeout-Seconds", "10")
req.Write(&buf)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(buf.Bytes())
case "/push":
// Accept pushed scrape results, just respond OK
io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
}))

c := Coordinator{logger: promslog.NewNopLogger()}
*proxyURL = ts.URL
*proxyURL = ts.URL + "/"
*myFqdn = "test.local" // Set fqdn to test.local for matching hostnames

return ts, c
}

func TestDoScrape(t *testing.T) {
func TestDoScrape_Success(t *testing.T) {
ts, c := prepareTest()
defer ts.Close()

req, err := http.NewRequest("GET", ts.URL, nil)
// Setup a test target server that will be scraped by doScrape
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify Authorization header if set
auth := r.Header.Get("Authorization")
if auth != "" && auth != "Bearer dummy-token" {
t.Errorf("unexpected Authorization header: %s", auth)
}
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "OK")
}))
defer targetServer.Close()

// Override myFqdn to match targetServer hostname
u, err := url.Parse(targetServer.URL)
if err != nil {
t.Fatal(err)
}
*myFqdn = u.Hostname()

// Prepare a scrape request targeting the test target server
req, err := http.NewRequest("GET", targetServer.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("id", "scrape-id-123")
req.Header.Set("X-Prometheus-Scrape-Timeout-Seconds", "10")

// Set bearerToken for authorization testing
bearerTokenMutex.Lock()
bearerToken = "dummy-token"
bearerTokenMutex.Unlock()

c.doScrape(req, targetServer.Client())
}

func TestDoScrape_FailWrongFQDN(t *testing.T) {
ts, c := prepareTest()
defer ts.Close()

req, err := http.NewRequest("GET", "http://wronghost.local", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Add("X-Prometheus-Scrape-Timeout-Seconds", "10.0")
*myFqdn = ts.URL
req.Header.Set("id", "fail-id")
req.Header.Set("X-Prometheus-Scrape-Timeout-Seconds", "10")

// This should cause handleErr due to fqdn mismatch
c.doScrape(req, ts.Client())
}

Expand All @@ -57,10 +126,140 @@ func TestHandleErr(t *testing.T) {
c.handleErr(req, ts.Client(), errors.New("test error"))
}

func TestLoop(t *testing.T) {
func TestDoPush_ErrorOnInvalidProxyURL(t *testing.T) {
c := Coordinator{logger: promslog.NewNopLogger()}
*proxyURL = "http://%41:8080" // invalid URL (percent-encoding issue)

resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("test")),
Header: http.Header{},
}
req, _ := http.NewRequest("GET", "http://example.com", nil)
err := c.doPush(resp, req, http.DefaultClient)
if err == nil {
t.Errorf("expected error on invalid proxy URL, got nil")
}
}

func TestDoPoll(t *testing.T) {
ts, c := prepareTest()
defer ts.Close()
if err := c.doPoll(ts.Client()); err != nil {

err := c.doPoll(ts.Client())
if err != nil {
t.Fatalf("doPoll failed: %v", err)
}
}

func TestLoopWithBackoff(t *testing.T) {
var count int
var mu sync.Mutex
done := make(chan struct{})
var once sync.Once

bo := backoffForTest(3)

go func() {
err := backoff.RetryNotify(func() error {
mu.Lock()
defer mu.Unlock()
count++
if count > 2 {
// safe close
once.Do(func() { close(done) })
return errors.New("forced error to stop retry")
}
return errors.New("temporary error")
}, bo, func(err error, d time.Duration) {
// No-op
})

if err != nil {
// safe even if already closed
once.Do(func() { close(done) })
}
}()

select {
case <-done:
case <-time.After(1 * time.Second):
t.Fatal("loop test timed out")
}
}

func backoffForTest(maxRetries int) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.InitialInterval = 1 * time.Millisecond
b.MaxInterval = 5 * time.Millisecond
b.MaxElapsedTime = 10 * time.Millisecond
return backoff.WithMaxRetries(b, uint64(maxRetries))
}

func TestWatchBearerTokenFile(t *testing.T) {
// This function is hard to test fully without fsnotify events,
// but we can test the initial loading of the token file.

// Create a temporary file with a token
tmpfile := t.TempDir() + "/tokenfile"
tokenContent := "file-token\n"
if err := os.WriteFile(tmpfile, []byte(tokenContent), 0600); err != nil {
t.Fatal(err)
}

logger := promslog.NewNopLogger()

// Run watchBearerTokenFile in a goroutine; it will load token initially
go func() {
// This will block watching the directory, so we only wait shortly
watchBearerTokenFile(tmpfile, logger)
}()

// Wait briefly for the token to load
time.Sleep(100 * time.Millisecond)

bearerTokenMutex.RLock()
defer bearerTokenMutex.RUnlock()
if bearerToken != strings.TrimSpace(tokenContent) {
t.Errorf("expected bearer token %q, got %q", strings.TrimSpace(tokenContent), bearerToken)
}
}

func TestBearerTokenHeader(t *testing.T) {
token := "dummy-token"
bearerTokenMutex.Lock()
bearerToken = token
bearerTokenMutex.Unlock()

var receivedToken string

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedToken = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()

// Ensure myFqdn matches the test server's hostname
u, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
*myFqdn = u.Hostname()

req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}

// Set required headers for doScrape to accept this request
req.Header.Set("id", "token-test-id")
req.Header.Set("X-Prometheus-Scrape-Timeout-Seconds", "10")

c := Coordinator{logger: promslog.NewNopLogger()}
c.doScrape(req, ts.Client())

expected := "Bearer dummy-token"
if receivedToken != expected {
t.Fatalf("expected %q, got %q", expected, receivedToken)
}
}
Loading