Skip to content

Commit

Permalink
Merge pull request #39 from danquack/dynamic-creds
Browse files Browse the repository at this point in the history
Allow for role based credentials
  • Loading branch information
joeirimpan authored Dec 12, 2023
2 parents 443868d + 9cb22a3 commit b731ef9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 24 deletions.
25 changes: 13 additions & 12 deletions messenger/pinpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,23 @@ func NewPinpoint(cfg []byte, l *onelog.Logger) (Messenger, error) {
if c.AppID == "" {
return nil, fmt.Errorf("invalid app_id")
}
if c.Region == "" {
return nil, fmt.Errorf("invalid region")

config := &aws.Config{
MaxRetries: aws.Int(3),
}
if c.AccessKey == "" {
return nil, fmt.Errorf("invalid access_key")
if c.AccessKey != "" && c.SecretKey != "" {
config.Credentials = credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")
}
if c.SecretKey == "" {
return nil, fmt.Errorf("invalid secret_key")
if c.Region != "" {
config.Region = &c.Region
}

sess := session.Must(session.NewSession())
svc := pinpoint.New(sess,
aws.NewConfig().
WithCredentials(credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")).
WithRegion(c.Region),
)
var sess = session.Must(session.NewSession(config))
err := checkCredentials(sess)
if err != nil {
return nil, err
}
svc := pinpoint.New(sess)

return pinpointMessenger{
client: svc,
Expand Down
34 changes: 22 additions & 12 deletions messenger/ses.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ses"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/francoispqt/onelog"
"github.com/knadh/smtppool"
)
Expand Down Expand Up @@ -106,30 +107,39 @@ func (s sesMessenger) Close() error {
return nil
}

func checkCredentials(sess *session.Session) error {
// Create a SES service client.
svc := sts.New(sess)
// Call the GetCallerIdentity API to check credentials
params := &sts.GetCallerIdentityInput{}
_, err := svc.GetCallerIdentity(params)
return err
}

// NewAWSSES creates new instance of pinpoint
func NewAWSSES(cfg []byte, l *onelog.Logger) (Messenger, error) {
var c sesCfg
if err := json.Unmarshal(cfg, &c); err != nil {
return nil, err
}

if c.Region == "" {
return nil, fmt.Errorf("invalid region")
config := &aws.Config{
MaxRetries: aws.Int(3),
}
if c.AccessKey == "" {
return nil, fmt.Errorf("invalid access_key")
if c.AccessKey != "" && c.SecretKey != "" {
config.Credentials = credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")
}
if c.SecretKey == "" {
return nil, fmt.Errorf("invalid secret_key")
if c.Region != "" {
config.Region = &c.Region
}

sess := session.Must(session.NewSession())
svc := ses.New(sess,
aws.NewConfig().
WithCredentials(credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")).
WithRegion(c.Region),
)
var sess = session.Must(session.NewSession(config))
err := checkCredentials(sess)
if err != nil {
return nil, err
}

svc := ses.New(sess)
return sesMessenger{
client: svc,
cfg: c,
Expand Down

0 comments on commit b731ef9

Please sign in to comment.