Skip to content

Commit

Permalink
Merge pull request #39878 from hashicorp/b-wafv2-rule_json_error
Browse files Browse the repository at this point in the history
r/wafv2_web_acl: fix `rule_json` error when unmarshalling
  • Loading branch information
johnsonaj authored Oct 25, 2024
2 parents 2a778ec + 8961fd2 commit 7eaac40
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .changelog/39878.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
resource/aws_wafv2_web_acl: Fix unmarshal error for incompatible types in `rule_json`
```
70 changes: 66 additions & 4 deletions internal/service/wafv2/flex.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package wafv2

import (
"encoding/base64"
"encoding/json"
"fmt"
"reflect"
Expand Down Expand Up @@ -983,21 +984,82 @@ func expandHeaderMatchPattern(l []interface{}) *awstypes.HeaderMatchPattern {
}

func expandWebACLRulesJSON(rawRules string) ([]awstypes.Rule, error) {
var rules []awstypes.Rule

err := json.Unmarshal([]byte(rawRules), &rules)
var temp []any
err := json.Unmarshal([]byte(rawRules), &temp)
if err != nil {
return nil, fmt.Errorf("decoding JSON: %s", err)
}

for _, v := range temp {
walkWebACLJSON(reflect.ValueOf(v))
}

out, err := json.Marshal(temp)
if err != nil {
return nil, err
}

var rules []awstypes.Rule
err = json.Unmarshal(out, &rules)
if err != nil {
return nil, err
}

for i, r := range rules {
if reflect.DeepEqual(r, awstypes.Rule{}) {
if reflect.ValueOf(r).IsZero() {
return nil, fmt.Errorf("invalid ACL Rule supplied at index (%d)", i)
}
}
return rules, nil
}

func walkWebACLJSON(v reflect.Value) {
m := map[string][]struct {
key string
outputType any
}{
"ByteMatchStatement": {
{key: "SearchString", outputType: []byte{}},
},
}

for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
v = v.Elem()
}

switch v.Kind() {
case reflect.Map:
for _, k := range v.MapKeys() {
if val, ok := m[k.String()]; ok {
st := v.MapIndex(k).Interface().(map[string]any)
for _, va := range val {
if st[va.key] == nil {
continue
}
str := st[va.key]
switch reflect.ValueOf(va.outputType).Kind() {
case reflect.Slice, reflect.Array:
switch reflect.ValueOf(va.outputType).Type().Elem().Kind() {
case reflect.Uint8:
base64String := base64.StdEncoding.EncodeToString([]byte(str.(string)))
st[va.key] = base64String
default:
}
default:
}
}
} else {
walkWebACLJSON(v.MapIndex(k))
}
}
case reflect.Array, reflect.Slice:
for i := 0; i < v.Len(); i++ {
walkWebACLJSON(v.Index(i))
}
default:
}
}

func expandWebACLRules(l []interface{}) []awstypes.Rule {
if len(l) == 0 || l[0] == nil {
return nil
Expand Down
61 changes: 61 additions & 0 deletions internal/service/wafv2/flex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,61 @@ func Test_expandWebACLRulesJSON(t *testing.T) {
rawRules: `[{"Action":{"Count":{}},"Name":"rule-1","Priority":1,"Statement":{"RateBasedStatement":{"AggregateKeyType":"IP","EvaluationWindowSec":600,"Limit":10000,"ScopeDownStatement":{"GeoMatchStatement":{"CountryCodes":["US","NL"]}}}},"VisibilityConfig":{"CloudwatchMetricsEnabled":false,"MetricName":"friendly-rule-metric-name","SampledRequestsEnabled":false}},{}]`,
wantErr: true,
},
"valid object SearchString": {
rawRules: `[{"Name" : "test_rule0","Priority":0,"Statement":{"AndStatement":{"Statements":[{"ByteMatchStatement":{"SearchString":"test","FieldToMatch":{"SingleHeader":{"Name":"host"}},"TextTransformations":[{"Priority":0,"Type":"NONE"}],"PositionalConstraint":"EXACTLY"}}]},"ByteMatchStatement":{"SearchString":"test","FieldToMatch":{"SingleHeader":{"Name":"host"}},"TextTransformations":[{"Priority":0,"Type":"NONE"}],"PositionalConstraint":"EXACTLY"}},"Action":{"Block":{}},"VisibilityConfig":{"SampledRequestsEnabled":true,"CloudWatchMetricsEnabled":true,"MetricName":"test_rule0"}}]`,
want: []awstypes.Rule{
{
Name: aws.String("test_rule0"),
Priority: 0,
Action: &awstypes.RuleAction{
Block: &awstypes.BlockAction{},
},
VisibilityConfig: &awstypes.VisibilityConfig{
SampledRequestsEnabled: true,
CloudWatchMetricsEnabled: true,
MetricName: aws.String("test_rule0"),
},
Statement: &awstypes.Statement{
AndStatement: &awstypes.AndStatement{
Statements: []awstypes.Statement{
{
ByteMatchStatement: &awstypes.ByteMatchStatement{
SearchString: []byte("test"),
FieldToMatch: &awstypes.FieldToMatch{
SingleHeader: &awstypes.SingleHeader{
Name: aws.String("host"),
},
},
TextTransformations: []awstypes.TextTransformation{
{
Priority: 0,
Type: awstypes.TextTransformationType("NONE"),
},
},
PositionalConstraint: awstypes.PositionalConstraint("EXACTLY"),
},
},
},
},
ByteMatchStatement: &awstypes.ByteMatchStatement{
SearchString: []byte("test"),
FieldToMatch: &awstypes.FieldToMatch{
SingleHeader: &awstypes.SingleHeader{
Name: aws.String("host"),
},
},
TextTransformations: []awstypes.TextTransformation{
{
Priority: 0,
Type: awstypes.TextTransformationType("NONE"),
},
},
PositionalConstraint: awstypes.PositionalConstraint("EXACTLY"),
},
},
},
},
},
}

ignoreExportedOpts := cmpopts.IgnoreUnexported(
Expand All @@ -79,6 +134,12 @@ func Test_expandWebACLRulesJSON(t *testing.T) {
awstypes.RateBasedStatement{},
awstypes.GeoMatchStatement{},
awstypes.VisibilityConfig{},
awstypes.SingleHeader{},
awstypes.ByteMatchStatement{},
awstypes.FieldToMatch{},
awstypes.TextTransformation{},
awstypes.BlockAction{},
awstypes.AndStatement{},
)

for name, tc := range testCases {
Expand Down
85 changes: 85 additions & 0 deletions internal/service/wafv2/web_acl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3074,6 +3074,14 @@ func TestAccWAFV2WebACL_ruleJSON(t *testing.T) {
ImportStateVerifyIgnore: []string{"rule_json"},
ImportStateIdFunc: testAccWebACLImportStateIdFunc(resourceName),
},
{
Config: testAccWebACLConfig_JSONruleUpdate(webACLName),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckWebACLExists(ctx, resourceName, &v),
acctest.MatchResourceAttrRegionalARN(resourceName, names.AttrARN, "wafv2", regexache.MustCompile(`regional/webacl/.+$`)),
resource.TestCheckResourceAttrSet(resourceName, "rule_json"),
),
},
},
})
}
Expand Down Expand Up @@ -6160,3 +6168,80 @@ resource "aws_wafv2_web_acl" "test" {
}
`, rName)
}

