Skip to content

Commit ffed2da

Browse files
author
Divjot Arora
committed
Use testing library for client side encryption spec tests.
GODRIVER-1270 Change-Id: I9fcc5e280e11ef8a4fa1b4923b90807646b00df7
1 parent 8ec63ac commit ffed2da

10 files changed

+337
-1209
lines changed

mongo/client_side_encryption_spec_test.go

Lines changed: 0 additions & 1130 deletions
This file was deleted.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
// +build cse
8+
9+
package integration
10+
11+
import (
12+
"path"
13+
"testing"
14+
)
15+
16+
const (
17+
encryptionSpecName = "client-side-encryption"
18+
)
19+
20+
func TestClientSideEncryption(t *testing.T) {
21+
for _, fileName := range jsonFilesInDir(t, path.Join(dataPath, encryptionSpecName)) {
22+
t.Run(fileName, func(t *testing.T) {
23+
runSpecTestFile(t, encryptionSpecName, fileName)
24+
})
25+
}
26+
}

mongo/integration/cmd_monitoring_helpers_test.go

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package integration
88

99
import (
1010
"go.mongodb.org/mongo-driver/bson"
11+
"go.mongodb.org/mongo-driver/bson/bsontype"
1112
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1213
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
1314
"go.mongodb.org/mongo-driver/x/bsonx"
@@ -54,6 +55,13 @@ func compareValues(mt *mtest.T, key string, expected, actual bson.RawValue) {
5455
}
5556
case bson.TypeEmbeddedDocument:
5657
e := expected.Document()
58+
if typeVal, err := e.LookupErr("$$type"); err == nil {
59+
// $$type represents a type assertion
60+
// for example {field: {$$type: "binData"}} should assert that "field" is an element with a binary value
61+
assertType(mt, actual.Type, typeVal.StringValue())
62+
return
63+
}
64+
5765
a := actual.Document()
5866
compareDocs(mt, e, a)
5967
case bson.TypeArray:
@@ -66,6 +74,61 @@ func compareValues(mt *mtest.T, key string, expected, actual bson.RawValue) {
6674
}
6775
}
6876

