From 11d5e76e1ba5147efbca61e8bf418b6279d3be2f Mon Sep 17 00:00:00 2001 From: Mark Rushakoff Date: Wed, 11 Jul 2018 15:30:07 -0700 Subject: [PATCH] chore(prometheus): don't use global prometheus registry And add tests. --- prometheus/auth_service.go | 16 ++-- prometheus/auth_service_test.go | 144 ++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 7 deletions(-) create mode 100644 prometheus/auth_service_test.go diff --git a/prometheus/auth_service.go b/prometheus/auth_service.go index 809279c8a36..fee0bf15c7b 100644 --- a/prometheus/auth_service.go +++ b/prometheus/auth_service.go @@ -38,11 +38,6 @@ func NewAuthorizationService() *AuthorizationService { }, []string{"method", "error"}), } - prometheus.MustRegister( - s.requestCount, - s.requestDuration, - ) - return s } @@ -50,7 +45,7 @@ func NewAuthorizationService() *AuthorizationService { func (s *AuthorizationService) FindAuthorizationByID(ctx context.Context, id platform.ID) (a *platform.Authorization, err error) { defer func(start time.Time) { labels := prometheus.Labels{ - "method": "FindAuthorizationsByID", + "method": "FindAuthorizationByID", "error": fmt.Sprint(err != nil), } s.requestCount.With(labels).Add(1) @@ -63,7 +58,7 @@ func (s *AuthorizationService) FindAuthorizationByID(ctx context.Context, id pla func (s *AuthorizationService) FindAuthorizationByToken(ctx context.Context, t string) (a *platform.Authorization, err error) { defer func(start time.Time) { labels := prometheus.Labels{ - "method": "FindAuthorizationsByToken", + "method": "FindAuthorizationByToken", "error": fmt.Sprint(err != nil), } s.requestCount.With(labels).Add(1) @@ -113,3 +108,10 @@ func (s *AuthorizationService) DeleteAuthorization(ctx context.Context, id platf return s.AuthorizationService.DeleteAuthorization(ctx, id) } + +func (s *AuthorizationService) PrometheusCollectors() []prometheus.Collector { + return []prometheus.Collector{ + s.requestCount, + s.requestDuration, + } +} diff --git a/prometheus/auth_service_test.go b/prometheus/auth_service_test.go new file mode 100644 index 00000000000..38a6569e280 --- /dev/null +++ b/prometheus/auth_service_test.go @@ -0,0 +1,144 @@ +package prometheus_test + +import ( + "context" + "errors" + "testing" + + "github.com/influxdata/platform" + "github.com/influxdata/platform/kit/prom" + "github.com/influxdata/platform/kit/prom/promtest" + "github.com/influxdata/platform/prometheus" +) + +// authzSvc is a test helper that returns its Err from every method on the AuthorizationService interface. +type authzSvc struct { + Err error +} + +var _ platform.AuthorizationService = (*authzSvc)(nil) + +func (a *authzSvc) FindAuthorizationByID(context.Context, platform.ID) (*platform.Authorization, error) { + return nil, a.Err +} + +func (a *authzSvc) FindAuthorizationByToken(context.Context, string) (*platform.Authorization, error) { + return nil, a.Err +} + +func (a *authzSvc) FindAuthorizations(context.Context, platform.AuthorizationFilter, ...platform.FindOptions) ([]*platform.Authorization, int, error) { + return nil, 0, a.Err +} + +func (a *authzSvc) CreateAuthorization(context.Context, *platform.Authorization) error { + return a.Err +} + +func (a *authzSvc) DeleteAuthorization(context.Context, platform.ID) error { + return a.Err +} + +func TestAuthorizationService_Metrics(t *testing.T) { + a := new(authzSvc) + + svc := prometheus.NewAuthorizationService() + svc.AuthorizationService = a + reg := prom.NewRegistry() + reg.MustRegister(svc.PrometheusCollectors()...) + + ctx := context.Background() + id := platform.ID{1} + + if _, err := svc.FindAuthorizationByID(ctx, id); err != nil { + t.Fatal(err) + } + mfs := promtest.MustGather(t, reg) + m := promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizationByID", "error": "false"}) + if got := m.GetCounter().GetValue(); got != 1 { + t.Fatalf("exp 1 request, got %v", got) + } + + if _, err := svc.FindAuthorizationByToken(ctx, ""); err != nil { + t.Fatal(err) + } + mfs = promtest.MustGather(t, reg) + m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizationByToken", "error": "false"}) + if got := m.GetCounter().GetValue(); got != 1 { + t.Fatalf("exp 1 request, got %v", got) + } + + if _, _, err := svc.FindAuthorizations(ctx, platform.AuthorizationFilter{}); err != nil { + t.Fatal(err) + } + mfs = promtest.MustGather(t, reg) + m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizations", "error": "false"}) + if got := m.GetCounter().GetValue(); got != 1 { + t.Fatalf("exp 1 request, got %v", got) + } + + if err := svc.CreateAuthorization(ctx, nil); err != nil { + t.Fatal(err) + } + mfs = promtest.MustGather(t, reg) + m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "CreateAuthorization", "error": "false"}) + if got := m.GetCounter().GetValue(); got != 1 { + t.Fatalf("exp 1 request, got %v", got) + } + + if err := svc.DeleteAuthorization(ctx, nil); err != nil { + t.Fatal(err) + } + mfs = promtest.MustGather(t, reg) + m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "DeleteAuthorization", "error": "false"}) + if got := m.GetCounter().GetValue(); got != 1 { + t.Fatalf("exp 1 request, got %v", got) + } + + forced := errors.New("forced error") + a.Err = forced + + if _, err := svc.FindAuthorizationByID(ctx, id); err != forced { + t.Fatalf("expected forced error, got %v", err) + } + mfs = promtest.MustGather(t, reg) + m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizationByID", "error": "true"}) + if got := m.GetCounter().GetValue(); got != 1 { + t.Fatalf("exp 1 request, got %v", got) + } + + if _, err := svc.FindAuthorizationByToken(ctx, ""); err != forced { + t.Fatalf("expected forced error, got %v", err) + } + mfs = promtest.MustGather(t, reg) + m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizationByToken", "error": "true"}) + if got := m.GetCounter().GetValue(); got != 1 { + t.Fatalf("exp 1 request, got %v", got) + } + + if _, _, err := svc.FindAuthorizations(ctx, platform.AuthorizationFilter{}); err != forced { + t.Fatalf("expected forced error, got %v", err) + } + mfs = promtest.MustGather(t, reg) + m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "FindAuthorizations", "error": "true"}) + if got := m.GetCounter().GetValue(); got != 1 { + t.Fatalf("exp 1 request, got %v", got) + } + + if err := svc.CreateAuthorization(ctx, nil); err != forced { + t.Fatalf("expected forced error, got %v", err) + } + mfs = promtest.MustGather(t, reg) + m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "CreateAuthorization", "error": "true"}) + if got := m.GetCounter().GetValue(); got != 1 { + t.Fatalf("exp 1 request, got %v", got) + } + + if err := svc.DeleteAuthorization(ctx, nil); err != forced { + t.Fatalf("expected forced error, got %v", err) + } + mfs = promtest.MustGather(t, reg) + m = promtest.MustFindMetric(t, mfs, "auth_prometheus_requests_total", map[string]string{"method": "DeleteAuthorization", "error": "true"}) + if got := m.GetCounter().GetValue(); got != 1 { + t.Fatalf("exp 1 request, got %v", got) + } +}