From 9eb60870a736c9c8089eedfa3f4626f24a0c93f4 Mon Sep 17 00:00:00 2001 From: sc978345 <115834537+sc978345@users.noreply.github.com> Date: Tue, 15 Oct 2024 13:29:52 -0500 Subject: [PATCH] [Internal] Query: Fixes issue with distributed query GetItemQueryStreamIterator (#4798) ## Description This change fixes an issue with GetItemQueryStreamIterator for distributed query where containerRid was not getting set in the responseHeaders, leading to an exception when building the response. We must set the container resource id in the CosmosQueryContext, prior to building the DistributedQueryPipelineStage. ## Type of change - [] Bug fix (non-breaking change which fixes an issue) --------- Co-authored-by: neildsh <35383880+neildsh@users.noreply.github.com> --- .../CosmosQueryExecutionContextFactory.cs | 18 ++--- .../Query/DistributedQueryClientTests.cs | 79 ++++++++++++++++++- 2 files changed, 86 insertions(+), 11 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Query/Core/Pipeline/CosmosQueryExecutionContextFactory.cs b/Microsoft.Azure.Cosmos/src/Query/Core/Pipeline/CosmosQueryExecutionContextFactory.cs index cb59ebbd80..88389f1bea 100644 --- a/Microsoft.Azure.Cosmos/src/Query/Core/Pipeline/CosmosQueryExecutionContextFactory.cs +++ b/Microsoft.Azure.Cosmos/src/Query/Core/Pipeline/CosmosQueryExecutionContextFactory.cs @@ -91,6 +91,15 @@ private static async Task> TryCreateCoreContextAsy { using (ITrace createQueryPipelineTrace = trace.StartChild("Create Query Pipeline", TraceComponent.Query, Tracing.TraceLevel.Info)) { + CosmosQueryClient cosmosQueryClient = cosmosQueryContext.QueryClient; + + ContainerQueryProperties containerQueryProperties = await cosmosQueryClient.GetCachedContainerQueryPropertiesAsync( + cosmosQueryContext.ResourceLink, + inputParameters.PartitionKey, + createQueryPipelineTrace, + cancellationToken); + cosmosQueryContext.ContainerResourceId = containerQueryProperties.ResourceId; + if (inputParameters.EnableDistributedQueryGatewayMode && cosmosQueryContext.ResourceTypeEnum == Documents.ResourceType.Document && cosmosQueryContext.OperationTypeEnum == Documents.OperationType.Query) @@ -152,15 +161,6 @@ private static async Task> TryCreateCoreContextAsy } } - CosmosQueryClient cosmosQueryClient = cosmosQueryContext.QueryClient; - - ContainerQueryProperties containerQueryProperties = await cosmosQueryClient.GetCachedContainerQueryPropertiesAsync( - cosmosQueryContext.ResourceLink, - inputParameters.PartitionKey, - createQueryPipelineTrace, - cancellationToken); - cosmosQueryContext.ContainerResourceId = containerQueryProperties.ResourceId; - Documents.PartitionKeyRange targetRange = await TryGetTargetRangeOptimisticDirectExecutionAsync( inputParameters, queryPlanFromContinuationToken, diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Query/DistributedQueryClientTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Query/DistributedQueryClientTests.cs index e771824eaa..3901f4287e 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Query/DistributedQueryClientTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Query/DistributedQueryClientTests.cs @@ -2,9 +2,11 @@ namespace Microsoft.Azure.Cosmos.EmulatorTests.Query { using System; using System.Collections.Generic; + using System.IO; using System.Linq; using System.Threading.Tasks; using Microsoft.Azure.Cosmos.CosmosElements; + using Microsoft.Azure.Cosmos.Json; using Microsoft.VisualStudio.TestTools.UnitTesting; [TestClass] @@ -135,6 +137,31 @@ public void TestDistributedQueryGatewayModeOverride() Assert.IsNull(Environment.GetEnvironmentVariable(ConfigurationManager.DistributedQueryGatewayModeEnabled)); } + [TestMethod] + public async Task StreamIteratorTestsAsync() + { + static Task ImplementationAsync(Container container, IReadOnlyList _) + { + int[] pageSizes = new[] { DocumentCount }; + + TestCase[] testCases = new[] + { + MakeTest( + "SELECT VALUE c.x FROM c WHERE c.x < 200", + pageSizes, + Expectations.AllDocumentsLessThan200ArePresent), + }; + + return RunStreamIteratorTestsAsync(container, testCases); + } + + await this.CreateIngestQueryDeleteAsync( + ConnectionModes.Gateway, + CollectionTypes.SinglePartition | CollectionTypes.MultiPartition, + CreateDocuments(DocumentCount), + ImplementationAsync); + } + private static async Task RunPartitionedParityTestsAsync(Container container, IEnumerable testCases) { IReadOnlyList feedRanges = await container.GetFeedRangesAsync(); @@ -201,7 +228,7 @@ private static async Task ContinuationTestsAsync(Container container, IEnumerabl foreach (int pageSize in testCase.PageSizes) { List results = await RunContinuationBasedQueryTestAsync(container, testCase.Query, pageSize); - testCase.ValidateResult(results); + Assert.IsTrue(testCase.ValidateResult(results)); } } } @@ -260,7 +287,55 @@ private static async Task RunTestsAsync(Container container, IEnumerable testCases) + { + foreach (TestCase testCase in testCases) + { + foreach (int pageSize in testCase.PageSizes) + { + QueryRequestOptions options = new QueryRequestOptions() + { + MaxItemCount = pageSize, + EnableDistributedQueryGatewayMode = true, + }; + + List extractedResults = new List(); + await foreach (ResponseMessage response in RunSimpleQueryAsync( + container, + testCase.Query, + options)) + { + Assert.AreEqual(System.Net.HttpStatusCode.OK, response.StatusCode); + + using (MemoryStream memoryStream = new MemoryStream()) + { + response.Content.CopyTo(memoryStream); + byte[] content = memoryStream.ToArray(); + + IJsonNavigator navigator = JsonNavigator.Create(content); + IJsonNavigatorNode rootNode = navigator.GetRootNode(); + + Assert.IsTrue(navigator.TryGetObjectProperty(rootNode, "_rid", out ObjectProperty ridProperty)); + string rid = navigator.GetStringValue(ridProperty.ValueNode); + Assert.IsTrue(rid.Length > 0); + + Assert.IsTrue(navigator.TryGetObjectProperty(rootNode, "Documents", out ObjectProperty documentsProperty)); + IEnumerable arrayItems = navigator.GetArrayItems(documentsProperty.ValueNode); + foreach (IJsonNavigatorNode node in arrayItems) + { + Assert.AreEqual(JsonNodeType.Number64, navigator.GetNodeType(node)); + + extractedResults.Add((int)Number64.ToLong(navigator.GetNumber64Value(node))); + } + } + } + + Assert.IsTrue(testCase.ValidateResult(extractedResults)); } } }