Skip to content

Commit

Permalink
Client Encryption Bulk: Adds the ability to pass container rid to SDK…
Browse files Browse the repository at this point in the history
… header for Bulk operations for container recreate scenarios (Azure#2404)

* Add support to inject/append custom header via RequestOptions.

* Update DotNetSDKAPI.json

* Update CosmosHeaderTests.cs

* Changes as per review comments.

* Update CosmosHeaderTests.cs

* Provide shallow copy function.

* Fix allows headers required by enc package to be passed during bulk operaton.

* Move to PREVIEW

* Fixes.

* Fixes.

* Update BatchAsyncBatcher.cs

* Update BatchAsyncContainerExecutor.cs

* using constants

* Update BatchAsyncBatcher.cs

* fixed names

* Update PartitionKeyRangeServerBatchRequest.cs

* Updated documentation.

* Fixes as per review comments.

* Update PartitionKeyRangeServerBatchRequestTests.cs

* fixes.

Co-authored-by: j82w <j82w@users.noreply.github.com>
  • Loading branch information
kr-santosh and j82w authored Sep 30, 2021
1 parent e342f48 commit 251a6e4
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 10 deletions.
10 changes: 10 additions & 0 deletions Microsoft.Azure.Cosmos/src/Batch/BatchAsyncBatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ internal class BatchAsyncBatcher
private readonly CosmosClientContext clientContext;
private long currentSize = 0;
private bool dispatched = false;
private bool isClientEncrypted = false;
private string intendedCollectionRidValue;

public bool IsEmpty => this.batchOperations.Count == 0;

Expand Down Expand Up @@ -86,6 +88,12 @@ public virtual bool TryAdd(ItemBatchOperation operation)
throw new ArgumentNullException(nameof(operation.Context));
}

if (operation.Context.IsClientEncrypted && !this.isClientEncrypted)
{
this.isClientEncrypted = true;
this.intendedCollectionRidValue = operation.Context.IntendedCollectionRidValue;
}

if (this.batchOperations.Count == this.maxBatchOperationCount)
{
DefaultTrace.TraceInformation($"Batch is full - Max operation count {this.maxBatchOperationCount} reached.");
Expand Down Expand Up @@ -224,6 +232,8 @@ internal virtual async Task<Tuple<PartitionKeyRangeServerBatchRequest, ArraySegm
this.maxBatchOperationCount,
ensureContinuousOperationIndexes: false,
serializerCore: this.serializerCore,
isClientEncrypted: this.isClientEncrypted,
intendedCollectionRidValue: this.intendedCollectionRidValue,
cancellationToken: cancellationToken).ConfigureAwait(false);
}
}
Expand Down
35 changes: 32 additions & 3 deletions Microsoft.Azure.Cosmos/src/Batch/BatchAsyncContainerExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,32 @@ public virtual async Task<TransactionalBatchOperationResult> AddAsync(
trace,
cancellationToken).ConfigureAwait(false);
BatchAsyncStreamer streamer = this.GetOrAddStreamerForPartitionKeyRange(resolvedPartitionKeyRangeId);

ItemBatchOperationContext context = new ItemBatchOperationContext(
resolvedPartitionKeyRangeId,
trace,
BatchAsyncContainerExecutor.GetRetryPolicy(this.cosmosContainer, operation.OperationType, this.retryOptions));

if (itemRequestOptions != null && itemRequestOptions.AddRequestHeaders != null)
{
// get the header value if any, passed by the encryption package.
Headers encryptionHeaders = new Headers();
itemRequestOptions.AddRequestHeaders?.Invoke(encryptionHeaders);

// make sure we set the Intended Collection Rid header when we have encrypted payload.
// This primarily would allow CosmosDB Encryption package to detect change in container referenced by a Client
// and prevent creating data with wrong Encryption Policy.
if (encryptionHeaders.TryGetValue(HttpConstants.HttpHeaders.IsClientEncrypted, out string encrypted))
{
context.IsClientEncrypted = bool.Parse(encrypted);

if (context.IsClientEncrypted && encryptionHeaders.TryGetValue(WFConstants.BackendHeaders.IntendedCollectionRid, out string ridValue))
{
context.IntendedCollectionRidValue = ridValue;
}
}
}

operation.AttachContext(context);
streamer.Add(operation);
return await context.OperationTask;
Expand Down Expand Up @@ -176,9 +198,16 @@ private static bool ValidateOperationEPK(
return true;
}

