Skip to content

Commit

Permalink
feat(misconf): API Gateway V1 support for CloudFormation (#6874)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikpivkin authored Jun 8, 2024
1 parent bb88937 commit 8491469
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 9 deletions.
7 changes: 4 additions & 3 deletions pkg/iac/adapters/cloudformation/aws/apigateway/apigateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (
func Adapt(cfFile parser.FileContext) apigateway.APIGateway {
return apigateway.APIGateway{
V1: v1.APIGateway{
APIs: nil,
DomainNames: nil,
APIs: adaptAPIsV1(cfFile),
DomainNames: adaptDomainNamesV1(cfFile),
},
V2: v2.APIGateway{
APIs: getApis(cfFile),
APIs: adaptAPIsV2(cfFile),
DomainNames: adaptDomainNamesV2(cfFile),
},
}
}
98 changes: 93 additions & 5 deletions pkg/iac/adapters/cloudformation/aws/apigateway/apigateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/aquasecurity/trivy/pkg/iac/adapters/cloudformation/testutil"
"github.com/aquasecurity/trivy/pkg/iac/providers/aws/apigateway"
v1 "github.com/aquasecurity/trivy/pkg/iac/providers/aws/apigateway/v1"
v2 "github.com/aquasecurity/trivy/pkg/iac/providers/aws/apigateway/v2"
"github.com/aquasecurity/trivy/pkg/iac/types"
)
Expand All @@ -19,24 +20,105 @@ func TestAdapt(t *testing.T) {
name: "complete",
source: `AWSTemplateFormatVersion: 2010-09-09
Resources:
MyApi:
MyRestApi:
Type: 'AWS::ApiGateway::RestApi'
Properties:
Description: A test API
Name: MyRestAPI
ApiResource:
Type: AWS::ApiGateway::Resource
Properties:
RestApiId: !Ref MyRestApi
MethodPOST:
Type: AWS::ApiGateway::Method
Properties:
RestApiId: !Ref MyRestApi
ResourceId: !Ref ApiResource
HttpMethod: POST
AuthorizationType: COGNITO_USER_POOLS
ApiKeyRequired: true
Stage:
Type: AWS::ApiGateway::Stage
Properties:
StageName: Prod
RestApiId: !Ref MyRestApi
TracingEnabled: true
AccessLogSetting:
DestinationArn: test-arn
MethodSettings:
- CacheDataEncrypted: true
CachingEnabled: true
HttpMethod: POST
MyDomainName:
Type: AWS::ApiGateway::DomainName
Properties:
DomainName: mydomainame.us-east-1.com
SecurityPolicy: "TLS_1_2"
MyApi2:
Type: 'AWS::ApiGatewayV2::Api'
Properties:
Name: MyApi
Name: MyApi2
ProtocolType: WEBSOCKET
MyStage:
MyStage2:
Type: 'AWS::ApiGatewayV2::Stage'
Properties:
StageName: Prod
ApiId: !Ref MyApi
ApiId: !Ref MyApi2
AccessLogSettings:
DestinationArn: some-arn
MyDomainName2:
Type: 'AWS::ApiGatewayV2::DomainName'
Properties:
DomainName: mydomainame.us-east-1.com
DomainNameConfigurations:
- SecurityPolicy: "TLS_1_2"
`,
expected: apigateway.APIGateway{
V1: v1.APIGateway{
APIs: []v1.API{
{
Name: types.StringTest("MyRestAPI"),
Stages: []v1.Stage{
{
Name: types.StringTest("Prod"),
XRayTracingEnabled: types.BoolTest(true),
AccessLogging: v1.AccessLogging{
CloudwatchLogGroupARN: types.StringTest("test-arn"),
},
RESTMethodSettings: []v1.RESTMethodSettings{
{
Method: types.StringTest("POST"),
CacheDataEncrypted: types.BoolTest(true),
CacheEnabled: types.BoolTest(true),
},
},
},
},
Resources: []v1.Resource{
{
Methods: []v1.Method{
{
HTTPMethod: types.StringTest("POST"),
AuthorizationType: types.StringTest("COGNITO_USER_POOLS"),
APIKeyRequired: types.BoolTest(true),
},
},
},
},
},
},
DomainNames: []v1.DomainName{
{
Name: types.StringTest("mydomainame.us-east-1.com"),
SecurityPolicy: types.StringTest("TLS_1_2"),
},
},
},
V2: v2.APIGateway{
APIs: []v2.API{
{
Name: types.StringTest("MyApi"),
Name: types.StringTest("MyApi2"),
ProtocolType: types.StringTest("WEBSOCKET"),
Stages: []v2.Stage{
{
Expand All @@ -48,6 +130,12 @@ Resources:
},
},
},
DomainNames: []v2.DomainName{
{
Name: types.StringTest("mydomainame.us-east-1.com"),
SecurityPolicy: types.StringTest("TLS_1_2"),
},
},
},
},
},
Expand Down
108 changes: 108 additions & 0 deletions pkg/iac/adapters/cloudformation/aws/apigateway/apiv1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package apigateway

import (
v1 "github.com/aquasecurity/trivy/pkg/iac/providers/aws/apigateway/v1"
"github.com/aquasecurity/trivy/pkg/iac/scanners/cloudformation/parser"
)

func adaptAPIsV1(fctx parser.FileContext) []v1.API {
var apis []v1.API

stages := make(map[string]*parser.Resource)
for _, stageResource := range fctx.GetResourcesByType("AWS::ApiGateway::Stage") {
restApiID := stageResource.GetStringProperty("RestApiId")
if restApiID.IsEmpty() {
continue
}

stages[restApiID.Value()] = stageResource
}

resources := make(map[string]*parser.Resource)
for _, resource := range fctx.GetResourcesByType("AWS::ApiGateway::Resource") {
restApiID := resource.GetStringProperty("RestApiId")
if restApiID.IsEmpty() {
continue
}

resources[restApiID.Value()] = resource
}

for _, apiResource := range fctx.GetResourcesByType("AWS::ApiGateway::RestApi") {

api := v1.API{
Metadata: apiResource.Metadata(),
Name: apiResource.GetStringProperty("Name"),
}

if stageResource, exists := stages[apiResource.ID()]; exists {
stage := v1.Stage{
Metadata: stageResource.Metadata(),
Name: stageResource.GetStringProperty("StageName"),
XRayTracingEnabled: stageResource.GetBoolProperty("TracingEnabled"),
}

if logSetting := stageResource.GetProperty("AccessLogSetting"); logSetting.IsNotNil() {
stage.AccessLogging = v1.AccessLogging{
Metadata: logSetting.Metadata(),
CloudwatchLogGroupARN: logSetting.GetStringProperty("DestinationArn"),
}
}

if methodSettings := stageResource.GetProperty("MethodSettings"); methodSettings.IsList() {
for _, methodSetting := range methodSettings.AsList() {
stage.RESTMethodSettings = append(stage.RESTMethodSettings, v1.RESTMethodSettings{
Metadata: methodSetting.Metadata(),
Method: methodSetting.GetStringProperty("HttpMethod"),
CacheDataEncrypted: methodSetting.GetBoolProperty("CacheDataEncrypted"),
CacheEnabled: methodSetting.GetBoolProperty("CachingEnabled"),
})
}
}

api.Stages = append(api.Stages, stage)
}

if resource, exists := resources[apiResource.ID()]; exists {
res := v1.Resource{
Metadata: resource.Metadata(),
}

for _, methodResource := range fctx.GetResourcesByType("AWS::ApiGateway::Method") {
resourceID := methodResource.GetStringProperty("ResourceId")
// TODO: handle RootResourceId
if resourceID.Value() != resource.ID() {
continue
}

res.Methods = append(res.Methods, v1.Method{
Metadata: methodResource.Metadata(),
HTTPMethod: methodResource.GetStringProperty("HttpMethod"),
AuthorizationType: methodResource.GetStringProperty("AuthorizationType"),
APIKeyRequired: methodResource.GetBoolProperty("ApiKeyRequired"),
})

}

api.Resources = append(api.Resources, res)
}

apis = append(apis, api)
}

return apis
}

func adaptDomainNamesV1(fctx parser.FileContext) []v1.DomainName {
var domainNames []v1.DomainName

for _, domainNameResource := range fctx.GetResourcesByType("AWS::ApiGateway::DomainName") {
domainNames = append(domainNames, v1.DomainName{
Metadata: domainNameResource.Metadata(),
Name: domainNameResource.GetStringProperty("DomainName"),
SecurityPolicy: domainNameResource.GetStringProperty("SecurityPolicy"),
})
}

return domainNames
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/types"
)

func getApis(cfFile parser.FileContext) (apis []v2.API) {
func adaptAPIsV2(cfFile parser.FileContext) (apis []v2.API) {

apiResources := cfFile.GetResourcesByType("AWS::ApiGatewayV2::Api")
for _, apiRes := range apiResources {
Expand Down Expand Up @@ -66,3 +66,26 @@ func getAccessLogging(r *parser.Resource) v2.AccessLogging {
CloudwatchLogGroupARN: destinationProp.AsStringValue(),
}
}

func adaptDomainNamesV2(fctx parser.FileContext) []v2.DomainName {
var domainNames []v2.DomainName

for _, domainNameResource := range fctx.GetResourcesByType("AWS::ApiGateway::DomainName") {

domainName := v2.DomainName{
Metadata: domainNameResource.Metadata(),
Name: domainNameResource.GetStringProperty("DomainName"),
SecurityPolicy: domainNameResource.GetStringProperty("SecurityPolicy"),
}

if domainNameCfgs := domainNameResource.GetProperty("DomainNameConfigurations"); domainNameCfgs.IsList() {
for _, domainNameCfg := range domainNameCfgs.AsList() {
domainName.SecurityPolicy = domainNameCfg.GetStringProperty("SecurityPolicy")
}
}

domainNames = append(domainNames, domainName)
}

return domainNames
}

0 comments on commit 8491469

Please sign in to comment.