Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MixedBulkWriteOperation should generate inserted document IDs at most once per batch #1484

Merged
merged 1 commit into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.mongodb.internal.connection;

import com.mongodb.lang.Nullable;
import org.bson.BsonBinary;
import org.bson.BsonBinaryWriter;
import org.bson.BsonBoolean;
Expand Down Expand Up @@ -57,11 +58,17 @@ public class IdHoldingBsonWriter extends LevelCountingBsonWriter {
private LevelCountingBsonWriter idBsonBinaryWriter;
private BasicOutputBuffer outputBuffer;
private String currentFieldName;
private final BsonValue fallbackId;
private BsonValue id;
private boolean idFieldIsAnArray = false;

public IdHoldingBsonWriter(final BsonWriter bsonWriter) {
/**
* @param fallbackId The "_id" field value to use if the top-level document written via this {@link BsonWriter}
* does not have "_id". If {@code null}, then a new {@link BsonObjectId} is generated instead.
*/
public IdHoldingBsonWriter(final BsonWriter bsonWriter, @Nullable final BsonObjectId fallbackId) {
super(bsonWriter);
this.fallbackId = fallbackId;
}

@Override
Expand Down Expand Up @@ -99,7 +106,7 @@ public void writeEndDocument() {
}

if (getCurrentLevel() == 0 && id == null) {
id = new BsonObjectId();
id = fallbackId == null ? new BsonObjectId() : fallbackId;
writeObjectId(ID_FIELD_NAME, id.asObjectId().getValue());
}
super.writeEndDocument();
Expand Down Expand Up @@ -408,6 +415,15 @@ public void flush() {
super.flush();
}

/**
* Returns either the value of the "_id" field from the top-level document written via this {@link BsonWriter},
* provided that the document is not {@link RawBsonDocument},
* or the generated {@link BsonObjectId}.
* If the document is {@link RawBsonDocument}, then returns {@code null}.
* <p>
* {@linkplain #flush() Flushing} is not required before calling this method.</p>
*/
@Nullable
public BsonValue getId() {
return id;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.mongodb.internal.bulk.WriteRequestWithIndex;
import org.bson.BsonDocument;
import org.bson.BsonDocumentWrapper;
import org.bson.BsonObjectId;
import org.bson.BsonValue;
import org.bson.BsonWriter;
import org.bson.codecs.BsonValueCodecProvider;
Expand Down Expand Up @@ -191,10 +192,23 @@ public void encode(final BsonWriter writer, final WriteRequestWithIndex writeReq
InsertRequest insertRequest = (InsertRequest) writeRequestWithIndex.getWriteRequest();
BsonDocument document = insertRequest.getDocument();

IdHoldingBsonWriter idHoldingBsonWriter = new IdHoldingBsonWriter(writer);
getCodec(document).encode(idHoldingBsonWriter, document,
EncoderContext.builder().isEncodingCollectibleDocument(true).build());
insertedIds.put(writeRequestWithIndex.getIndex(), idHoldingBsonWriter.getId());
BsonValue documentId = insertedIds.compute(
writeRequestWithIndex.getIndex(),
(writeRequestIndex, writeRequestDocumentId) -> {
IdHoldingBsonWriter idHoldingBsonWriter = new IdHoldingBsonWriter(
writer,
// Reuse `writeRequestDocumentId` if it may have been generated
// by `IdHoldingBsonWriter` in a previous attempt.
// If its type is not `BsonObjectId`, we know it could not have been generated.
writeRequestDocumentId instanceof BsonObjectId ? writeRequestDocumentId.asObjectId() : null);
getCodec(document).encode(idHoldingBsonWriter, document,
EncoderContext.builder().isEncodingCollectibleDocument(true).build());
return idHoldingBsonWriter.getId();
});
if (documentId == null) {
// we must add an entry anyway because we rely on all the indexes being present
insertedIds.put(writeRequestWithIndex.getIndex(), null);
}
} else if (writeRequestWithIndex.getType() == WriteRequest.Type.UPDATE
|| writeRequestWithIndex.getType() == WriteRequest.Type.REPLACE) {
UpdateRequest update = (UpdateRequest) writeRequestWithIndex.getWriteRequest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ import static org.bson.BsonHelper.documentWithValuesOfEveryType
import static org.bson.BsonHelper.getBsonValues

class IdHoldingBsonWriterSpecification extends Specification {
private static final OBJECT_ID = new BsonObjectId()

def 'should write all types'() {
given:
def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
def document = documentWithValuesOfEveryType()

when:
Expand All @@ -47,18 +48,25 @@ class IdHoldingBsonWriterSpecification extends Specification {
!document.containsKey('_id')
encodedDocument.containsKey('_id')
idTrackingBsonWriter.getId() == encodedDocument.get('_id')
if (expectedIdNullIfMustBeGenerated != null) {
idTrackingBsonWriter.getId() == expectedIdNullIfMustBeGenerated
}

when:
encodedDocument.remove('_id')

then:
encodedDocument == document

where:
fallbackId << [null, OBJECT_ID]
expectedIdNullIfMustBeGenerated << [null, OBJECT_ID]
}

def 'should support all types for _id value'() {
given:
def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
def document = new BsonDocument()
document.put('_id', id)

Expand All @@ -71,12 +79,15 @@ class IdHoldingBsonWriterSpecification extends Specification {
idTrackingBsonWriter.getId() == id

where:
id << getBsonValues()
[id, fallbackId] << [
getBsonValues(),
[null, new BsonObjectId()]
].combinations()
}

def 'serialize document with list of documents that contain an _id field'() {
def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
def document = new BsonDocument('_id', new BsonObjectId())
.append('items', new BsonArray(Collections.singletonList(new BsonDocument('_id', new BsonObjectId()))))

Expand All @@ -86,11 +97,14 @@ class IdHoldingBsonWriterSpecification extends Specification {

then:
encodedDocument == document

where:
fallbackId << [null, new BsonObjectId()]
}

def 'serialize _id documents containing arrays'() {
def bsonBinaryWriter = new BsonBinaryWriter(new BasicOutputBuffer())
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter)
def idTrackingBsonWriter = new IdHoldingBsonWriter(bsonBinaryWriter, fallbackId)
BsonDocument document = BsonDocument.parse(json)

when:
Expand All @@ -102,7 +116,8 @@ class IdHoldingBsonWriterSpecification extends Specification {
encodedDocument == document

where:
json << ['{"_id": {"a": []}, "b": 123}',
[json, fallbackId] << [
['{"_id": {"a": []}, "b": 123}',
'{"_id": {"a": [1, 2]}, "b": 123}',
'{"_id": {"a": [[[[1]]]]}, "b": 123}',
'{"_id": {"a": [{"a": [1, 2]}]}, "b": 123}',
Expand All @@ -112,7 +127,9 @@ class IdHoldingBsonWriterSpecification extends Specification {
'{"_id": [1, 2], "b": 123}',
'{"_id": [[1], [[2]]], "b": 123}',
'{"_id": [{"a": 1}], "b": 123}',
'{"_id": [{"a": [{"b": 123}]}]}']
'{"_id": [{"a": [{"b": 123}]}]}'],
[null, new BsonObjectId()]
].combinations()
}

private static BsonDocument getEncodedDocument(BsonOutput buffer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,35 @@
import com.mongodb.MongoBulkWriteException;
import com.mongodb.MongoWriteConcernException;
import com.mongodb.MongoWriteException;
import com.mongodb.ServerAddress;
import com.mongodb.client.model.CreateCollectionOptions;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.ValidationOptions;
import com.mongodb.event.CommandListener;
import com.mongodb.event.CommandStartedEvent;
import org.bson.BsonArray;
import org.bson.BsonDocument;
import org.bson.BsonInt32;
import org.bson.BsonString;
import org.bson.BsonValue;
import org.bson.Document;
import org.bson.codecs.pojo.PojoCodecProvider;
import org.junit.Before;
import org.junit.Test;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet;
import static com.mongodb.ClusterFixture.serverVersionAtLeast;
import static com.mongodb.MongoClientSettings.getDefaultCodecRegistry;
import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder;
import static java.lang.String.format;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
Expand Down Expand Up @@ -116,6 +130,55 @@ public void testWriteErrorDetailsIsPropagated() {
}
}

/**
* This test is not from the specification.
*/
@Test
@SuppressWarnings("try")
public void insertMustGenerateIdAtMostOnce() throws ExecutionException, InterruptedException {
assumeTrue(serverVersionAtLeast(4, 0));
assumeTrue(isDiscoverableReplicaSet());
ServerAddress primaryServerAddress = Fixture.getPrimary();
CompletableFuture<BsonValue> futureIdGeneratedByFirstInsertAttempt = new CompletableFuture<>();
CompletableFuture<BsonValue> futureIdGeneratedBySecondInsertAttempt = new CompletableFuture<>();
CommandListener commandListener = new CommandListener() {
@Override
public void commandStarted(final CommandStartedEvent event) {
if (event.getCommandName().equals("insert")) {
BsonValue generatedId = event.getCommand().getArray("documents").get(0).asDocument().get("_id");
if (!futureIdGeneratedByFirstInsertAttempt.isDone()) {
futureIdGeneratedByFirstInsertAttempt.complete(generatedId);
} else {
futureIdGeneratedBySecondInsertAttempt.complete(generatedId);
}
}
}
};
BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand"))
.append("mode", new BsonDocument("times", new BsonInt32(1)))
.append("data", new BsonDocument()
.append("failCommands", new BsonArray(singletonList(new BsonString("insert"))))
.append("errorLabels", new BsonArray(singletonList(new BsonString("RetryableWriteError"))))
.append("writeConcernError", new BsonDocument("code", new BsonInt32(91))
.append("errmsg", new BsonString("Replication is being shut down"))));
try (MongoClient client = MongoClients.create(getMongoClientSettingsBuilder()
.retryWrites(true)
.addCommandListener(commandListener)
.applyToServerSettings(builder -> builder.heartbeatFrequency(50, TimeUnit.MILLISECONDS))
.build());
FailPoint ignored = FailPoint.enable(failPointDocument, primaryServerAddress)) {
MongoCollection<MyDocument> coll = client.getDatabase(database.getName())
.getCollection(collection.getNamespace().getCollectionName(), MyDocument.class)
.withCodecRegistry(fromRegistries(
getDefaultCodecRegistry(),
fromProviders(PojoCodecProvider.builder().automatic(true).build())));
BsonValue insertedId = coll.insertOne(new MyDocument()).getInsertedId();
BsonValue idGeneratedByFirstInsertAttempt = futureIdGeneratedByFirstInsertAttempt.get();
assertEquals(idGeneratedByFirstInsertAttempt, insertedId);
assertEquals(idGeneratedByFirstInsertAttempt, futureIdGeneratedBySecondInsertAttempt.get());
}
}

private void setFailPoint() {
failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand"))
.append("mode", new BsonDocument("times", new BsonInt32(1)))
Expand All @@ -132,4 +195,15 @@ private void setFailPoint() {
private void disableFailPoint() {
getCollectionHelper().runAdminCommand(failPointDocument.append("mode", new BsonString("off")));
}

public static final class MyDocument {
private int v;

public MyDocument() {
}

public int getV() {
return v;
}
}
}