Skip to content

Commit

Permalink
Enable Azure Workload Identity to authorize against RabbitMQ manageme… (
Browse files Browse the repository at this point in the history
#4657)

Signed-off-by: Jakub Adamus <jakub.adamus@vivantis.cz>
Signed-off-by: KratkyZobak <kratky@zobak.cz>
Co-authored-by: Jakub Adamus <jakub.adamus@vivantis.cz>
Co-authored-by: Jakub Adamus <krarky@zobak.cz>
  • Loading branch information
3 people authored Jun 21, 2023
1 parent 56517ff commit 39d6094
Show file tree
Hide file tree
Showing 12 changed files with 494 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio
- **Pulsar Scaler**: Improve error messages for unsuccessful connections ([#4563](https://github.com/kedacore/keda/issues/4563))
- **Security**: Enable secret scanning in GitHub repo
- **RabbitMQ Scaler**: Add support for `unsafeSsl` in trigger metadata ([#4448](https://github.com/kedacore/keda/issues/4448))
- **RabbitMQ Scaler**: Add support for `workloadIdentityResource` and utilize AzureAD Workload Identity for HTTP authorization ([#4716](https://github.com/kedacore/keda/issues/4716))
- **PostgreSQL Scaler**: Replace `lib/pq` with `pgx` ([#4704](https://github.com/kedacore/keda/issues/4704))
- **Prometheus Metrics**: Add new metric with KEDA build info ([#4647](https://github.com/kedacore/keda/issues/4647))
- **Prometheus Scaler**: Add support for Google Managed Prometheus ([#4675](https://github.com/kedacore/keda/pull/4675))
Expand Down
54 changes: 46 additions & 8 deletions pkg/scalers/rabbitmq_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
v2 "k8s.io/api/autoscaling/v2"
"k8s.io/metrics/pkg/apis/external_metrics"

"github.com/kedacore/keda/v2/apis/keda/v1alpha1"
"github.com/kedacore/keda/v2/pkg/scalers/azure"
kedautil "github.com/kedacore/keda/v2/pkg/util"
)

Expand Down Expand Up @@ -59,6 +61,7 @@ type rabbitMQScaler struct {
connection *amqp.Connection
channel *amqp.Channel
httpClient *http.Client
azureOAuth *azure.ADWorkloadIdentityTokenProvider
logger logr.Logger
}

Expand All @@ -85,6 +88,10 @@ type rabbitMQMetadata struct {
keyPassword string
enableTLS bool
unsafeSsl bool

// token provider for azure AD
workloadIdentityClientID string
workloadIdentityResource string
}

type queueInfo struct {
Expand Down Expand Up @@ -233,6 +240,13 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {

meta.keyPassword = config.AuthParams["keyPassword"]

if config.PodIdentity.Provider == v1alpha1.PodIdentityProviderAzureWorkload {
if config.AuthParams["workloadIdentityResource"] != "" {
meta.workloadIdentityClientID = config.PodIdentity.IdentityID
meta.workloadIdentityResource = config.AuthParams["workloadIdentityResource"]
}
}

certGiven := meta.cert != ""
keyGiven := meta.key != ""
if certGiven != keyGiven {
Expand Down Expand Up @@ -264,6 +278,10 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
}
}

if meta.protocol == amqpProtocol && config.AuthParams["workloadIdentityResource"] != "" {
return nil, fmt.Errorf("workload identity is not supported for amqp protocol currently")
}

// Resolve queueName
if val, ok := config.TriggerMetadata["queueName"]; ok {
meta.queueName = val
Expand Down Expand Up @@ -464,9 +482,9 @@ func (s *rabbitMQScaler) Close(context.Context) error {
return nil
}

func (s *rabbitMQScaler) getQueueStatus() (int64, float64, error) {
func (s *rabbitMQScaler) getQueueStatus(ctx context.Context) (int64, float64, error) {
if s.metadata.protocol == httpProtocol {
info, err := s.getQueueInfoViaHTTP()
info, err := s.getQueueInfoViaHTTP(ctx)
if err != nil {
return -1, -1, err
}
Expand All @@ -488,12 +506,32 @@ func (s *rabbitMQScaler) getQueueStatus() (int64, float64, error) {
return int64(items.Messages), 0, nil
}

func getJSON(s *rabbitMQScaler, url string) (queueInfo, error) {
func getJSON(ctx context.Context, s *rabbitMQScaler, url string) (queueInfo, error) {
var result queueInfo
r, err := s.httpClient.Get(url)

request, err := http.NewRequest("GET", url, nil)
if err != nil {
return result, err
}

if s.metadata.workloadIdentityResource != "" {
if s.azureOAuth == nil {
s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.workloadIdentityClientID, s.metadata.workloadIdentityResource)
}

err = s.azureOAuth.Refresh()
if err != nil {
return result, err
}

request.Header.Set("Authorization", "Bearer "+s.azureOAuth.OAuthToken())
}

r, err := s.httpClient.Do(request)
if err != nil {
return result, err
}

defer r.Body.Close()

if r.StatusCode == 200 {
Expand All @@ -518,7 +556,7 @@ func getJSON(s *rabbitMQScaler, url string) (queueInfo, error) {
return result, fmt.Errorf("error requesting rabbitMQ API status: %s, response: %s, from: %s", r.Status, body, url)
}

func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) {
func (s *rabbitMQScaler) getQueueInfoViaHTTP(ctx context.Context) (*queueInfo, error) {
parsedURL, err := url.Parse(s.metadata.host)

if err != nil {
Expand Down Expand Up @@ -547,7 +585,7 @@ func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) {
}

var info queueInfo
info, err = getJSON(s, getQueueInfoManagementURI)
info, err = getJSON(ctx, s, getQueueInfoManagementURI)

if err != nil {
return nil, err
Expand All @@ -572,8 +610,8 @@ func (s *rabbitMQScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpe
}

// GetMetricsAndActivity returns value for a supported metric and an error if there is a problem getting the metric
func (s *rabbitMQScaler) GetMetricsAndActivity(_ context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
messages, publishRate, err := s.getQueueStatus()
func (s *rabbitMQScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
messages, publishRate, err := s.getQueueStatus(ctx)
if err != nil {
return []external_metrics.ExternalMetricValue{}, false, s.anonymizeRabbitMQError(err)
}
Expand Down
38 changes: 26 additions & 12 deletions pkg/scalers/rabbitmq_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"time"

"github.com/stretchr/testify/assert"

"github.com/kedacore/keda/v2/apis/keda/v1alpha1"
)

const (
Expand All @@ -24,10 +26,12 @@ type parseRabbitMQMetadataTestData struct {
}

type parseRabbitMQAuthParamTestData struct {
metadata map[string]string
authParams map[string]string
isError bool
enableTLS bool
metadata map[string]string
podIdentity v1alpha1.AuthPodIdentity
authParams map[string]string
isError bool
enableTLS bool
workloadIdentity bool
}

type rabbitMQMetricIdentifier struct {
Expand Down Expand Up @@ -134,19 +138,23 @@ var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{
}

var testRabbitMQAuthParamData = []parseRabbitMQAuthParamTestData{
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, true, false},
// success, TLS cert/key and assumed public CA
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, true, false},
// success, TLS cert/key + key password and assumed public CA
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, true, false},
// success, TLS CA only
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa"}, false, true, false},
// failure, TLS missing cert
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, true, false},
// failure, TLS missing key
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, true, false},
// failure, TLS invalid
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, true, false},
// success, WorkloadIdentity
{map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "http"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: "client-id"}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, false, false, true},
// failure, WoekloadIdentity not supported for amqp
{map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "amqp"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: "client-id"}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, true, false, false},
}
var rabbitMQMetricIdentifiers = []rabbitMQMetricIdentifier{
{&testRabbitMQMetadata[1], 0, "s0-rabbitmq-sample"},
Expand Down Expand Up @@ -177,7 +185,7 @@ func TestRabbitMQParseMetadata(t *testing.T) {

func TestRabbitMQParseAuthParamData(t *testing.T) {
for _, testData := range testRabbitMQAuthParamData {
metadata, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
metadata, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams, PodIdentity: testData.podIdentity})
if err != nil && !testData.isError {
t.Error("Expected success but got error", err)
}
Expand All @@ -201,6 +209,12 @@ func TestRabbitMQParseAuthParamData(t *testing.T) {
t.Errorf("Expected key to be set to %v but got %v\n", testData.authParams["keyPassword"], metadata.key)
}
}
if metadata != nil && metadata.workloadIdentityClientID != "" && !testData.workloadIdentity {
t.Errorf("Expected workloadIdentity to be disabled but got %v as client ID and %v as resource\n", metadata.workloadIdentityClientID, metadata.workloadIdentityResource)
}
if metadata != nil && metadata.workloadIdentityClientID == "" && testData.workloadIdentity {
t.Error("Expected workloadIdentity to be enabled but was not\n")
}
}
}

Expand Down
63 changes: 53 additions & 10 deletions tests/scalers/rabbitmq/rabbitmq_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ data:
default_vhost = {{.VHostName}}
management.tcp.port = 15672
management.tcp.ip = 0.0.0.0
{{if .EnableOAuth}}
auth_backends.1 = rabbit_auth_backend_internal
auth_backends.2 = rabbit_auth_backend_oauth2
auth_backends.3 = rabbit_auth_backend_amqp
auth_oauth2.resource_server_id = {{.OAuthClientID}}
auth_oauth2.scope_prefix = rabbitmq.
auth_oauth2.additional_scopes_key = {{.OAuthScopesKey}}
auth_oauth2.jwks_url = {{.OAuthJwksURI}}
{{end}}
enabled_plugins: |
[rabbitmq_management].
---
Expand Down Expand Up @@ -158,33 +167,67 @@ spec:
`
)

type RabbitOAuthConfig struct {
Enable bool
ClientID string
ScopesKey string
JwksURI string
}

func WithoutOAuth() RabbitOAuthConfig {
return RabbitOAuthConfig{
Enable: false,
}
}

func WithAzureADOAuth(tenantID string, clientID string) RabbitOAuthConfig {
return RabbitOAuthConfig{
Enable: true,
ClientID: clientID,
ScopesKey: "roles",
JwksURI: fmt.Sprintf("https://login.microsoftonline.com/%s/discovery/keys", tenantID),
}
}

type templateData struct {
Namespace string
Connection string
QueueName string
HostName, VHostName string
Username, Password string
MessageCount int
EnableOAuth bool
OAuthClientID string
OAuthScopesKey string
OAuthJwksURI string
}

func RMQInstall(t *testing.T, kc *kubernetes.Clientset, namespace, user, password, vhost string) {
func RMQInstall(t *testing.T, kc *kubernetes.Clientset, namespace, user, password, vhost string, oauth RabbitOAuthConfig) {
helper.CreateNamespace(t, kc, namespace)
data := templateData{
Namespace: namespace,
VHostName: vhost,
Username: user,
Password: password,
Namespace: namespace,
VHostName: vhost,
Username: user,
Password: password,
EnableOAuth: oauth.Enable,
OAuthClientID: oauth.ClientID,
OAuthScopesKey: oauth.ScopesKey,
OAuthJwksURI: oauth.JwksURI,
}

helper.KubectlApplyWithTemplate(t, data, "rmqDeploymentTemplate", deploymentTemplate)
}

func RMQUninstall(t *testing.T, namespace, user, password, vhost string) {
func RMQUninstall(t *testing.T, namespace, user, password, vhost string, oauth RabbitOAuthConfig) {
data := templateData{
Namespace: namespace,
VHostName: vhost,
Username: user,
Password: password,
Namespace: namespace,
VHostName: vhost,
Username: user,
Password: password,
EnableOAuth: oauth.Enable,
OAuthClientID: oauth.ClientID,
OAuthScopesKey: oauth.ScopesKey,
OAuthJwksURI: oauth.JwksURI,
}

helper.KubectlDeleteWithTemplate(t, data, "rmqDeploymentTemplate", deploymentTemplate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

// Load environment variables from .env file
var _ = godotenv.Load("../../.env")
var _ = godotenv.Load("../../../.env")

const (
testName = "rmq-queue-amqp-test"
Expand Down Expand Up @@ -79,7 +79,7 @@ func TestScaler(t *testing.T) {
kc := GetKubernetesClient(t)
data, templates := getTemplateData()

RMQInstall(t, kc, rmqNamespace, user, password, vhost)
RMQInstall(t, kc, rmqNamespace, user, password, vhost, WithoutOAuth())
CreateKubernetesResources(t, kc, testNamespace, data, templates)

assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1),
Expand All @@ -92,7 +92,7 @@ func TestScaler(t *testing.T) {
// cleanup
t.Log("--- cleaning up ---")
DeleteKubernetesResources(t, testNamespace, data, templates)
RMQUninstall(t, rmqNamespace, user, password, vhost)
RMQUninstall(t, rmqNamespace, user, password, vhost, WithoutOAuth())
}

func getTemplateData() (templateData, []Template) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

// Load environment variables from .env file
var _ = godotenv.Load("../../.env")
var _ = godotenv.Load("../../../.env")

const (
testName = "rmq-queue-amqp-vhost-test"
Expand Down Expand Up @@ -79,7 +79,7 @@ func TestScaler(t *testing.T) {
kc := GetKubernetesClient(t)
data, templates := getTemplateData()

RMQInstall(t, kc, rmqNamespace, user, password, vhost)
RMQInstall(t, kc, rmqNamespace, user, password, vhost, WithoutOAuth())
CreateKubernetesResources(t, kc, testNamespace, data, templates)

assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1),
Expand All @@ -92,7 +92,7 @@ func TestScaler(t *testing.T) {
// cleanup
t.Log("--- cleaning up ---")
DeleteKubernetesResources(t, testNamespace, data, templates)
RMQUninstall(t, rmqNamespace, user, password, vhost)
RMQUninstall(t, rmqNamespace, user, password, vhost, WithoutOAuth())
}

func getTemplateData() (templateData, []Template) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

// Load environment variables from .env file
var _ = godotenv.Load("../../.env")
var _ = godotenv.Load("../../../.env")

const (
testName = "rmq-queue-http-test"
Expand Down Expand Up @@ -80,7 +80,7 @@ func TestScaler(t *testing.T) {
kc := GetKubernetesClient(t)
data, templates := getTemplateData()

RMQInstall(t, kc, rmqNamespace, user, password, vhost)
RMQInstall(t, kc, rmqNamespace, user, password, vhost, WithoutOAuth())
CreateKubernetesResources(t, kc, testNamespace, data, templates)

assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1),
Expand All @@ -91,7 +91,7 @@ func TestScaler(t *testing.T) {
// cleanup
t.Log("--- cleaning up ---")
DeleteKubernetesResources(t, testNamespace, data, templates)
RMQUninstall(t, rmqNamespace, user, password, vhost)
RMQUninstall(t, rmqNamespace, user, password, vhost, WithoutOAuth())
}

func getTemplateData() (templateData, []Template) {
Expand Down
Loading

0 comments on commit 39d6094

Please sign in to comment.