func testAccWebACLConfig_JSONruleUpdate(rName string) string {
return fmt.Sprintf(`
resource "aws_wafv2_web_acl" "test" {
name = %[1]q
description = %[1]q
scope = "REGIONAL"
default_action {
allow {}
}
visibility_config {
cloudwatch_metrics_enabled = false
metric_name = "friendly-metric-name"
sampled_requests_enabled = false
}
rule_json = jsonencode([
{
Name = "rule-1",
Priority = 1,
Action = {
Count = {}
},
Statement = {
RateBasedStatement = {
Limit = 10000,
AggregateKeyType = "IP",
EvaluationWindowSec = 600,
ScopeDownStatement = {
GeoMatchStatement = {
CountryCodes = ["US", "NL"]
},
},
},
},
VisibilityConfig = {
CloudwatchMetricsEnabled = false,
MetricName = "test-metric-name",
SampledRequestsEnabled = false,
},
},
{
"Name" : "test_rule0",
"Priority" : 0,
"Statement" : {
"ByteMatchStatement" : {
"SearchString" : "test",
"FieldToMatch" : {
"SingleHeader" : {
"Name" : "host"
}
},
"TextTransformations" : [
{
"Priority" : 0,
"Type" : "NONE"
}
],
"PositionalConstraint" : "EXACTLY"
}
},
"Action" : {
"Block" : {}
},
"VisibilityConfig" : {
"SampledRequestsEnabled" : true,
"CloudWatchMetricsEnabled" : true,
"MetricName" : "test_rule0"
}
}
])
}
`, rName)
}

0 comments on commit 7eaac40

Please sign in to comment.