diff --git a/.github/docker-compose.yml b/.github/docker-compose.yml new file mode 100644 index 0000000..f06f00a --- /dev/null +++ b/.github/docker-compose.yml @@ -0,0 +1,8 @@ +version: '3' + +services: + dynamodb: + image: amazon/dynamodb-local:latest + ports: + - "8880:8000" + command: "-jar DynamoDBLocal.jar -sharedDb -inMemory" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..f90120b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,25 @@ +name: CI + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: 'stable' + - name: Start DynamoDB Local + run: docker compose -f '.github/docker-compose.yml' up -d + - name: Test + run: go test -v -race -cover -coverpkg=./... ./... + env: + DYNAMO_TEST_ENDPOINT: 'http://localhost:8880' + DYNAMO_TEST_REGION: local + DYNAMO_TEST_TABLE: 'TestDB-%' + AWS_ACCESS_KEY_ID: dummy + AWS_SECRET_ACCESS_KEY: dummy + AWS_REGION: local diff --git a/README.md b/README.md index b4ac10c..8cf6fa5 100644 --- a/README.md +++ b/README.md @@ -232,38 +232,23 @@ err := db.Table("Books").Get("ID", 555).One(dynamo.AWSEncoding(&someBook)) ### Integration tests -By default, tests are run in offline mode. Create a table called `TestDB`, with a number partition key called `UserID` and a string sort key called `Time`. It also needs a Global Secondary Index called `Msg-Time-index` with a string partition key called `Msg` and a string sort key called `Time`. +By default, tests are run in offline mode. In order to run the integration tests, some environment variables need to be set. -Change the table name with the environment variable `DYNAMO_TEST_TABLE`. You must specify `DYNAMO_TEST_REGION`, setting it to the AWS region where your test table is. - - - ```bash -DYNAMO_TEST_REGION=us-west-2 go test github.com/guregu/dynamo/... -cover - ``` - -If you want to use [DynamoDB Local](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DynamoDBLocal.html) to run local tests, specify `DYNAMO_TEST_ENDPOINT`. - - ```bash -DYNAMO_TEST_REGION=us-west-2 DYNAMO_TEST_ENDPOINT=http://localhost:8000 go test github.com/guregu/dynamo/... -cover - ``` - -Example of using [aws-cli](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Tools.CLI.html) to create a table for testing. +To run the tests against [DynamoDB Local](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DynamoDBLocal.html): ```bash -aws dynamodb create-table \ - --table-name TestDB \ - --attribute-definitions \ - AttributeName=UserID,AttributeType=N \ - AttributeName=Time,AttributeType=S \ - AttributeName=Msg,AttributeType=S \ - --key-schema \ - AttributeName=UserID,KeyType=HASH \ - AttributeName=Time,KeyType=RANGE \ - --global-secondary-indexes \ - IndexName=Msg-Time-index,KeySchema=[{'AttributeName=Msg,KeyType=HASH'},{'AttributeName=Time,KeyType=RANGE'}],Projection={'ProjectionType=ALL'} \ - --billing-mode PAY_PER_REQUEST \ - --region us-west-2 \ - --endpoint-url http://localhost:8000 # using DynamoDB local +# Use Docker to run DynamoDB local on port 8880 +docker compose -f '.github/docker-compose.yml' up -d + +# Run the tests with a fresh table +# The tables will be created automatically +DYNAMO_TEST_ENDPOINT='http://localhost:8880' \ + DYNAMO_TEST_REGION='local' \ + DYNAMO_TEST_TABLE='TestDB-%' \ # the % will be replaced the current timestamp + AWS_ACCESS_KEY_ID='dummy' \ + AWS_SECRET_ACCESS_KEY='dummy' \ + AWS_REGION='local' \ + go test -v -race ./... -cover -coverpkg=./... ``` ### License diff --git a/batch_test.go b/batch_test.go index 398e455..e18502d 100644 --- a/batch_test.go +++ b/batch_test.go @@ -11,7 +11,10 @@ func TestBatchGetWrite(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table1 := testDB.Table(testTableWidgets) + table2 := testDB.Table(testTableSprockets) + tables := []Table{table1, table2} + totalBatchSize := batchSize * len(tables) items := make([]interface{}, batchSize) widgets := make(map[int]widget) @@ -28,10 +31,19 @@ func TestBatchGetWrite(t *testing.T) { keys[i] = Keys{i, now} } + var batches []*BatchWrite + for _, table := range tables { + b := table.Batch().Write().Put(items...) + batches = append(batches, b) + } + batch1 := batches[0] + for _, b := range batches[1:] { + batch1.Merge(b) + } var wcc ConsumedCapacity - wrote, err := table.Batch().Write().Put(items...).ConsumedCapacity(&wcc).Run() - if wrote != batchSize { - t.Error("unexpected wrote:", wrote, "≠", batchSize) + wrote, err := batch1.ConsumedCapacity(&wcc).Run() + if wrote != totalBatchSize { + t.Error("unexpected wrote:", wrote, "≠", totalBatchSize) } if err != nil { t.Error("unexpected error:", err) @@ -41,20 +53,29 @@ func TestBatchGetWrite(t *testing.T) { } // get all - var results []widget + var gets []*BatchGet + for _, table := range tables { + b := table.Batch("UserID", "Time"). + Get(keys...). + Project("UserID", "Time"). + Consistent(true) + gets = append(gets, b) + } + var cc ConsumedCapacity - err = table.Batch("UserID", "Time"). - Get(keys...). - Project("UserID", "Time"). - Consistent(true). - ConsumedCapacity(&cc). - All(&results) + get1 := gets[0].ConsumedCapacity(&cc) + for _, b := range gets[1:] { + get1.Merge(b) + } + + var results []widget + err = get1.All(&results) if err != nil { t.Error("unexpected error:", err) } - if len(results) != batchSize { - t.Error("expected", batchSize, "results, got", len(results)) + if len(results) != totalBatchSize { + t.Error("expected", totalBatchSize, "results, got", len(results)) } if cc.Total == 0 { @@ -72,26 +93,31 @@ func TestBatchGetWrite(t *testing.T) { } // delete both - wrote, err = table.Batch("UserID", "Time").Write(). - Delete(keys...).Run() - if wrote != batchSize { - t.Error("unexpected wrote:", wrote, "≠", batchSize) + wrote, err = table1.Batch("UserID", "Time").Write(). + Delete(keys...). + DeleteInRange(table2, "UserID", "Time", keys...). + Run() + if wrote != totalBatchSize { + t.Error("unexpected wrote:", wrote, "≠", totalBatchSize) } if err != nil { t.Error("unexpected error:", err) } // get both again - results = nil - err = table.Batch("UserID", "Time"). - Get(keys...). - Consistent(true). - All(&results) - if err != ErrNotFound { - t.Error("expected ErrNotFound, got", err) - } - if len(results) != 0 { - t.Error("expected 0 results, got", len(results)) + { + var results []widget + err = table1.Batch("UserID", "Time"). + Get(keys...). + FromRange(table2, "UserID", "Time", keys...). + Consistent(true). + All(&results) + if err != ErrNotFound { + t.Error("expected ErrNotFound, got", err) + } + if len(results) != 0 { + t.Error("expected 0 results, got", len(results)) + } } } @@ -99,7 +125,7 @@ func TestBatchGetEmptySets(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) now := time.Now().UnixNano() / 1000000000 id := int(now) @@ -150,7 +176,7 @@ func TestBatchGetEmptySets(t *testing.T) { } func TestBatchEmptyInput(t *testing.T) { - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) var out []any err := table.Batch("UserID", "Time").Get().All(&out) if err != ErrNoInput { diff --git a/batchget.go b/batchget.go index c5c9cf5..5a26385 100644 --- a/batchget.go +++ b/batchget.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodb" @@ -42,12 +43,12 @@ func (table Table) Batch(hashAndRangeKeyName ...string) Batch { // BatchGet is a BatchGetItem operation. type BatchGet struct { - batch Batch - reqs []*Query - projection string - consistent bool + batch Batch + reqs []*Query + projections map[string][]string // table → paths + projection []string // default paths + consistent bool - subber err error cc *ConsumedCapacity } @@ -62,17 +63,29 @@ func (b Batch) Get(keys ...Keyed) *BatchGet { batch: b, err: b.err, } - bg.add(keys) - return bg + return bg.And(keys...) } -// And adds more keys to be gotten. +// And adds more keys to be gotten from the default table. +// To get items from other tables, use [BatchGet.From] or [BatchGet.FromRange]. func (bg *BatchGet) And(keys ...Keyed) *BatchGet { - bg.add(keys) - return bg + return bg.add(bg.batch.table, bg.batch.hashKey, bg.batch.rangeKey, keys...) +} + +// From adds more keys to be gotten from the given table. +// The given table's primary key must be a hash key (partition key) only. +// For tables with a range key (sort key) primary key, use [BatchGet.FromRange]. +func (bg *BatchGet) From(table Table, hashKey string, keys ...Keyed) *BatchGet { + return bg.add(table, hashKey, "", keys...) +} + +// FromRange adds more keys to be gotten from the given table. +// For tables without a range key (sort key) primary key, use [BatchGet.From]. +func (bg *BatchGet) FromRange(table Table, hashKey, rangeKey string, keys ...Keyed) *BatchGet { + return bg.add(table, hashKey, rangeKey, keys...) } -func (bg *BatchGet) add(keys []Keyed) { +func (bg *BatchGet) add(table Table, hashKey string, rangeKey string, keys ...Keyed) *BatchGet { for _, key := range keys { if key == nil { bg.setError(errors.New("dynamo: batch: the Keyed interface must not be nil")) @@ -85,23 +98,69 @@ func (bg *BatchGet) add(keys []Keyed) { } bg.reqs = append(bg.reqs, get) } + return bg } // Project limits the result attributes to the given paths. +// This will apply to all tables, but can be overriden by [BatchGet.ProjectTable] to set specific per-table projections. func (bg *BatchGet) Project(paths ...string) *BatchGet { - var expr string - for i, p := range paths { - if i != 0 { - expr += ", " + bg.projection = paths + return bg +} + +// Project limits the result attributes to the given paths for the given table. +func (bg *BatchGet) ProjectTable(table Table, paths ...string) *BatchGet { + return bg.project(table.Name(), paths...) +} + +func (bg *BatchGet) project(table string, paths ...string) *BatchGet { + if bg.projections == nil { + bg.projections = make(map[string][]string) + } + bg.projections[table] = paths + return bg +} + +func (bg *BatchGet) projectionFor(table string) []string { + if proj := bg.projections[table]; proj != nil { + return proj + } + if bg.projection != nil { + return bg.projection + } + return nil +} + +// Merge copies operations and settings from src to this batch get. +func (bg *BatchGet) Merge(src *BatchGet) *BatchGet { + bg.reqs = append(bg.reqs, src.reqs...) + bg.consistent = bg.consistent || src.consistent + this := bg.batch.table.Name() + for table, proj := range src.projections { + if this == table { + continue + } + bg.mergeProjection(table, proj) + } + if len(src.projection) > 0 { + if that := src.batch.table.Name(); that != this { + bg.mergeProjection(that, src.projection) } - name, err := bg.escape(p) - bg.setError(err) - expr += name } - bg.projection = expr return bg } +func (bg *BatchGet) mergeProjection(table string, proj []string) { + current := bg.projections[table] + merged := current + for _, path := range proj { + if !slices.Contains(current, path) { + merged = append(merged, path) + } + } + bg.project(table, merged...) +} + // Consistent will, if on is true, make this batch use a strongly consistent read. // Reads are eventually consistent by default. // Strongly consistent reads are more resource-heavy than eventually consistent reads. @@ -118,7 +177,7 @@ func (bg *BatchGet) ConsumedCapacity(cc *ConsumedCapacity) *BatchGet { // All executes this request and unmarshals all results to out, which must be a pointer to a slice. func (bg *BatchGet) All(out interface{}) error { - iter := newBGIter(bg, unmarshalAppendTo(out), bg.err) + iter := newBGIter(bg, unmarshalAppendTo(out), nil, bg.err) for iter.Next(out) { } return iter.Err() @@ -126,7 +185,7 @@ func (bg *BatchGet) All(out interface{}) error { // AllWithContext executes this request and unmarshals all results to out, which must be a pointer to a slice. func (bg *BatchGet) AllWithContext(ctx context.Context, out interface{}) error { - iter := newBGIter(bg, unmarshalAppendTo(out), bg.err) + iter := newBGIter(bg, unmarshalAppendTo(out), nil, bg.err) for iter.NextWithContext(ctx, out) { } return iter.Err() @@ -134,7 +193,13 @@ func (bg *BatchGet) AllWithContext(ctx context.Context, out interface{}) error { // Iter returns a results iterator for this batch. func (bg *BatchGet) Iter() Iter { - return newBGIter(bg, unmarshalItem, bg.err) + return newBGIter(bg, unmarshalItem, nil, bg.err) +} + +// IterWithTable is like [BatchGet.Iter], but will update the value pointed by tablePtr after each iteration. +// This can be useful when getting from multiple tables to determine which table the latest item came from. +func (bg *BatchGet) IterWithTable(tablePtr *string) Iter { + return newBGIter(bg, unmarshalItem, tablePtr, bg.err) } func (bg *BatchGet) input(start int) *dynamodb.BatchGetItemInput { @@ -147,12 +212,12 @@ func (bg *BatchGet) input(start int) *dynamodb.BatchGetItemInput { } in := &dynamodb.BatchGetItemInput{ - RequestItems: make(map[string]*dynamodb.KeysAndAttributes, 1), + RequestItems: make(map[string]*dynamodb.KeysAndAttributes), } - if bg.projection != "" { - for _, get := range bg.reqs[start:end] { - get.Project(get.projection) + for _, get := range bg.reqs[start:end] { + if proj := bg.projectionFor(get.table.Name()); proj != nil { + get.Project(proj...) bg.setError(get.err) } } @@ -160,22 +225,19 @@ func (bg *BatchGet) input(start int) *dynamodb.BatchGetItemInput { in.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) } - var kas *dynamodb.KeysAndAttributes for _, get := range bg.reqs[start:end] { + table := get.table.Name() + kas := in.RequestItems[table] if kas == nil { kas = get.keysAndAttribs() + if bg.consistent { + kas.ConsistentRead = &bg.consistent + } + in.RequestItems[table] = kas continue } kas.Keys = append(kas.Keys, get.keys()) } - if bg.projection != "" { - kas.ProjectionExpression = &bg.projection - kas.ExpressionAttributeNames = bg.nameExpr - } - if bg.consistent { - kas.ConsistentRead = &bg.consistent - } - in.RequestItems[bg.batch.table.Name()] = kas return in } @@ -188,8 +250,10 @@ func (bg *BatchGet) setError(err error) { // bgIter is the iterator for Batch Get operations type bgIter struct { bg *BatchGet + track *string // table out value input *dynamodb.BatchGetItemInput output *dynamodb.BatchGetItemOutput + got []batchGot err error idx int total int @@ -198,13 +262,19 @@ type bgIter struct { unmarshal unmarshalFunc } -func newBGIter(bg *BatchGet, fn unmarshalFunc, err error) *bgIter { +type batchGot struct { + table string + item Item +} + +func newBGIter(bg *BatchGet, fn unmarshalFunc, track *string, err error) *bgIter { if err == nil && len(bg.reqs) == 0 { err = ErrNoInput } iter := &bgIter{ bg: bg, + track: track, err: err, backoff: backoff.NewExponentialBackOff(), unmarshal: fn, @@ -230,16 +300,14 @@ func (itr *bgIter) NextWithContext(ctx context.Context, out interface{}) bool { return false } - tableName := itr.bg.batch.table.Name() - redo: // can we use results we already have? - if itr.output != nil && itr.idx < len(itr.output.Responses[tableName]) { - items := itr.output.Responses[tableName] - item := items[itr.idx] - itr.err = itr.unmarshal(item, out) + if itr.output != nil && itr.idx < len(itr.got) { + got := itr.got[itr.idx] + itr.err = itr.unmarshal(got.item, out) itr.idx++ itr.total++ + itr.trackTable(got.table) return itr.err == nil } @@ -248,12 +316,15 @@ redo: itr.input = itr.bg.input(itr.processed) } - if itr.output != nil && itr.idx >= len(itr.output.Responses[tableName]) { - var unprocessed int - if itr.output.UnprocessedKeys != nil && itr.output.UnprocessedKeys[tableName] != nil { - unprocessed = len(itr.output.UnprocessedKeys[tableName].Keys) + if itr.output != nil && itr.idx >= len(itr.got) { + for _, req := range itr.input.RequestItems { + itr.processed += len(req.Keys) + } + if itr.output.UnprocessedKeys != nil { + for _, keys := range itr.output.UnprocessedKeys { + itr.processed -= len(keys.Keys) + } } - itr.processed += len(itr.input.RequestItems[tableName].Keys) - unprocessed // have we exhausted all results? if len(itr.output.UnprocessedKeys) == 0 { // yes, try to get next inner batch of 100 items @@ -291,10 +362,27 @@ redo: } } + itr.got = itr.got[:0] + for table, resp := range itr.output.Responses { + for _, item := range resp { + itr.got = append(itr.got, batchGot{ + table: table, + item: item, + }) + } + } + // we've got unprocessed results, marshal one goto redo } +func (itr *bgIter) trackTable(next string) { + if itr.track == nil { + return + } + *itr.track = next +} + // Err returns the error encountered, if any. // You should check this after Next is finished. func (itr *bgIter) Err() error { diff --git a/batchwrite.go b/batchwrite.go index 028f049..742f739 100644 --- a/batchwrite.go +++ b/batchwrite.go @@ -15,11 +15,16 @@ const maxWriteOps = 25 // BatchWrite is a BatchWriteItem operation. type BatchWrite struct { batch Batch - ops []*dynamodb.WriteRequest + ops []batchWrite err error cc *ConsumedCapacity } +type batchWrite struct { + table string + op *dynamodb.WriteRequest +} + // Write creates a new batch write request, to which // puts and deletes can be added. func (b Batch) Write() *BatchWrite { @@ -29,33 +34,73 @@ func (b Batch) Write() *BatchWrite { } } -// Put adds put operations for items to this batch. +// Put adds put operations for items to this batch using the default table. func (bw *BatchWrite) Put(items ...interface{}) *BatchWrite { + return bw.PutIn(bw.batch.table, items...) +} + +// PutIn adds put operations for items to this batch using the given table. +// This can be useful for writing to multiple different tables. +func (bw *BatchWrite) PutIn(table Table, items ...interface{}) *BatchWrite { + name := table.Name() for _, item := range items { encoded, err := marshalItem(item) bw.setError(err) - bw.ops = append(bw.ops, &dynamodb.WriteRequest{PutRequest: &dynamodb.PutRequest{ - Item: encoded, - }}) + bw.ops = append(bw.ops, batchWrite{ + table: name, + op: &dynamodb.WriteRequest{PutRequest: &dynamodb.PutRequest{ + Item: encoded, + }}, + }) } return bw } -// Delete adds delete operations for the given keys to this batch. +// Delete adds delete operations for the given keys to this batch, using the default table. func (bw *BatchWrite) Delete(keys ...Keyed) *BatchWrite { + return bw.deleteIn(bw.batch.table, bw.batch.hashKey, bw.batch.rangeKey, keys...) +} + +// DeleteIn adds delete operations for the given keys to this batch, using the given table. +// hashKey must be the name of the primary key hash (partition) attribute. +// This function is for tables with a hash key (partition key) only. +// For tables including a range key (sort key) primary key, use [BatchWrite.DeleteInRange] instead. +func (bw *BatchWrite) DeleteIn(table Table, hashKey string, keys ...Keyed) *BatchWrite { + return bw.deleteIn(table, hashKey, "", keys...) +} + +// DeleteInRange adds delete operations for the given keys to this batch, using the given table. +// hashKey must be the name of the primary key hash (parition) attribute, rangeKey must be the name of the primary key range (sort) attribute. +// This function is for tables with a hash key (partition key) and range key (sort key). +// For tables without a range key primary key, use [BatchWrite.DeleteIn] instead. +func (bw *BatchWrite) DeleteInRange(table Table, hashKey, rangeKey string, keys ...Keyed) *BatchWrite { + return bw.deleteIn(table, hashKey, rangeKey, keys...) +} + +func (bw *BatchWrite) deleteIn(table Table, hashKey, rangeKey string, keys ...Keyed) *BatchWrite { + name := table.Name() for _, key := range keys { - del := bw.batch.table.Delete(bw.batch.hashKey, key.HashKey()) - if rk := key.RangeKey(); bw.batch.rangeKey != "" && rk != nil { - del.Range(bw.batch.rangeKey, rk) + del := table.Delete(hashKey, key.HashKey()) + if rk := key.RangeKey(); rangeKey != "" && rk != nil { + del.Range(rangeKey, rk) bw.setError(del.err) } - bw.ops = append(bw.ops, &dynamodb.WriteRequest{DeleteRequest: &dynamodb.DeleteRequest{ - Key: del.key(), - }}) + bw.ops = append(bw.ops, batchWrite{ + table: name, + op: &dynamodb.WriteRequest{DeleteRequest: &dynamodb.DeleteRequest{ + Key: del.key(), + }}, + }) } return bw } +// Merge copies operations from src to this batch. +func (bw *BatchWrite) Merge(src *BatchWrite) *BatchWrite { + bw.ops = append(bw.ops, src.ops...) + return bw +} + // ConsumedCapacity will measure the throughput capacity consumed by this operation and add it to cc. func (bw *BatchWrite) ConsumedCapacity(cc *ConsumedCapacity) *BatchWrite { bw.cc = cc @@ -72,6 +117,10 @@ func (bw *BatchWrite) Run() (wrote int, err error) { return bw.RunWithContext(ctx) } +// RunWithContext executes this batch. +// For batches with more than 25 operations, an error could indicate that +// some records have been written and some have not. Consult the wrote +// return amount to figure out which operations have succeeded. func (bw *BatchWrite) RunWithContext(ctx context.Context) (wrote int, err error) { if bw.err != nil { return 0, bw.err @@ -108,12 +157,21 @@ func (bw *BatchWrite) RunWithContext(ctx context.Context) (wrote int, err error) } } - unprocessed := res.UnprocessedItems[bw.batch.table.Name()] - wrote += len(ops) - len(unprocessed) - if len(unprocessed) == 0 { + wrote += len(ops) + if len(res.UnprocessedItems) == 0 { break } - ops = unprocessed + + ops = ops[:0] + for tableName, unprocessed := range res.UnprocessedItems { + wrote -= len(unprocessed) + for _, op := range unprocessed { + ops = append(ops, batchWrite{ + table: tableName, + op: op, + }) + } + } // need to sleep when re-requesting, per spec if err := aws.SleepWithContext(ctx, boff.NextBackOff()); err != nil { @@ -126,11 +184,13 @@ func (bw *BatchWrite) RunWithContext(ctx context.Context) (wrote int, err error) return wrote, nil } -func (bw *BatchWrite) input(ops []*dynamodb.WriteRequest) *dynamodb.BatchWriteItemInput { +func (bw *BatchWrite) input(ops []batchWrite) *dynamodb.BatchWriteItemInput { + items := make(map[string][]*dynamodb.WriteRequest) + for _, op := range ops { + items[op.table] = append(items[op.table], op.op) + } input := &dynamodb.BatchWriteItemInput{ - RequestItems: map[string][]*dynamodb.WriteRequest{ - bw.batch.table.Name(): ops, - }, + RequestItems: items, } if bw.cc != nil { input.ReturnConsumedCapacity = aws.String(dynamodb.ReturnConsumedCapacityIndexes) diff --git a/db_test.go b/db_test.go index deddd62..e9cd8c2 100644 --- a/db_test.go +++ b/db_test.go @@ -1,50 +1,124 @@ package dynamo import ( + "errors" + "fmt" + "log" "os" + "strconv" + "strings" "testing" "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" ) var ( - testDB *DB - testTable = "TestDB" + testDB *DB + testTableWidgets = "TestDB" + testTableSprockets = "TestDB-Sprockets" ) var dummyCreds = credentials.NewStaticCredentials("dummy", "dummy", "") const offlineSkipMsg = "DYNAMO_TEST_REGION not set" -func init() { - // os.Setenv("DYNAMO_TEST_REGION", "us-west-2") - if region := os.Getenv("DYNAMO_TEST_REGION"); region != "" { - var endpoint *string - if dte := os.Getenv("DYNAMO_TEST_ENDPOINT"); dte != "" { - endpoint = aws.String(dte) - } +// widget is the data structure used for integration tests +type widget struct { + UserID int `dynamo:",hash"` + Time time.Time `dynamo:",range" index:"Msg-Time-index,range"` + Msg string `index:"Msg-Time-index,hash"` + Count int + Meta map[string]string + StrPtr *string `dynamo:",allowempty"` +} + +func TestMain(m *testing.M) { + var endpoint, region *string + if dte := os.Getenv("DYNAMO_TEST_ENDPOINT"); dte != "" { + endpoint = &dte + } + if dtr := os.Getenv("DYNAMO_TEST_REGION"); dtr != "" { + region = &dtr + } + if endpoint != nil && region == nil { + dtr := "local" + region = &dtr + } + if region != nil { testDB = New(session.Must(session.NewSession()), &aws.Config{ - Region: aws.String(region), + Region: region, Endpoint: endpoint, // LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody), }) } + + timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10) + var offline bool if table := os.Getenv("DYNAMO_TEST_TABLE"); table != "" { - testTable = table + offline = false + // Test-% --> Test-1707708680863 + table = strings.ReplaceAll(table, "%", timestamp) + testTableWidgets = table + } + if table := os.Getenv("DYNAMO_TEST_TABLE2"); table != "" { + table = strings.ReplaceAll(table, "%", timestamp) + testTableSprockets = table + } else if !offline { + testTableSprockets = testTableWidgets + "-Sprockets" + } + + if !offline && testTableWidgets == testTableSprockets { + panic(fmt.Sprintf("DYNAMO_TEST_TABLE must not equal DYNAMO_TEST_TABLE2. got DYNAMO_TEST_TABLE=%q and DYNAMO_TEST_TABLE2=%q", + testTableWidgets, testTableSprockets)) + } + + var shouldCreate bool + switch os.Getenv("DYNAMO_TEST_CREATE_TABLE") { + case "1", "true", "yes": + shouldCreate = true + case "0", "false", "no": + shouldCreate = false + default: + shouldCreate = endpoint != nil + } + + var created []Table + if testDB != nil { + for _, name := range []string{testTableWidgets, testTableSprockets} { + table := testDB.Table(name) + log.Println("Checking test table:", name) + _, err := table.Describe().Run() + switch { + case isTableNotExistsErr(err) && shouldCreate: + log.Println("Creating test table:", name) + if err := testDB.CreateTable(name, widget{}).Run(); err != nil { + panic(err) + } + created = append(created, testDB.Table(name)) + case err != nil: + panic(err) + } + } + } + + code := m.Run() + defer os.Exit(code) + + for _, table := range created { + log.Println("Deleting test table:", table.Name()) + if err := table.DeleteTable().Run(); err != nil { + log.Println("Error deleting test table:", table.Name(), err) + } } } -// widget is the data structure used for integration tests -type widget struct { - UserID int `dynamo:",hash"` - Time time.Time `dynamo:",range"` - Msg string - Count int - Meta map[string]string - StrPtr *string `dynamo:",allowempty"` +func isTableNotExistsErr(err error) bool { + var ae awserr.Error + return errors.As(err, &ae) && ae.Code() == "ResourceNotFoundException" } func TestListTables(t *testing.T) { @@ -60,13 +134,13 @@ func TestListTables(t *testing.T) { found := false for _, t := range tables { - if t == testTable { + if t == testTableWidgets { found = true break } } if !found { - t.Error("couldn't find testTable", testTable, "in:", tables) + t.Error("couldn't find testTable", testTableWidgets, "in:", tables) } } diff --git a/delete_test.go b/delete_test.go index 751565d..9d11dd7 100644 --- a/delete_test.go +++ b/delete_test.go @@ -10,7 +10,7 @@ func TestDelete(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // first, add an item to delete later item := widget{ diff --git a/describetable_test.go b/describetable_test.go index 34bc50e..9e22173 100644 --- a/describetable_test.go +++ b/describetable_test.go @@ -8,7 +8,7 @@ func TestDescribeTable(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) desc, err := table.Describe().Run() if err != nil { @@ -16,8 +16,8 @@ func TestDescribeTable(t *testing.T) { return } - if desc.Name != testTable { - t.Error("wrong name:", desc.Name, "≠", testTable) + if desc.Name != testTableWidgets { + t.Error("wrong name:", desc.Name, "≠", testTableWidgets) } if desc.HashKey != "UserID" || desc.RangeKey != "Time" { t.Error("bad keys:", desc.HashKey, desc.RangeKey) diff --git a/put_test.go b/put_test.go index f0fd48a..fe2ad63 100644 --- a/put_test.go +++ b/put_test.go @@ -12,7 +12,7 @@ func TestPut(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) type widget2 struct { widget @@ -62,7 +62,7 @@ func TestPut(t *testing.T) { t.Errorf("bad old value. %#v ≠ %#v", oldValue, item) } - if cc.Total < 1 || cc.Table < 1 || cc.TableName != testTable { + if cc.Total < 1 || cc.Table < 1 || cc.TableName != testTableWidgets { t.Errorf("bad consumed capacity: %#v", cc) } @@ -77,7 +77,7 @@ func TestPutAndQueryAWSEncoding(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) type awsWidget struct { XUserID int `dynamodbav:"UserID"` diff --git a/query_test.go b/query_test.go index 398a359..f1d7992 100644 --- a/query_test.go +++ b/query_test.go @@ -13,7 +13,7 @@ func TestGetAllCount(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // first, add an item to make sure there is at least one item := widget{ @@ -146,7 +146,7 @@ func TestQueryPaging(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) widgets := []interface{}{ widget{ @@ -193,7 +193,7 @@ func TestQueryMagicLEK(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) widgets := []interface{}{ widget{ @@ -276,7 +276,7 @@ func TestQueryBadKeys(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) t.Run("hash key", func(t *testing.T) { var v interface{} diff --git a/scan_test.go b/scan_test.go index 75cccbf..9941b6a 100644 --- a/scan_test.go +++ b/scan_test.go @@ -13,7 +13,7 @@ func TestScan(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // first, add an item to make sure there is at least one item := widget{ @@ -107,7 +107,7 @@ func TestScanPaging(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // prepare data insert := make([]interface{}, 10) @@ -126,11 +126,9 @@ func TestScanPaging(t *testing.T) { widgets := [10]widget{} itr := table.Scan().Consistent(true).SearchLimit(1).Iter() for i := 0; i < len(widgets); i++ { - more := itr.Next(&widgets[i]) + itr.Next(&widgets[i]) if itr.Err() != nil { t.Error("unexpected error", itr.Err()) - } - if !more { break } itr = table.Scan().StartFrom(itr.LastEvaluatedKey()).SearchLimit(1).Iter() @@ -146,22 +144,17 @@ func TestScanPaging(t *testing.T) { const segments = 2 ctx := context.Background() widgets := [10]widget{} - itr := table.Scan().Consistent(true).SearchLimit(1).IterParallel(ctx, segments) - for i := 0; i < len(widgets)/segments; i++ { - var more bool - for j := 0; j < segments; j++ { - more = itr.Next(&widgets[i*segments+j]) - if !more && j != segments-1 { - t.Error("bad number of results from parallel scan") - } + limit := int64(len(widgets) / segments) + itr := table.Scan().Consistent(true).SearchLimit(limit).IterParallel(ctx, segments) + for i := 0; i < len(widgets); { + for ; i < len(widgets) && itr.Next(&widgets[i]); i++ { } if itr.Err() != nil { t.Error("unexpected error", itr.Err()) - } - if !more { break } - itr = table.Scan().SearchLimit(1).IterParallelStartFrom(ctx, itr.LastEvaluatedKeys()) + t.Logf("parallel chunk: %d", i) + itr = table.Scan().SearchLimit(limit).IterParallelStartFrom(ctx, itr.LastEvaluatedKeys()) } for i, w := range widgets { if w.UserID == 0 && w.Time.IsZero() { @@ -175,7 +168,7 @@ func TestScanMagicLEK(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) widgets := []interface{}{ widget{ diff --git a/ttl_test.go b/ttl_test.go index 9ded4e3..9ffcafc 100644 --- a/ttl_test.go +++ b/ttl_test.go @@ -8,7 +8,7 @@ func TestDescribeTTL(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) desc, err := table.DescribeTTL().Run() if err != nil { diff --git a/tx_test.go b/tx_test.go index 626d026..0fd28e0 100644 --- a/tx_test.go +++ b/tx_test.go @@ -21,7 +21,7 @@ func TestTx(t *testing.T) { widget2 := widget{UserID: 69, Time: date2, Msg: "cat"} // basic write & check - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) tx := testDB.WriteTx() var cc, ccold ConsumedCapacity tx.Idempotent(true) @@ -184,7 +184,7 @@ func TestTxRetry(t *testing.T) { date1 := time.Date(1999, 1, 1, 1, 1, 1, 0, time.UTC) widget1 := widget{UserID: 69, Time: date1, Msg: "dog", Count: 0} - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) if err := table.Put(widget1).Run(); err != nil { t.Fatal(err) } diff --git a/update_test.go b/update_test.go index 56dce4d..ae22824 100644 --- a/update_test.go +++ b/update_test.go @@ -13,7 +13,7 @@ func TestUpdate(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) type widget2 struct { widget @@ -168,7 +168,7 @@ func TestUpdateNil(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // first, add an item to make sure there is at least one item := widget{ @@ -223,7 +223,7 @@ func TestUpdateSetAutoOmit(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) type widget2 struct { widget diff --git a/updatetable_test.go b/updatetable_test.go index 04c2c1b..472b641 100644 --- a/updatetable_test.go +++ b/updatetable_test.go @@ -9,7 +9,7 @@ func _TestUpdateTable(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) desc, err := table.UpdateTable().CreateIndex(Index{ Name: "test123", @@ -32,8 +32,8 @@ func _TestUpdateTable(t *testing.T) { if err != nil { t.Error(err) } - if desc.Name != testTable { - t.Error("wrong name:", desc.Name, "≠", testTable) + if desc.Name != testTableWidgets { + t.Error("wrong name:", desc.Name, "≠", testTableWidgets) } if desc.Status != UpdatingStatus { t.Error("bad status:", desc.Status, "≠", UpdatingStatus)