77+
// helper for $$type assertions
78+
func assertType(mt *mtest.T, actual bsontype.Type, typeStr string) {
79+
mt.Helper()
80+
81+
var expected bsontype.Type
82+
switch typeStr {
83+
case "double":
84+
expected = bsontype.Double
85+
case "string":
86+
expected = bsontype.String
87+
case "object":
88+
expected = bsontype.EmbeddedDocument
89+
case "array":
90+
expected = bsontype.Array
91+
case "binData":
92+
expected = bsontype.Binary
93+
case "undefined":
94+
expected = bsontype.Undefined
95+
case "objectId":
96+
expected = bsontype.ObjectID
97+
case "boolean":
98+
expected = bsontype.Boolean
99+
case "date":
100+
expected = bsontype.DateTime
101+
case "null":
102+
expected = bsontype.Null
103+
case "regex":
104+
expected = bsontype.Regex
105+
case "dbPointer":
106+
expected = bsontype.DBPointer
107+
case "javascript":
108+
expected = bsontype.JavaScript
109+
case "symbol":
110+
expected = bsontype.Symbol
111+
case "javascriptWithScope":
112+
expected = bsontype.CodeWithScope
113+
case "int":
114+
expected = bsontype.Int32
115+
case "timestamp":
116+
expected = bsontype.Timestamp
117+
case "long":
118+
expected = bsontype.Int64
119+
case "decimal":
120+
expected = bsontype.Decimal128
121+
case "minKey":
122+
expected = bsontype.MinKey
123+
case "maxKey":
124+
expected = bsontype.MaxKey
125+
default:
126+
mt.Fatalf("unrecognized type string: %v", typeStr)
127+
}
128+
129+
assert.Equal(mt, expected, actual, "BSON type mismatch; expected %v, got %v", expected, actual)
130+
}
131+
69132
// compare expected and actual BSON documents. comparison succeeds if actual contains each element in expected.
70133
func compareDocs(mt *mtest.T, expected, actual bson.Raw) {
71134
mt.Helper()
@@ -80,6 +143,12 @@ func compareDocs(mt *mtest.T, expected, actual bson.Raw) {
80143

81144
eVal := e.Value()
82145
if doc, ok := eVal.DocumentOK(); ok {
146+
// special $$type assertion
147+
if typeVal, err := doc.LookupErr("$$type"); err == nil {
148+
assertType(mt, aVal.Type, typeVal.StringValue())
149+
continue
150+
}
151+
83152
// nested doc
84153
compareDocs(mt, doc, aVal.Document())
85154
continue
@@ -124,8 +193,9 @@ func checkExpectations(mt *mtest.T, expectations []*expectation, id0 bsonx.Doc,
124193
assert.Equal(mt, bson.RawValue{}, actualVal, "expected value for key %s to be nil but got %v", key, actualVal)
125194
continue
126195
}
127-
if key == "ordered" {
196+
if key == "ordered" || key == "cursor" {
128197
// TODO: some tests specify that "ordered" must be a key in the event but ordered isn't a valid option for some of these cases (e.g. insertOne)
198+
// TODO: some FLE tests specify "cursor" subdocument for listCollections
129199
continue
130200
}
131201

@@ -153,12 +223,17 @@ func checkExpectations(mt *mtest.T, expectations []*expectation, id0 bsonx.Doc,
153223
assert.Equal(mt, expectedID, actualID,
154224
"session ID mismatch for session %v; expected %v, got %v", sessName, expectedID, actualID)
155225
case "getMore":
156-
expectedID := val.Int64()
157-
// ignore placeholder cursor ID (42)
158-
if expectedID != 42 {
226+
expectedID, ok := val.Int64OK()
227+
if ok {
228+
// ignore placeholder ID (42)
229+
if expectedID == 42 {
230+
continue
231+
}
159232
actualID := actualVal.Int64()
160-
assert.Equal(mt, expectedID, actualID, "cursor ID mismatch; expected %v, got %v", expectedID, actualID)
233+
assert.Equal(mt, expectedID, actualID, "expected cursor ID; expected %v, got %v", expectedID, actualID)
234+
continue
161235
}
236+
compareValues(mt, key, val, actualVal)
162237
case "readConcern":
163238
expectedRc := val.Document()
164239
actualRc := actualVal.Document()

mongo/integration/collection_test.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,10 @@ func TestCollection(t *testing.T) {
179179
})
180180
mt.Run("write error", func(mt *mtest.T) {
181181
cappedOpts := bson.D{{"capped", true}, {"size", 64 * 1024}}
182-
capped := mt.CreateCollectionWithOptions("deleteOne_capped", cappedOpts)
182+
capped := mt.CreateCollection(mtest.Collection{
183+
Name: "deleteOne_capped",
184+
CreateOpts: cappedOpts,
185+
}, true)
183186
_, err := capped.DeleteOne(mtest.Background, bson.D{{"x", 1}})
184187

185188
we, ok := err.(mongo.WriteException)
@@ -225,7 +228,10 @@ func TestCollection(t *testing.T) {
225228
})
226229
mt.Run("write error", func(mt *mtest.T) {
227230
cappedOpts := bson.D{{"capped", true}, {"size", 64 * 1024}}
228-
capped := mt.CreateCollectionWithOptions("deleteMany_capped", cappedOpts)
231+
capped := mt.CreateCollection(mtest.Collection{
232+
Name: "deleteMany_capped",
233+
CreateOpts: cappedOpts,
234+
}, true)
229235
_, err := capped.DeleteMany(mtest.Background, bson.D{{"x", 1}})
230236

231237
we, ok := err.(mongo.WriteException)
@@ -882,7 +888,10 @@ func TestCollection(t *testing.T) {
882888
doc := mongo.NewDeleteOneModel().SetFilter(bson.D{{"x", 1}})
883889
models := []mongo.WriteModel{doc, doc}
884890
cappedOpts := bson.D{{"capped", true}, {"size", 64 * 1024}}
885-
capped := mt.CreateCollectionWithOptions("delete_write_errors", cappedOpts)
891+
capped := mt.CreateCollection(mtest.Collection{
892+
Name: "delete_write_errors",
893+
CreateOpts: cappedOpts,
894+
}, true)
886895

887896
testCases := []struct {
888897
name string

mongo/integration/crud_helpers_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ func createUpdate(mt *mtest.T, updateVal bson.RawValue) interface{} {
4848
return nil
4949
}
5050

51+
// kill all open sessions on the server. This function uses mt.GlobalClient() because killAllSessions is not allowed
52+
// for clients configured with specific options (e.g. client side encryption).
5153
func killSessions(mt *mtest.T) {
5254
mt.Helper()
5355

54-
err := mt.Client.Database("admin").RunCommand(mtest.Background, bson.D{
56+
err := mt.GlobalClient().Database("admin").RunCommand(mtest.Background, bson.D{
5557
{"killAllSessions", bson.A{}},
5658
}, options.RunCmd().SetReadPreference(mtest.PrimaryRp)).Err()
5759
if err == nil {

mongo/integration/database_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func TestDatabase(t *testing.T) {
7070
lcNamesOpts := mtest.NewOptions().MinServerVersion("4.0")
7171
mt.RunOpts("list collection names", lcNamesOpts, func(mt *mtest.T) {
7272
collName := "lcNamesCollection"
73-
mt.CreateCollection(collName, true)
73+
mt.CreateCollection(mtest.Collection{Name: collName}, true)
7474

7575
testCases := []struct {
7676
name string
@@ -115,11 +115,11 @@ func TestDatabase(t *testing.T) {
115115
for _, tc := range testCases {
116116
tcOpts := mtest.NewOptions().Topologies(tc.expectedTopology)
117117
mt.RunOpts(tc.name, tcOpts, func(mt *mtest.T) {
118-
mt.CreateCollection(listCollUncapped, true)
119-
mt.CreateCollectionWithOptions(listCollCapped, bson.D{
120-
{"capped", true},
121-
{"size", 64 * 1024},
122-
})
118+
mt.CreateCollection(mtest.Collection{Name: listCollUncapped}, true)
119+
mt.CreateCollection(mtest.Collection{
120+
Name: listCollCapped,
121+
CreateOpts: bson.D{{"capped", true}, {"size", 64 * 1024}},
122+
}, true)
123123

124124
filter := bson.D{}
125125
if tc.cappedOnly {

mongo/integration/json_helpers_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package integration
99
import (
1010
"io/ioutil"
1111
"math"
12+
"os"
1213
"path"
1314
"strings"
1415
"testing"
@@ -23,6 +24,11 @@ import (
2324
"go.mongodb.org/mongo-driver/mongo/writeconcern"
2425
)
2526

27+
const (
28+
awsAccessKeyID = "AWS_ACCESS_KEY_ID"
29+
awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
30+
)
31+
2632
// Helper functions to do read JSON spec test files and convert JSON objects into the appropriate driver types.
2733
// Functions in this file should take testing.TB rather than testing.T/mtest.T for generality because they
2834
// do not do any database communication.
@@ -82,6 +88,8 @@ func createClientOptions(t testing.TB, opts bson.Raw) *options.ClientOptions {
8288
clientOpts.SetHeartbeatInterval(hfms)
8389
case "retryReads":
8490
clientOpts.SetRetryReads(opt.Boolean())
91+
case "autoEncryptOpts":
92+
clientOpts.SetAutoEncryptionOptions(createAutoEncryptionOptions(t, opt.Document()))
8593
default:
8694
t.Fatalf("unrecognized client option: %v", name)
8795
}
@@ -90,6 +98,87 @@ func createClientOptions(t testing.TB, opts bson.Raw) *options.ClientOptions {
9098
return clientOpts
9199
}
92100

101+
func createAutoEncryptionOptions(t testing.TB, opts bson.Raw) *options.AutoEncryptionOptions {
102+
t.Helper()
103+
104+
aeo := options.AutoEncryption()
105+
var kvnsFound bool
106+
elems, _ := opts.Elements()
107+
108+
for _, elem := range elems {
109+
name := elem.Key()
110+
opt := elem.Value()
111+
112+
switch name {
113+
case "kmsProviders":
114+
aeo.SetKmsProviders(createKmsProvidersMap(t, opt.Document()))
115+
case "schemaMap":
116+
var schemaMap map[string]interface{}
117+
err := bson.Unmarshal(opt.Document(), &schemaMap)
118+
if err != nil {
119+
t.Fatalf("error creating schema map: %v", err)
120+
}
121+
122+
aeo.SetSchemaMap(schemaMap)
123+
case "keyVaultNamespace":
124+
kvnsFound = true
125+
aeo.SetKeyVaultNamespace(opt.StringValue())
126+
case "bypassAutoEncryption":
127+
aeo.SetBypassAutoEncryption(opt.Boolean())
128+
default:
129+
t.Fatalf("unrecognized auto encryption option: %v", name)
130+
}
131+
}
132+
if !kvnsFound {
133+
aeo.SetKeyVaultNamespace("admin.datakeys")
134+
}
135+
136+
return aeo
137+
}
138+
139+
func createKmsProvidersMap(t testing.TB, opts bson.Raw) map[string]map[string]interface{} {
140+
t.Helper()
141+
142+
// aws: value is always empty object. create new map value from access key ID and secret access key
143+
// local: value is {"key": primitive.Binary}. transform to {"key": []byte}
144+
145+
kmsMap := make(map[string]map[string]interface{})
146+
elems, _ := opts.Elements()
147+
148+
for _, elem := range elems {
149+
provider := elem.Key()
150+
providerOpt := elem.Value()
151+
152+
switch provider {
153+
case "aws":
154+
keyID := os.Getenv(awsAccessKeyID)
155+
if keyID == "" {
156+
t.Fatalf("%s env var not set", awsAccessKeyID)
157+
}
158+
secretAccessKey := os.Getenv(awsSecretAccessKey)
159+
if secretAccessKey == "" {
160+
t.Fatalf("%s env var not set", awsSecretAccessKey)
161+
}
162+
163+
awsMap := map[string]interface{}{
164+
"accessKeyId": keyID,
165+
"secretAccessKey": secretAccessKey,
166+
}
167+
kmsMap["aws"] = awsMap
168+
case "local":
169+
_, key := providerOpt.Document().Lookup("key").Binary()
170+
localMap := map[string]interface{}{
171+
"key": key,
172+
}
173+
kmsMap["local"] = localMap
174+
default:
175+
t.Fatalf("unrecognized KMS provider: %v", provider)
176+
}
177+
}
178+
179+
return kmsMap
180+
}
181+
93182
// create session options from a map
94183
func createSessionOptions(t testing.TB, opts bson.Raw) *options.SessionOptions {
95184
t.Helper()

0 commit comments

Comments
 (0)