private static void AddHeadersToRequestMessage(RequestMessage requestMessage, string partitionKeyRangeId)
private static void AddHeadersToRequestMessage(RequestMessage requestMessage, PartitionKeyRangeServerBatchRequest partitionKeyRangeServerBatchRequest)
{
requestMessage.Headers.PartitionKeyRangeId = partitionKeyRangeId;
requestMessage.Headers.PartitionKeyRangeId = partitionKeyRangeServerBatchRequest.PartitionKeyRangeId;

if (partitionKeyRangeServerBatchRequest.IsClientEncrypted)
{
requestMessage.Headers.Add(HttpConstants.HttpHeaders.IsClientEncrypted, partitionKeyRangeServerBatchRequest.IsClientEncrypted.ToString());
requestMessage.Headers.Add(WFConstants.BackendHeaders.IntendedCollectionRid, partitionKeyRangeServerBatchRequest.IntendedCollectionRidValue);
}

requestMessage.Headers.Add(HttpConstants.HttpHeaders.ShouldBatchContinueOnError, bool.TrueString);
requestMessage.Headers.Add(HttpConstants.HttpHeaders.IsBatchAtomic, bool.FalseString);
requestMessage.Headers.Add(HttpConstants.HttpHeaders.IsBatchRequest, bool.TrueString);
Expand Down Expand Up @@ -247,7 +276,7 @@ private async Task<PartitionKeyRangeBatchExecutionResult> ExecuteAsync(
cosmosContainerCore: this.cosmosContainer,
feedRange: null,
streamPayload: serverRequestPayload,
requestEnricher: requestMessage => BatchAsyncContainerExecutor.AddHeadersToRequestMessage(requestMessage, serverRequest.PartitionKeyRangeId),
requestEnricher: requestMessage => BatchAsyncContainerExecutor.AddHeadersToRequestMessage(requestMessage, serverRequest),
trace: trace,
cancellationToken: cancellationToken).ConfigureAwait(false);

Expand Down
4 changes: 4 additions & 0 deletions Microsoft.Azure.Cosmos/src/Batch/ItemBatchOperationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ internal class ItemBatchOperationContext : IDisposable
{
public string PartitionKeyRangeId { get; private set; }

public bool IsClientEncrypted { get; set; }

public string IntendedCollectionRidValue { get; set; }

public BatchAsyncBatcher CurrentBatcher { get; set; }

public Task<TransactionalBatchOperationResult> OperationTask => this.taskCompletionSource.Task;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,34 @@ internal sealed class PartitionKeyRangeServerBatchRequest : ServerBatchRequest
/// Initializes a new instance of the <see cref="PartitionKeyRangeServerBatchRequest"/> class.
/// </summary>
/// <param name="partitionKeyRangeId">The partition key range id associated with all requests.</param>
/// <param name="isClientEncrypted"> If the operation has Encrypted data. </param>
/// <param name="intendedCollectionRidValue"> Intended Collection Rid value. </param>
/// <param name="maxBodyLength">Maximum length allowed for the request body.</param>
/// <param name="maxOperationCount">Maximum number of operations allowed in the request.</param>
/// <param name="serializerCore">Serializer to serialize user provided objects to JSON.</param>
public PartitionKeyRangeServerBatchRequest(
string partitionKeyRangeId,
bool isClientEncrypted,
string intendedCollectionRidValue,
int maxBodyLength,
int maxOperationCount,
CosmosSerializerCore serializerCore)
: base(maxBodyLength, maxOperationCount, serializerCore)
{
this.PartitionKeyRangeId = partitionKeyRangeId;
this.IsClientEncrypted = isClientEncrypted;
this.IntendedCollectionRidValue = intendedCollectionRidValue;
}

/// <summary>
/// Gets the PartitionKeyRangeId that applies to all operations in this request.
/// </summary>
public string PartitionKeyRangeId { get; }

public bool IsClientEncrypted { get; }

public string IntendedCollectionRidValue { get; }

/// <summary>
/// Creates an instance of <see cref="PartitionKeyRangeServerBatchRequest"/>.
/// In case of direct mode requests, all the operations are expected to belong to the same PartitionKeyRange.
Expand All @@ -43,6 +53,8 @@ public PartitionKeyRangeServerBatchRequest(
/// <param name="maxOperationCount">Maximum number of operations allowed in the request.</param>
/// <param name="ensureContinuousOperationIndexes">Whether to stop adding operations to the request once there is non-continuity in the operation indexes.</param>
/// <param name="serializerCore">Serializer to serialize user provided objects to JSON.</param>
/// <param name="isClientEncrypted"> Indicates if the request has encrypted data. </param>
/// <param name="intendedCollectionRidValue"> The intended collection Rid value. </param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> representing request cancellation.</param>
/// <returns>A newly created instance of <see cref="PartitionKeyRangeServerBatchRequest"/> and the overflow ItemBatchOperation not being processed.</returns>
public static async Task<Tuple<PartitionKeyRangeServerBatchRequest, ArraySegment<ItemBatchOperation>>> CreateAsync(
Expand All @@ -52,9 +64,18 @@ public static async Task<Tuple<PartitionKeyRangeServerBatchRequest, ArraySegment
int maxOperationCount,
bool ensureContinuousOperationIndexes,
CosmosSerializerCore serializerCore,
bool isClientEncrypted,
string intendedCollectionRidValue,
CancellationToken cancellationToken)
{
PartitionKeyRangeServerBatchRequest request = new PartitionKeyRangeServerBatchRequest(partitionKeyRangeId, maxBodyLength, maxOperationCount, serializerCore);
PartitionKeyRangeServerBatchRequest request = new PartitionKeyRangeServerBatchRequest(
partitionKeyRangeId,
isClientEncrypted,
intendedCollectionRidValue,
maxBodyLength,
maxOperationCount,
serializerCore);

ArraySegment<ItemBatchOperation> pendingOperations = await request.CreateBodyStreamAsync(operations, cancellationToken, ensureContinuousOperationIndexes);
return new Tuple<PartitionKeyRangeServerBatchRequest, ArraySegment<ItemBatchOperation>>(request, pendingOperations);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
using System.Collections.Generic;
using System.Net;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Tracing;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json.Linq;

Expand Down Expand Up @@ -121,6 +122,114 @@ public async Task CreateItemAsync_WithBulk()
}
}

[TestMethod]
public async Task CreateItemAsyncValidateIntendedCollRid_WithBulk()
{
Container container = await this.database.CreateContainerAsync(Guid.NewGuid().ToString(), "/pk", 10000);

List<Task<ItemResponse<ToDoActivity>>> tasks = new List<Task<ItemResponse<ToDoActivity>>>();

ContainerInlineCore containerInternal = (ContainerInlineCore)container;

string rid = await containerInternal.GetCachedRIDAsync(forceRefresh: false, NoOpTrace.Singleton, cancellationToken: default);

// case 1. use wrong rid by using a stale rid.
ItemRequestOptions itemRequestOptions = new ItemRequestOptions()
{
AddRequestHeaders = (headers) =>
{
headers[Documents.HttpConstants.HttpHeaders.IsClientEncrypted] = bool.TrueString;
headers[Documents.WFConstants.BackendHeaders.IntendedCollectionRid] = rid;
}
};

// delete the container.
using (await this.database.GetContainer(container.Id).DeleteContainerStreamAsync())
{ }

// recreate with same id.
await this.database.CreateContainerAsync(container.Id, "/pk", 10000);


for (int i = 0; i < 2; i++)
{
tasks.Add(ExecuteCreateAsync(container, CreateItem(i.ToString()), itemRequestOptions));
}

try
{
await Task.WhenAll(tasks);
Assert.Fail("Bulk execution should have failed. ");
}
catch(CosmosException ex)
{
if(ex.StatusCode == HttpStatusCode.Created || ex.SubStatusCode != 1024)
{
Assert.Fail("Bulk execution should have failed with these specific status codes. ");
}
}

// case 2.
tasks.Clear();

// should ignore if the item is not encrypted.
itemRequestOptions = new ItemRequestOptions()
{
AddRequestHeaders = (headers) =>
{
headers[Documents.HttpConstants.HttpHeaders.IsClientEncrypted] = bool.FalseString;
headers[Documents.WFConstants.BackendHeaders.IntendedCollectionRid] = rid;
}
};

for (int i = 0; i < 2; i++)
{
tasks.Add(ExecuteCreateAsync(container, CreateItem(i.ToString()), itemRequestOptions));
}

await Task.WhenAll(tasks);

for (int i = 0; i < 2; i++)
{
Task<ItemResponse<ToDoActivity>> task = tasks[i];
ItemResponse<ToDoActivity> result = await task;
Assert.IsTrue(result.Headers.RequestCharge > 0);
Assert.IsFalse(string.IsNullOrEmpty(result.Diagnostics.ToString()));
Assert.AreEqual(HttpStatusCode.Created, result.StatusCode);
}

// case 3.
tasks.Clear();

// use the correct rid.
rid = await containerInternal.GetCachedRIDAsync(forceRefresh: false, NoOpTrace.Singleton, cancellationToken: default);

itemRequestOptions = new ItemRequestOptions()
{
AddRequestHeaders = (headers) =>
{
headers[Documents.HttpConstants.HttpHeaders.IsClientEncrypted] = bool.TrueString;
headers[Documents.WFConstants.BackendHeaders.IntendedCollectionRid] = rid;
}
};

for (int i = 3; i < 8; i++)
{
tasks.Add(ExecuteCreateAsync(container, CreateItem(i.ToString()), itemRequestOptions));
}

await Task.WhenAll(tasks);

for (int i = 0; i < 5; i++)
{
Task<ItemResponse<ToDoActivity>> task = tasks[i];
ItemResponse<ToDoActivity> result = await task;
Assert.IsTrue(result.Headers.RequestCharge > 0);
Assert.IsFalse(string.IsNullOrEmpty(result.Diagnostics.ToString()));
Assert.AreEqual(HttpStatusCode.Created, result.StatusCode);
}
}

[TestMethod]
public async Task CreateItemJObjectWithoutPK_WithBulk()
{
Expand Down Expand Up @@ -528,9 +637,9 @@ private async Task CreateLargeItemStreamWithBulk(int appxItemSize)
}
}

private static Task<ItemResponse<ToDoActivity>> ExecuteCreateAsync(Container container, ToDoActivity item)
private static Task<ItemResponse<ToDoActivity>> ExecuteCreateAsync(Container container, ToDoActivity item, ItemRequestOptions itemRequestOptions = null)
{
return container.CreateItemAsync<ToDoActivity>(item, new PartitionKey(item.pk));
return container.CreateItemAsync<ToDoActivity>(item, new PartitionKey(item.pk), itemRequestOptions);
}

private static Task<ItemResponse<JObject>> ExecuteCreateAsync(Container container, JObject item)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ public async Task ReadManyTestWithIncorrectIntendedContainerRid()
{
AddRequestHeaders = (headers) =>
{
headers["x-ms-cosmos-is-client-encrypted"] = bool.TrueString;
headers["x-ms-cosmos-intended-collection-rid"] = "iCoRrecTrID=";
headers[Documents.HttpConstants.HttpHeaders.IsClientEncrypted] = bool.TrueString;
headers[Documents.WFConstants.BackendHeaders.IntendedCollectionRid] = "iCoRrecTrID=";
}
};

Expand Down Expand Up @@ -307,8 +307,8 @@ public async Task ReadManyTestWithIncorrectIntendedContainerRid()
{
AddRequestHeaders = (headers) =>
{
headers["x-ms-cosmos-is-client-encrypted"] = bool.TrueString;
headers["x-ms-cosmos-intended-collection-rid"] = rid;
headers[Documents.HttpConstants.HttpHeaders.IsClientEncrypted] = bool.TrueString;
headers[Documents.WFConstants.BackendHeaders.IntendedCollectionRid] = rid;
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public async Task FitsAllOperations()
2,
false,
MockCosmosUtil.Serializer,
isClientEncrypted: false,
intendedCollectionRidValue: null,
default(CancellationToken));

Assert.AreEqual(operations.Count, request.Operations.Count);
Expand Down Expand Up @@ -65,6 +67,8 @@ public async Task OverflowsBasedOnCount()
1,
false,
MockCosmosUtil.Serializer,
isClientEncrypted: false,
intendedCollectionRidValue: null,
default(CancellationToken));

Assert.AreEqual(1, request.Operations.Count);
Expand Down Expand Up @@ -96,6 +100,8 @@ public async Task OverflowsBasedOnCount_WithOffset()
1,
false,
MockCosmosUtil.Serializer,
isClientEncrypted: false,
intendedCollectionRidValue: null,
default(CancellationToken));

Assert.AreEqual(1, request.Operations.Count);
Expand Down Expand Up @@ -170,6 +176,8 @@ private static async Task<Tuple<PartitionKeyRangeServerBatchRequest, ArraySegmen
maxServerRequestOperationCount,
false,
MockCosmosUtil.Serializer,
isClientEncrypted: false,
intendedCollectionRidValue: null,
default(CancellationToken));
}
}
Expand Down

0 comments on commit 251a6e4

Please sign in to comment.