diff --git a/CHANGELOG.md b/CHANGELOG.md index 1decf86b920..222af2f4860 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio - **Prometheus Scaler:** Introduce skipping of certificate check for unsigned certs ([#2310](https://github.com/kedacore/keda/issues/2310)) - **Event Hubs Scaler:** Support Azure Active Direcotry Pod & Workload Identity for Storage Blobs ([#3569](https://github.com/kedacore/keda/issues/3569)) - **Metrics API Scaler:** Add unsafeSsl paramater to skip certificate validation when connecting over HTTPS ([#3728](https://github.com/kedacore/keda/discussions/3728)) +- **NATS Scalers:** Support HTTPS protocol in NATS Scalers ([#3805](https://github.com/kedacore/keda/issues/3805)) ### Fixes diff --git a/pkg/scalers/nats_jetstream_scaler.go b/pkg/scalers/nats_jetstream_scaler.go index 4d944e00604..b038551bd95 100644 --- a/pkg/scalers/nats_jetstream_scaler.go +++ b/pkg/scalers/nats_jetstream_scaler.go @@ -21,6 +21,8 @@ import ( const ( jetStreamMetricType = "External" defaultJetStreamLagThreshold = 10 + natsHTTPProtocol = "http" + natsHTTPSProtocol = "https" ) type natsJetStreamScaler struct { @@ -108,11 +110,6 @@ func NewNATSJetStreamScaler(config *ScalerConfig) (Scaler, error) { func parseNATSJetStreamMetadata(config *ScalerConfig) (natsJetStreamMetadata, error) { meta := natsJetStreamMetadata{} - var err error - meta.monitoringEndpoint, err = GetFromAuthOrMeta(config, "natsServerMonitoringEndpoint") - if err != nil { - return meta, err - } if config.TriggerMetadata["account"] == "" { return meta, errors.New("no account name given") @@ -149,17 +146,34 @@ func parseNATSJetStreamMetadata(config *ScalerConfig) (natsJetStreamMetadata, er } meta.scalerIndex = config.ScalerIndex + + natsServerEndpoint, err := GetFromAuthOrMeta(config, "natsServerMonitoringEndpoint") + if err != nil { + return meta, err + } + useHTTPS := false + if val, ok := config.TriggerMetadata["useHttps"]; ok { + useHTTPS, err = strconv.ParseBool(val) + if err != nil { + return meta, fmt.Errorf("useHTTPS parsing error %s", err.Error()) + } + } + meta.monitoringEndpoint = getNATSJetStreamEndpoint(useHTTPS, natsServerEndpoint, meta.account) + return meta, nil } -func (s *natsJetStreamScaler) getNATSJetStreamEndpoint() string { - return fmt.Sprintf("http://%s/jsz?acc=%s&consumers=true&config=true", s.metadata.monitoringEndpoint, s.metadata.account) +func getNATSJetStreamEndpoint(useHTTPS bool, natsServerEndpoint string, account string) string { + protocol := natsHTTPProtocol + if useHTTPS { + protocol = natsHTTPSProtocol + } + + return fmt.Sprintf("%s://%s/jsz?acc=%s&consumers=true&config=true", protocol, natsServerEndpoint, account) } func (s *natsJetStreamScaler) IsActive(ctx context.Context) (bool, error) { - monitoringEndpoint := s.getNATSJetStreamEndpoint() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, monitoringEndpoint, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.metadata.monitoringEndpoint, nil) if err != nil { return false, err } @@ -216,7 +230,7 @@ func (s *natsJetStreamScaler) GetMetricSpecForScaling(context.Context) []v2.Metr } func (s *natsJetStreamScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.getNATSJetStreamEndpoint(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.metadata.monitoringEndpoint, nil) if err != nil { return nil, err } diff --git a/pkg/scalers/nats_jetstream_scaler_test.go b/pkg/scalers/nats_jetstream_scaler_test.go index 542fa6b4b1b..3d75ed548cc 100644 --- a/pkg/scalers/nats_jetstream_scaler_test.go +++ b/pkg/scalers/nats_jetstream_scaler_test.go @@ -3,7 +3,10 @@ package scalers import ( "context" "net/http" + "strings" "testing" + + "github.com/stretchr/testify/assert" ) type parseNATSJetStreamMetadataTestData struct { @@ -30,13 +33,15 @@ var testNATSJetStreamMetadata = []parseNATSJetStreamMetadataTestData{ // Missing nats server monitoring endpoint, should fail {map[string]string{"account": "$G", "stream": "mystream"}, map[string]string{}, true}, // All good. - {map[string]string{"natsServerMonitoringEndpoint": "nats.nats:8222", "account": "$G", "stream": "mystream", "consumer": "pull_consumer"}, map[string]string{}, false}, + {map[string]string{"natsServerMonitoringEndpoint": "nats.nats:8222", "account": "$G", "stream": "mystream", "consumer": "pull_consumer", "useHttps": "true"}, map[string]string{}, false}, // All good + activationLagThreshold {map[string]string{"natsServerMonitoringEndpoint": "nats.nats:8222", "account": "$G", "stream": "mystream", "consumer": "pull_consumer", "activationLagThreshold": "10"}, map[string]string{}, false}, // natsServerMonitoringEndpoint is defined in authParams {map[string]string{"account": "$G", "stream": "mystream", "consumer": "pull_consumer"}, map[string]string{"natsServerMonitoringEndpoint": "nats.nats:8222"}, false}, // Missing nats server monitoring endpoint , should fail {map[string]string{"account": "$G", "stream": "mystream", "consumer": "pull_consumer"}, map[string]string{"natsServerMonitoringEndpoint": ""}, true}, + // Misconfigured https, should fail + {map[string]string{"natsServerMonitoringEndpoint": "nats.nats:8222", "account": "$G", "stream": "mystream", "consumer": "pull_consumer", "useHttps": "error"}, map[string]string{}, true}, } var natsJetStreamMetricIdentifiers = []natsJetStreamMetricIdentifier{ @@ -50,7 +55,7 @@ func TestNATSJetStreamParseMetadata(t *testing.T) { if err != nil && !testData.isError { t.Error("Expected success but got error", err) } else if testData.isError && err == nil { - t.Error("Expected error but got success") + t.Error("Expected error but got success" + testData.authParams["natsServerMonitoringEndpoint"] + "foo") } } } @@ -75,3 +80,15 @@ func TestNATSJetStreamGetMetricSpecForScaling(t *testing.T) { } } } + +func TestGetNATSJetStreamEndpointHTTPS(t *testing.T) { + endpoint := getNATSJetStreamEndpoint(true, "nats.nats:8222", "$G") + + assert.True(t, strings.HasPrefix(endpoint, "https:")) +} + +func TestGetNATSJetStreamEndpointHTTP(t *testing.T) { + endpoint := getNATSJetStreamEndpoint(false, "nats.nats:8222", "$G") + + assert.True(t, strings.HasPrefix(endpoint, "http:")) +} diff --git a/pkg/scalers/stan_scaler.go b/pkg/scalers/stan_scaler.go index c0324357488..fc8cf697bb4 100644 --- a/pkg/scalers/stan_scaler.go +++ b/pkg/scalers/stan_scaler.go @@ -45,18 +45,21 @@ type stanScaler struct { } type stanMetadata struct { - natsServerMonitoringEndpoint string - queueGroup string - durableName string - subject string - lagThreshold int64 - activationLagThreshold int64 - scalerIndex int + monitoringEndpoint string + stanChannelsEndpoint string + queueGroup string + durableName string + subject string + lagThreshold int64 + activationLagThreshold int64 + scalerIndex int } const ( - stanMetricType = "External" - defaultStanLagThreshold = 10 + stanMetricType = "External" + defaultStanLagThreshold = 10 + natsStreamingHTTPProtocol = "http" + natsStreamingHTTPSProtocol = "https" ) // NewStanScaler creates a new stanScaler @@ -82,11 +85,6 @@ func NewStanScaler(config *ScalerConfig) (Scaler, error) { func parseStanMetadata(config *ScalerConfig) (stanMetadata, error) { meta := stanMetadata{} - var err error - meta.natsServerMonitoringEndpoint, err = GetFromAuthOrMeta(config, "natsServerMonitoringEndpoint") - if err != nil { - return meta, err - } if config.TriggerMetadata["queueGroup"] == "" { return meta, errors.New("no queue group given") @@ -123,25 +121,39 @@ func parseStanMetadata(config *ScalerConfig) (stanMetadata, error) { } meta.scalerIndex = config.ScalerIndex + + var err error + useHTTPS := false + if val, ok := config.TriggerMetadata["useHttps"]; ok { + useHTTPS, err = strconv.ParseBool(val) + if err != nil { + return meta, fmt.Errorf("useHTTPS parsing error %s", err.Error()) + } + } + natsServerEndpoint, err := GetFromAuthOrMeta(config, "natsServerMonitoringEndpoint") + if err != nil { + return meta, err + } + meta.stanChannelsEndpoint = getSTANChannelsEndpoint(useHTTPS, natsServerEndpoint) + meta.monitoringEndpoint = getMonitoringEndpoint(meta.stanChannelsEndpoint, meta.subject) + return meta, nil } // IsActive determines if we need to scale from zero func (s *stanScaler) IsActive(ctx context.Context) (bool, error) { - monitoringEndpoint := s.getMonitoringEndpoint() - - req, err := http.NewRequestWithContext(ctx, "GET", monitoringEndpoint, nil) + req, err := http.NewRequestWithContext(ctx, "GET", s.metadata.monitoringEndpoint, nil) if err != nil { return false, err } resp, err := s.httpClient.Do(req) if err != nil { - s.logger.Error(err, "Unable to access the nats streaming broker monitoring endpoint", "natsServerMonitoringEndpoint", s.metadata.natsServerMonitoringEndpoint) + s.logger.Error(err, "Unable to access the nats streaming broker monitoring endpoint", "natsServerMonitoringEndpoint", s.metadata.monitoringEndpoint) return false, err } if resp.StatusCode == 404 { - req, err := http.NewRequestWithContext(ctx, "GET", s.getSTANChannelsEndpoint(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", s.metadata.stanChannelsEndpoint, nil) if err != nil { return false, err } @@ -151,9 +163,9 @@ func (s *stanScaler) IsActive(ctx context.Context) (bool, error) { } defer baseResp.Body.Close() if baseResp.StatusCode == 404 { - s.logger.Info("Streaming broker endpoint returned 404. Please ensure it has been created", "url", monitoringEndpoint, "channelName", s.metadata.subject) + s.logger.Info("Streaming broker endpoint returned 404. Please ensure it has been created", "url", s.metadata.monitoringEndpoint, "channelName", s.metadata.subject) } else { - s.logger.Info("Unable to connect to STAN. Please ensure you have configured the ScaledObject with the correct endpoint.", "baseResp.StatusCode", baseResp.StatusCode, "natsServerMonitoringEndpoint", s.metadata.natsServerMonitoringEndpoint) + s.logger.Info("Unable to connect to STAN. Please ensure you have configured the ScaledObject with the correct endpoint.", "baseResp.StatusCode", baseResp.StatusCode, "monitoringEndpoint", s.metadata.monitoringEndpoint) } return false, err @@ -167,12 +179,16 @@ func (s *stanScaler) IsActive(ctx context.Context) (bool, error) { return s.hasPendingMessage() || s.getMaxMsgLag() > s.metadata.activationLagThreshold, nil } -func (s *stanScaler) getSTANChannelsEndpoint() string { - return "http://" + s.metadata.natsServerMonitoringEndpoint + "/streaming/channelsz" +func getSTANChannelsEndpoint(useHTTPS bool, natsServerEndpoint string) string { + protocol := natsStreamingHTTPProtocol + if useHTTPS { + protocol = natsStreamingHTTPSProtocol + } + return fmt.Sprintf("%s://%s/streaming/channelsz", protocol, natsServerEndpoint) } -func (s *stanScaler) getMonitoringEndpoint() string { - return s.getSTANChannelsEndpoint() + "?channel=" + s.metadata.subject + "&subs=1" +func getMonitoringEndpoint(stanChannelsEndpoint string, subject string) string { + return fmt.Sprintf("%s?channel=%s&subs=1", stanChannelsEndpoint, subject) } func (s *stanScaler) getMaxMsgLag() int64 { @@ -227,14 +243,14 @@ func (s *stanScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { // GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *stanScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - req, err := http.NewRequestWithContext(ctx, "GET", s.getMonitoringEndpoint(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", s.metadata.monitoringEndpoint, nil) if err != nil { return nil, err } resp, err := s.httpClient.Do(req) if err != nil { - s.logger.Error(err, "Unable to access the nats streaming broker monitoring endpoint", "natsServerMonitoringEndpoint", s.metadata.natsServerMonitoringEndpoint) + s.logger.Error(err, "Unable to access the nats streaming broker monitoring endpoint", "monitoringEndpoint", s.metadata.monitoringEndpoint) return []external_metrics.ExternalMetricValue{}, err } diff --git a/pkg/scalers/stan_scaler_test.go b/pkg/scalers/stan_scaler_test.go index e52a1ceb22d..be665cd3a8e 100644 --- a/pkg/scalers/stan_scaler_test.go +++ b/pkg/scalers/stan_scaler_test.go @@ -3,7 +3,10 @@ package scalers import ( "context" "net/http" + "strings" "testing" + + "github.com/stretchr/testify/assert" ) type parseStanMetadataTestData struct { @@ -28,13 +31,15 @@ var testStanMetadata = []parseStanMetadataTestData{ // Missing nats server monitoring endpoint, should fail {map[string]string{"queueGroup": "grp1", "subject": "mySubject"}, map[string]string{}, true}, // All good. - {map[string]string{"natsServerMonitoringEndpoint": "stan-nats-ss", "queueGroup": "grp1", "durableName": "ImDurable", "subject": "mySubject"}, map[string]string{}, false}, + {map[string]string{"natsServerMonitoringEndpoint": "stan-nats-ss", "queueGroup": "grp1", "durableName": "ImDurable", "subject": "mySubject", "useHttps": "true"}, map[string]string{}, false}, // All good + activationLagThreshold {map[string]string{"natsServerMonitoringEndpoint": "stan-nats-ss", "queueGroup": "grp1", "durableName": "ImDurable", "subject": "mySubject", "activationLagThreshold": "10"}, map[string]string{}, false}, // natsServerMonitoringEndpoint is defined in authParams {map[string]string{"queueGroup": "grp1", "durableName": "ImDurable", "subject": "mySubject"}, map[string]string{"natsServerMonitoringEndpoint": "stan-nats-ss"}, false}, // Missing nats server monitoring endpoint , should fail {map[string]string{"queueGroup": "grp1", "durableName": "ImDurable", "subject": "mySubject"}, map[string]string{"natsServerMonitoringEndpoint": ""}, true}, + // Misconfigured https, should fail + {map[string]string{"natsServerMonitoringEndpoint": "stan-nats-ss", "queueGroup": "grp1", "durableName": "ImDurable", "subject": "mySubject", "useHttps": "error"}, map[string]string{}, true}, } var stanMetricIdentifiers = []stanMetricIdentifier{ @@ -73,3 +78,15 @@ func TestStanGetMetricSpecForScaling(t *testing.T) { } } } + +func TestGetSTANChannelsEndpointHTTPS(t *testing.T) { + endpoint := getSTANChannelsEndpoint(true, "stan-nats-ss") + + assert.True(t, strings.HasPrefix(endpoint, "https:")) +} + +func TestGetSTANChannelsEndpointHTTP(t *testing.T) { + endpoint := getSTANChannelsEndpoint(false, "stan-nats-ss") + + assert.True(t, strings.HasPrefix(endpoint, "http:")) +}