Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 36 additions & 48 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"context"
"encoding/json"
"errors"
"fmt"
"iter"
Expand All @@ -26,22 +25,25 @@ import (
type AWSBudgetPlugin struct {
Logger hclog.Logger

config *PluginConfig
config *PluginConfig
awsBudgetClient *budgets.Client
}

type Validator interface {
Validate() error
}


type PluginConfig struct {
AccountId string `mapstructure:"account_id"`
AwsAccessKeyId string `mapstructure:"aws_access_key_id"`
AccountId string `mapstructure:"account_id"`
AwsAccessKeyId string `mapstructure:"aws_access_key_id"`
AwsSecretAccessKey string `mapstructure:"aws_secret_access_key"`
AwsSessionToken string `mapstructure:"aws_session_token"`
AssumeRoleArn string `mapstructure:"assume_role_arn"`
AwsSessionToken string `mapstructure:"aws_session_token"`
AssumeRoleArn string `mapstructure:"assume_role_arn"`
}

type SaturatedBudget struct {
Budget *types.Budget
Alerts *[]types.Notification
}

func (c *PluginConfig) Validate() error {
Expand All @@ -55,14 +57,14 @@ func (c *PluginConfig) Validate() error {
func loadAWSConfig(ctx context.Context, pluginConfig *PluginConfig) (*aws.Config, error) {
var awsConfig aws.Config
var err error

if pluginConfig.AwsAccessKeyId != "" && pluginConfig.AwsSecretAccessKey != "" && pluginConfig.AwsSessionToken != "" {
// Use credentials if in config
creds := aws.NewCredentialsCache(
credentials.NewStaticCredentialsProvider(
pluginConfig.AwsAccessKeyId,
pluginConfig.AwsSecretAccessKey,
pluginConfig.AwsSessionToken,
credentials.NewStaticCredentialsProvider(
pluginConfig.AwsAccessKeyId,
pluginConfig.AwsSecretAccessKey,
pluginConfig.AwsSessionToken,
),
)
awsConfig, err = config.LoadDefaultConfig(ctx, config.WithRegion(os.Getenv("AWS_REGION")), config.WithCredentialsProvider(creds))
Expand Down Expand Up @@ -98,7 +100,6 @@ func loadAWSConfig(ctx context.Context, pluginConfig *PluginConfig) (*aws.Config
return &awsConfig, nil
}


func (l *AWSBudgetPlugin) Configure(req *proto.ConfigureRequest) (*proto.ConfigureResponse, error) {
l.Logger.Info("Configuring AWS Budget Plugin")
pluginConfig := &PluginConfig{}
Expand Down Expand Up @@ -158,25 +159,20 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH
break
}

alertCount := 0

for _, err := range getNotificationsForBudget(ctx, l.awsBudgetClient, &l.config.AccountId, budget.BudgetName) {
if err != nil {
l.Logger.Error("unable to get notification", "error", err)
evalStatus = proto.ExecutionStatus_FAILURE
accumulatedErrors = errors.Join(accumulatedErrors, err)
break
}
alertCount += 1
alerts, err := getNotificationsForBudget(ctx, l.awsBudgetClient, &l.config.AccountId, budget.BudgetName)
if err != nil {
l.Logger.Error("unable to get notifications", "error", err)
evalStatus = proto.ExecutionStatus_FAILURE
accumulatedErrors = errors.Join(accumulatedErrors, err)
break
}

labels := map[string]string{
"provider": "aws",
"provider": "aws",
"type": "budget",
"account-id": l.config.AccountId,
"budget-name": aws.ToString(budget.BudgetName),
}

actors := []*proto.OriginActor{
{
Title: "The Continuous Compliance Framework",
Expand Down Expand Up @@ -225,11 +221,11 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH
Value: aws.ToString(budget.BillingViewArn),
},
{
Name: "alert-count",
Value: fmt.Sprintf("%v", alertCount),
Name: "alert-count",
Value: fmt.Sprintf("%v", len(*alerts)),
},
{
Name: "health-status",
Name: "health-status",
Value: aws.ToString((*string)(&budget.HealthStatus.Status)),
},
},
Expand All @@ -253,10 +249,10 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH

evidences := make([]*proto.Evidence, 0)

b, _ := json.Marshal(budget)
var budgetMap map[string]interface{}
_ = json.Unmarshal(b, &budgetMap)
budgetMap["AlertCount"] = alertCount
data := &SaturatedBudget{
Budget: &budget,
Alerts: alerts,
}

for _, policyPath := range request.GetPolicyPaths() {

Expand All @@ -275,7 +271,8 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH
actors,
activities,
)
evidence, err := processor.GenerateResults(ctx, policyPath, budgetMap)
evidence, err := processor.GenerateResults(ctx, policyPath, data)
l.Logger.Info(fmt.Sprintf("Evidence: %v", evidence))
evidences = slices.Concat(evidences, evidence)
if err != nil {
accumulatedErrors = errors.Join(accumulatedErrors, err)
Expand All @@ -290,7 +287,6 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH
continue
}


}

return &proto.EvalResponse{
Expand All @@ -314,21 +310,13 @@ func getBudgets(ctx context.Context, client *budgets.Client, accountId *string)
}
}

func getNotificationsForBudget(ctx context.Context, client *budgets.Client, accountId *string, budgetName *string) iter.Seq2[types.Notification, error] {
return func(yield func(types.Notification, error) bool) {
result, err := client.DescribeNotificationsForBudget(ctx, &budgets.DescribeNotificationsForBudgetInput{AccountId: accountId, BudgetName: budgetName})
if err != nil {
yield(types.Notification{}, err)
return
}

for _, notification := range result.Notifications {
if !yield(notification, nil) {
return
}
}
func getNotificationsForBudget(ctx context.Context, client *budgets.Client, accountId *string, budgetName *string) (*[]types.Notification, error) {
result, err := client.DescribeNotificationsForBudget(ctx, &budgets.DescribeNotificationsForBudgetInput{AccountId: accountId, BudgetName: budgetName})
if err != nil {
return nil, err
}

return &result.Notifications, nil
}

func main() {
Expand All @@ -352,4 +340,4 @@ func main() {
},
GRPCServer: goplugin.DefaultGRPCServer,
})
}
}