Skip to content

Commit

Permalink
fix: migrate aws/credentials.go to use NewSession, same functionality…
Browse files Browse the repository at this point in the history
… but now supports error (#9878)

(cherry picked from commit fde6374)
  • Loading branch information
sspaink authored and reimda committed Oct 7, 2021
1 parent 3dd21de commit 068eb10
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 51 deletions.
15 changes: 9 additions & 6 deletions config/aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ type CredentialConfig struct {
WebIdentityTokenFile string `toml:"web_identity_token_file"`
}

func (c *CredentialConfig) Credentials() client.ConfigProvider {
func (c *CredentialConfig) Credentials() (client.ConfigProvider, error) {
if c.RoleARN != "" {
return c.assumeCredentials()
}

return c.rootCredentials()
}

func (c *CredentialConfig) rootCredentials() client.ConfigProvider {
func (c *CredentialConfig) rootCredentials() (client.ConfigProvider, error) {
config := &aws.Config{
Region: aws.String(c.Region),
}
Expand All @@ -42,11 +42,14 @@ func (c *CredentialConfig) rootCredentials() client.ConfigProvider {
config.Credentials = credentials.NewSharedCredentials(c.Filename, c.Profile)
}

return session.New(config)
return session.NewSession(config)
}

func (c *CredentialConfig) assumeCredentials() client.ConfigProvider {
rootCredentials := c.rootCredentials()
func (c *CredentialConfig) assumeCredentials() (client.ConfigProvider, error) {
rootCredentials, err := c.rootCredentials()
if err != nil {
return nil, err
}
config := &aws.Config{
Region: aws.String(c.Region),
Endpoint: &c.EndpointURL,
Expand All @@ -58,5 +61,5 @@ func (c *CredentialConfig) assumeCredentials() client.ConfigProvider {
config.Credentials = stscreds.NewCredentials(rootCredentials, c.RoleARN)
}

return session.New(config)
return session.NewSession(config)
}
6 changes: 5 additions & 1 deletion plugins/inputs/cloudwatch/cloudwatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,11 @@ func (c *CloudWatch) initializeCloudWatch() error {
}

loglevel := aws.LogOff
c.client = cwClient.New(c.CredentialConfig.Credentials(), cfg.WithLogLevel(loglevel))
p, err := c.CredentialConfig.Credentials()
if err != nil {
return err
}
c.client = cwClient.New(p, cfg.WithLogLevel(loglevel))

// Initialize regex matchers for each Dimension value.
for _, m := range c.Metrics {
Expand Down
31 changes: 19 additions & 12 deletions plugins/inputs/kinesis_consumer/kinesis_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,24 +153,31 @@ func (k *KinesisConsumer) SetParser(parser parsers.Parser) {
}

func (k *KinesisConsumer) connect(ac telegraf.Accumulator) error {
client := kinesis.New(k.CredentialConfig.Credentials())
p, err := k.CredentialConfig.Credentials()
if err != nil {
return err
}
client := kinesis.New(p)

k.checkpoint = &noopCheckpoint{}
if k.DynamoDB != nil {
var err error
p, err := (&internalaws.CredentialConfig{
Region: k.Region,
AccessKey: k.AccessKey,
SecretKey: k.SecretKey,
RoleARN: k.RoleARN,
Profile: k.Profile,
Filename: k.Filename,
Token: k.Token,
EndpointURL: k.EndpointURL,
}).Credentials()
if err != nil {
return err
}
k.checkpoint, err = ddb.New(
k.DynamoDB.AppName,
k.DynamoDB.TableName,
ddb.WithDynamoClient(dynamodb.New((&internalaws.CredentialConfig{
Region: k.Region,
AccessKey: k.AccessKey,
SecretKey: k.SecretKey,
RoleARN: k.RoleARN,
Profile: k.Profile,
Filename: k.Filename,
Token: k.Token,
EndpointURL: k.EndpointURL,
}).Credentials())),
ddb.WithDynamoClient(dynamodb.New(p)),
ddb.WithMaxInterval(time.Second*10),
)
if err != nil {
Expand Down
18 changes: 11 additions & 7 deletions plugins/outputs/cloudwatch/cloudwatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ var sampleConfig = `
## Namespace for the CloudWatch MetricDatums
namespace = "InfluxData/Telegraf"
## If you have a large amount of metrics, you should consider to send statistic
## values instead of raw metrics which could not only improve performance but
## also save AWS API cost. If enable this flag, this plugin would parse the required
## CloudWatch statistic fields (count, min, max, and sum) and send them to CloudWatch.
## You could use basicstats aggregator to calculate those fields. If not all statistic
## fields are available, all fields would still be sent as raw metrics.
## If you have a large amount of metrics, you should consider to send statistic
## values instead of raw metrics which could not only improve performance but
## also save AWS API cost. If enable this flag, this plugin would parse the required
## CloudWatch statistic fields (count, min, max, and sum) and send them to CloudWatch.
## You could use basicstats aggregator to calculate those fields. If not all statistic
## fields are available, all fields would still be sent as raw metrics.
# write_statistics = false
## Enable high resolution metrics of 1 second (if not enabled, standard resolution are of 60 seconds precision)
Expand All @@ -198,7 +198,11 @@ func (c *CloudWatch) Description() string {
}

func (c *CloudWatch) Connect() error {
c.svc = cloudwatch.New(c.CredentialConfig.Credentials())
p, err := c.CredentialConfig.Credentials()
if err != nil {
return err
}
c.svc = cloudwatch.New(p)
return nil
}

Expand Down
12 changes: 8 additions & 4 deletions plugins/outputs/cloudwatch_logs/cloudwatch_logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ region = "us-east-1"
## Cloud watch log group. Must be created in AWS cloudwatch logs upfront!
## For example, you can specify the name of the k8s cluster here to group logs from all cluster in oine place
log_group = "my-group-name"
log_group = "my-group-name"
## Log stream in log group
## Either log group name or reference to metric attribute, from which it can be parsed:
## tag:<TAG_NAME> or field:<FIELD_NAME>. If log stream is not exist, it will be created.
## Since AWS is not automatically delete logs streams with expired logs entries (i.e. empty log stream)
## Since AWS is not automatically delete logs streams with expired logs entries (i.e. empty log stream)
## you need to put in place appropriate house-keeping (https://forums.aws.amazon.com/thread.jspa?threadID=178855)
log_stream = "tag:location"
Expand All @@ -126,7 +126,7 @@ log_data_metric_name = "docker_log"
## Specify from which metric attribute the log data should be retrieved:
## tag:<TAG_NAME> or field:<FIELD_NAME>.
## I.e., if you are using docker_log plugin to stream logs from container, then
## specify log_data_source = "field:message"
## specify log_data_source = "field:message"
log_data_source = "field:message"
`

Expand Down Expand Up @@ -187,7 +187,11 @@ func (c *CloudWatchLogs) Connect() error {
var logGroupsOutput = &cloudwatchlogs.DescribeLogGroupsOutput{NextToken: &dummyToken}
var err error

c.svc = cloudwatchlogs.New(c.CredentialConfig.Credentials())
p, err := c.CredentialConfig.Credentials()
if err != nil {
return err
}
c.svc = cloudwatchlogs.New(p)
if c.svc == nil {
return fmt.Errorf("can't create cloudwatch logs service endpoint")
}
Expand Down
8 changes: 6 additions & 2 deletions plugins/outputs/kinesis/kinesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,13 @@ func (k *KinesisOutput) Connect() error {
k.Log.Infof("Establishing a connection to Kinesis in %s", k.Region)
}

svc := kinesis.New(k.CredentialConfig.Credentials())
p, err := k.CredentialConfig.Credentials()
if err != nil {
return err
}
svc := kinesis.New(p)

_, err := svc.DescribeStreamSummary(&kinesis.DescribeStreamSummaryInput{
_, err = svc.DescribeStreamSummary(&kinesis.DescribeStreamSummaryInput{
StreamName: aws.String(k.StreamName),
})
k.svc = svc
Expand Down
28 changes: 17 additions & 11 deletions plugins/outputs/timestream/timestream.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ const MaxRecordsPerCall = 100
var sampleConfig = `
## Amazon Region
region = "us-east-1"
## Amazon Credentials
## Credentials are loaded in the following order:
## 1) Web identity provider credentials via STS if role_arn and web_identity_token_file are specified
Expand All @@ -75,7 +75,7 @@ var sampleConfig = `
#role_session_name = ""
#profile = ""
#shared_credential_file = ""
## Endpoint to make request against, the correct endpoint is automatically
## determined and this option should only be set if you wish to override the
## default.
Expand All @@ -88,7 +88,7 @@ var sampleConfig = `
## Specifies if the plugin should describe the Timestream database upon starting
## to validate if it has access necessary permissions, connection, etc., as a safety check.
## If the describe operation fails, the plugin will not start
## If the describe operation fails, the plugin will not start
## and therefore the Telegraf agent will not start.
describe_database_on_start = false
Expand All @@ -97,17 +97,17 @@ var sampleConfig = `
## For example, consider the following data in line protocol format:
## weather,location=us-midwest,season=summer temperature=82,humidity=71 1465839830100400200
## airquality,location=us-west no2=5,pm25=16 1465839830100400200
## where weather and airquality are the measurement names, location and season are tags,
## where weather and airquality are the measurement names, location and season are tags,
## and temperature, humidity, no2, pm25 are fields.
## In multi-table mode:
## - first line will be ingested to table named weather
## - second line will be ingested to table named airquality
## - the tags will be represented as dimensions
## - first table (weather) will have two records:
## one with measurement name equals to temperature,
## one with measurement name equals to temperature,
## another with measurement name equals to humidity
## - second table (airquality) will have two records:
## one with measurement name equals to no2,
## one with measurement name equals to no2,
## another with measurement name equals to pm25
## - the Timestream tables from the example will look like this:
## TABLE "weather":
Expand Down Expand Up @@ -141,7 +141,7 @@ var sampleConfig = `
## Specifies the Timestream table where the metrics will be uploaded.
# single_table_name = "yourTableNameHere"
## Only valid and required for mapping_mode = "single-table"
## Only valid and required for mapping_mode = "single-table"
## Describes what will be the Timestream dimension name for the Telegraf
## measurement name.
# single_table_dimension_name_for_telegraf_measurement_name = "namespace"
Expand Down Expand Up @@ -169,9 +169,12 @@ var sampleConfig = `
`

// WriteFactory function provides a way to mock the client instantiation for testing purposes.
var WriteFactory = func(credentialConfig *internalaws.CredentialConfig) WriteClient {
configProvider := credentialConfig.Credentials()
return timestreamwrite.New(configProvider)
var WriteFactory = func(credentialConfig *internalaws.CredentialConfig) (WriteClient, error) {
configProvider, err := credentialConfig.Credentials()
if err != nil {
return nil, err
}
return timestreamwrite.New(configProvider), nil
}

func (t *Timestream) Connect() error {
Expand Down Expand Up @@ -221,7 +224,10 @@ func (t *Timestream) Connect() error {

t.Log.Infof("Constructing Timestream client for '%s' mode", t.MappingMode)

svc := WriteFactory(&t.CredentialConfig)
svc, err := WriteFactory(&t.CredentialConfig)
if err != nil {
return err
}

if t.DescribeDatabaseOnStart {
t.Log.Infof("Describing database '%s' in region '%s'", t.DatabaseName, t.Region)
Expand Down
16 changes: 8 additions & 8 deletions plugins/outputs/timestream/timestream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package timestream_test

import (
"fmt"
"github.com/aws/aws-sdk-go/aws/awserr"
"reflect"
"sort"
"strconv"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/timestreamwrite"
"github.com/influxdata/telegraf"
Expand Down Expand Up @@ -53,10 +54,9 @@ func (m *mockTimestreamClient) DescribeDatabase(*timestreamwrite.DescribeDatabas

func TestConnectValidatesConfigParameters(t *testing.T) {
assertions := assert.New(t)
ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) ts.WriteClient {
return &mockTimestreamClient{}
ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) (ts.WriteClient, error) {
return &mockTimestreamClient{}, nil
}

// checking base arguments
noDatabaseName := ts.Timestream{Log: testutil.Logger{}}
assertions.Contains(noDatabaseName.Connect().Error(), "DatabaseName")
Expand Down Expand Up @@ -182,11 +182,11 @@ func (m *mockTimestreamErrorClient) DescribeDatabase(*timestreamwrite.DescribeDa
func TestThrottlingErrorIsReturnedToTelegraf(t *testing.T) {
assertions := assert.New(t)

ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) ts.WriteClient {
ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) (ts.WriteClient, error) {
return &mockTimestreamErrorClient{
awserr.New(timestreamwrite.ErrCodeThrottlingException,
"Throttling Test", nil),
}
}, nil
}
plugin := ts.Timestream{
MappingMode: ts.MappingModeMultiTable,
Expand All @@ -210,11 +210,11 @@ func TestThrottlingErrorIsReturnedToTelegraf(t *testing.T) {
func TestRejectedRecordsErrorResultsInMetricsBeingSkipped(t *testing.T) {
assertions := assert.New(t)

ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) ts.WriteClient {
ts.WriteFactory = func(credentialConfig *internalaws.CredentialConfig) (ts.WriteClient, error) {
return &mockTimestreamErrorClient{
awserr.New(timestreamwrite.ErrCodeRejectedRecordsException,
"RejectedRecords Test", nil),
}
}, nil
}
plugin := ts.Timestream{
MappingMode: ts.MappingModeMultiTable,
Expand Down

0 comments on commit 068eb10

Please sign in to comment.