Skip to content

Commit

Permalink
feat: Add OAuth extensions for kafka scaler (kedacore#4486)
Browse files Browse the repository at this point in the history
Signed-off-by: qvalentin <valentin.theodor@web.de>
  • Loading branch information
qvalentin authored May 18, 2023
1 parent 8c9dd7f commit 2c77385
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio

- **Azure Data Exporer Scaler**: Use azidentity SDK ([#4489](https://github.com/kedacore/keda/issues/4489))
- **GCP PubSub Scaler**: Make it more flexible for metrics ([#4243](https://github.com/kedacore/keda/issues/4243))
- **Kafka Scaler:** Add support for OAuth extensions ([#4544](https://github.com/kedacore/keda/issues/4544))

### Fixes

Expand Down
15 changes: 14 additions & 1 deletion pkg/scalers/kafka_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type kafkaMetadata struct {
// OAUTHBEARER
scopes []string
oauthTokenEndpointURI string
oauthExtensions map[string]string

// TLS
enableTLS bool
Expand Down Expand Up @@ -163,6 +164,18 @@ func parseKafkaAuthParams(config *ScalerConfig, meta *kafkaMetadata) error {
return errors.New("no oauth token endpoint uri given")
}
meta.oauthTokenEndpointURI = strings.TrimSpace(config.AuthParams["oauthTokenEndpointUri"])

meta.oauthExtensions = make(map[string]string)
oauthExtensionsRaw := config.AuthParams["oauthExtensions"]
if oauthExtensionsRaw != "" {
for _, extension := range strings.Split(oauthExtensionsRaw, ",") {
splittedExtension := strings.Split(extension, "=")
if len(splittedExtension) != 2 {
return errors.New("invalid OAuthBearer extension, must be of format key=value")
}
meta.oauthExtensions[splittedExtension[0]] = splittedExtension[1]
}
}
}
} else {
return fmt.Errorf("err SASL mode %s given", mode)
Expand Down Expand Up @@ -382,7 +395,7 @@ func getKafkaClients(metadata kafkaMetadata) (sarama.Client, sarama.ClusterAdmin

if metadata.saslType == KafkaSASLTypeOAuthbearer {
config.Net.SASL.Mechanism = sarama.SASLTypeOAuth
config.Net.SASL.TokenProvider = OAuthBearerTokenProvider(metadata.username, metadata.password, metadata.oauthTokenEndpointURI, metadata.scopes)
config.Net.SASL.TokenProvider = OAuthBearerTokenProvider(metadata.username, metadata.password, metadata.oauthTokenEndpointURI, metadata.scopes, metadata.oauthExtensions)
}

client, err := sarama.NewClient(metadata.bootstrapServers, config)
Expand Down
6 changes: 4 additions & 2 deletions pkg/scalers/kafka_scaler_oauth_token_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (

type TokenProvider struct {
tokenSource oauth2.TokenSource
extensions map[string]string
}

func OAuthBearerTokenProvider(clientID, clientSecret, tokenURL string, scopes []string) sarama.AccessTokenProvider {
func OAuthBearerTokenProvider(clientID, clientSecret, tokenURL string, scopes []string, extensions map[string]string) sarama.AccessTokenProvider {
cfg := clientcredentials.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Expand All @@ -22,6 +23,7 @@ func OAuthBearerTokenProvider(clientID, clientSecret, tokenURL string, scopes []

return &TokenProvider{
tokenSource: cfg.TokenSource(context.Background()),
extensions: extensions,
}
}

Expand All @@ -31,5 +33,5 @@ func (t *TokenProvider) Token() (*sarama.AccessToken, error) {
return nil, err
}

return &sarama.AccessToken{Token: token.AccessToken}, nil
return &sarama.AccessToken{Token: token.AccessToken, Extensions: t.extensions}, nil
}
11 changes: 11 additions & 0 deletions pkg/scalers/kafka_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ var parseKafkaOAuthbrearerAuthParamsTestDataset = []parseKafkaAuthParamsTestData
{map[string]string{"sasl": "foo", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable"}, true, false},
// failure, SASL OAUTHBEARER + TLS missing oauthTokenEndpointUri
{map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "", "tls": "disable"}, true, false},
// success, SASL OAUTHBEARER + extension
{map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable", "oauthExtensions": "extension_foo=bar"}, false, false},
// success, SASL OAUTHBEARER + multiple extensions
{map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable", "oauthExtensions": "extension_foo=bar,extension_baz=baz"}, false, false},
// failure, SASL OAUTHBEARER + bad extension
{map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable", "oauthExtensions": "extension_foo=bar,extension_bazbaz"}, true, false},
}

var kafkaMetricIdentifiers = []kafkaMetricIdentifier{
Expand Down Expand Up @@ -384,6 +390,11 @@ func TestKafkaOAuthbrearerAuthParams(t *testing.T) {
t.Errorf("Expected scopes to be set to %v but got %v\n", strings.Count(testData.authParams["scopes"], ","), len(meta.scopes))
}
}
if err == nil && testData.authParams["oauthExtensions"] != "" {
if len(meta.oauthExtensions) != strings.Count(testData.authParams["oauthExtensions"], ",")+1 {
t.Errorf("Expected number of extensions to be set to %v but got %v\n", strings.Count(testData.authParams["oauthExtensions"], ",")+1, len(meta.oauthExtensions))
}
}
}
}

Expand Down

0 comments on commit 2c77385

Please sign in to comment.