From 8fac2bbca05ddee8ecb4d6a5a68e5e3bbd8f667a Mon Sep 17 00:00:00 2001 From: Annie Liang <64233642+xinlian12@users.noreply.github.com> Date: Fri, 14 May 2021 12:20:43 -0700 Subject: [PATCH] ThroughputControl- Discard response out of cycle (#21369) * No ru tracking if response come back out of cycle Co-authored-by: annie-mac --- .../DocumentServiceRequestContext.java | 3 +- .../ThroughputControlTrackingUnit.java | 113 ++++++++++++++++++ .../ThroughputRequestThrottler.java | 71 +++++++++-- .../GlobalThroughputRequestController.java | 3 +- .../PkRangesThroughputRequestController.java | 2 +- .../directconnectivity/ReflectionUtils.java | 9 ++ .../ThroughputRequestThrottlerTests.java | 58 ++++++++- ...lobalThroughputRequestControllerTests.java | 4 + ...angesThroughputRequestControllerTests.java | 3 + 9 files changed, 254 insertions(+), 12 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/ThroughputControlTrackingUnit.java diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/DocumentServiceRequestContext.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/DocumentServiceRequestContext.java index d268589d3c80..493e32ed54c6 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/DocumentServiceRequestContext.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/DocumentServiceRequestContext.java @@ -38,6 +38,7 @@ public class DocumentServiceRequestContext implements Cloneable { public volatile PartitionKeyInternal effectivePartitionKey; public volatile CosmosDiagnostics cosmosDiagnostics; public volatile String resourcePhysicalAddress; + public volatile String throughputControlCycleId; public DocumentServiceRequestContext() { } @@ -99,7 +100,7 @@ public DocumentServiceRequestContext clone() { context.performedBackgroundAddressRefresh = this.performedBackgroundAddressRefresh; context.cosmosDiagnostics = this.cosmosDiagnostics; context.resourcePhysicalAddress = this.resourcePhysicalAddress; - + context.throughputControlCycleId = this.throughputControlCycleId; return context; } } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/ThroughputControlTrackingUnit.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/ThroughputControlTrackingUnit.java new file mode 100644 index 000000000000..020f56c8f172 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/ThroughputControlTrackingUnit.java @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.implementation.throughputControl; + +import com.azure.cosmos.implementation.OperationType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +public class ThroughputControlTrackingUnit { + + private static final Logger logger = LoggerFactory.getLogger(ThroughputControlTrackingUnit.class); + + private final OperationType operationType; + private final AtomicInteger rejectedRequests; + private final AtomicInteger passedRequests; + private final AtomicReference successRuUsage; + private final AtomicInteger successResponse; + private final AtomicInteger failedResponse; + private final AtomicInteger outOfCycleResponse; + private String throughputControlCycleId; + + public ThroughputControlTrackingUnit(OperationType operationType, String throughputControlCycleId) { + this.operationType = operationType; + + this.rejectedRequests = new AtomicInteger(0); + this.passedRequests = new AtomicInteger(0); + this.successRuUsage = new AtomicReference<>(0d); + this.successResponse = new AtomicInteger(0); + this.failedResponse = new AtomicInteger(0); + this.outOfCycleResponse = new AtomicInteger(0); + this.throughputControlCycleId = throughputControlCycleId; + } + + public void reset(String newCycleId) { + if (this.rejectedRequests.get() > 0 + || this.passedRequests.get() > 0 + || this.successResponse.get() > 0 + || this.failedResponse.get() > 0) { + + double sAvgRuPerRequest = 0.0; + if (this.successResponse.get() != 0) { + sAvgRuPerRequest = successRuUsage.get() / this.successResponse.get(); + } + + logger.debug( + "[CycleId: {}, operationType: {}, rejectedCnt: {}, passedCnt: {}, sAvgRu: {}, successCnt: {}, failedCnt: {}, outOfCycleCnt: {}]", + this.throughputControlCycleId, + this.operationType.toString(), + this.rejectedRequests.get(), + this.passedRequests.get(), + sAvgRuPerRequest, + this.successResponse.get(), + this.failedResponse.get(), + this.outOfCycleResponse.get()); + } + + this.rejectedRequests.set(0); + this.passedRequests.set(0); + this.successRuUsage.set(0d); + this.successResponse.set(0); + this.failedResponse.set(0); + this.outOfCycleResponse.set(0); + this.throughputControlCycleId = newCycleId; + } + + public void increasePassedRequest(){ + this.passedRequests.incrementAndGet(); + } + + public void increaseRejectedRequest(){ + this.rejectedRequests.incrementAndGet(); + } + + public void increaseSuccessResponse() { + this.successResponse.incrementAndGet(); + } + + public void increaseFailedResponse() { this.failedResponse.incrementAndGet(); } + + public void increaseOutOfCycleResponse() { this.outOfCycleResponse.incrementAndGet(); } + + public void trackRRuUsage(double ruUsage) { + this.successRuUsage.getAndAccumulate(ruUsage, (available, newRuUsage) -> available + newRuUsage); + } + + public int getRejectedRequests() { + return rejectedRequests.get(); + } + + public int getPassedRequests() { + return passedRequests.get(); + } + + public double getSuccessRuUsage() { + return successRuUsage.get(); + } + + public int getSuccessResponse() { + return successResponse.get(); + } + + public int getFailedResponse() { + return failedResponse.get(); + } + + public int getOutOfCycleResponse() { + return outOfCycleResponse.get(); + } +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/ThroughputRequestThrottler.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/ThroughputRequestThrottler.java index 98bd6d575e12..3a1197bc55e2 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/ThroughputRequestThrottler.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/ThroughputRequestThrottler.java @@ -6,16 +6,20 @@ import com.azure.cosmos.BridgeInternal; import com.azure.cosmos.CosmosException; import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.implementation.OperationType; import com.azure.cosmos.implementation.RequestRateTooLargeException; import com.azure.cosmos.implementation.RxDocumentServiceRequest; import com.azure.cosmos.implementation.RxDocumentServiceResponse; import com.azure.cosmos.implementation.Utils; +import com.azure.cosmos.implementation.apachecommons.lang.StringUtils; import com.azure.cosmos.implementation.directconnectivity.StoreResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Exceptions; import reactor.core.publisher.Mono; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.ReentrantReadWriteLock; @@ -29,13 +33,20 @@ public class ThroughputRequestThrottler { private final AtomicReference scheduledThroughput; private final ReentrantReadWriteLock.WriteLock throughputWriteLock; private final ReentrantReadWriteLock.ReadLock throughputReadLock; + private final ConcurrentHashMap trackingDictionary; + private final String pkRangeId; + private String cycleId; - public ThroughputRequestThrottler(double scheduledThroughput) { + public ThroughputRequestThrottler(double scheduledThroughput, String pkRangeId) { this.availableThroughput = new AtomicReference<>(scheduledThroughput); this.scheduledThroughput = new AtomicReference<>(scheduledThroughput); ReentrantReadWriteLock throughputReadWriteLock = new ReentrantReadWriteLock(); this.throughputWriteLock = throughputReadWriteLock.writeLock(); this.throughputReadLock = throughputReadWriteLock.readLock(); + + this.trackingDictionary = new ConcurrentHashMap<>(); + this.cycleId = UUID.randomUUID().toString(); + this.pkRangeId = pkRangeId; } public double renewThroughputUsageCycle(double scheduledThroughput) { @@ -45,6 +56,17 @@ public double renewThroughputUsageCycle(double scheduledThroughput) { this.scheduledThroughput.set(scheduledThroughput); this.updateAvailableThroughput(); + if (throughputUsagePercentage > 0) { + logger.debug( + "[CycleId: {}, pkRangeId: {}, ruUsagePercentage: {}]", + this.cycleId, this.pkRangeId, throughputUsagePercentage); + } + + String newCycleId = UUID.randomUUID().toString(); + for (ThroughputControlTrackingUnit trackingUnit : this.trackingDictionary.values()) { + trackingUnit.reset(newCycleId); + } + this.cycleId = newCycleId; return throughputUsagePercentage; } finally { this.throughputWriteLock.unlock(); @@ -60,15 +82,30 @@ private void updateAvailableThroughput() { public Mono processRequest(RxDocumentServiceRequest request, Mono originalRequestMono) { try { this.throughputReadLock.lock(); + ThroughputControlTrackingUnit trackingUnit = + this.trackingDictionary.compute(request.getOperationType(), ((key, value) -> { + if (value == null) { + value = new ThroughputControlTrackingUnit(request.getOperationType(), this.cycleId); + } + return value; + })); + if (this.availableThroughput.get() > 0) { + if (StringUtils.isEmpty(request.requestContext.throughputControlCycleId)) { + request.requestContext.throughputControlCycleId = this.cycleId; + } + + trackingUnit.increasePassedRequest(); return originalRequestMono - .doOnSuccess(response -> this.trackRequestCharge(response)) - .doOnError(throwable -> this.trackRequestCharge(throwable)); + .doOnSuccess(response -> this.trackRequestCharge(request, response)) + .doOnError(throwable -> this.trackRequestCharge(request, throwable)); } else { + trackingUnit.increaseRejectedRequest(); + // there is no enough throughput left, block request RequestRateTooLargeException requestRateTooLargeException = new RequestRateTooLargeException(); - int backoffTimeInMilliSeconds = (int)Math.floor(Math.abs(this.availableThroughput.get() * 1000 / this.scheduledThroughput.get())); + int backoffTimeInMilliSeconds = (int)Math.ceil(Math.abs(this.availableThroughput.get() / this.scheduledThroughput.get())) * 1000; requestRateTooLargeException.getResponseHeaders().put( HttpConstants.HttpHeaders.RETRY_AFTER_IN_MILLISECONDS, @@ -87,14 +124,14 @@ public Mono processRequest(RxDocumentServiceRequest request, Mono orig } finally { this.throughputReadLock.unlock(); } - } - private void trackRequestCharge (T response) { + private void trackRequestCharge (RxDocumentServiceRequest request, T response) { try { // Read lock is enough here. this.throughputReadLock.lock(); double requestCharge = 0; + boolean failedRequest = false; if (response instanceof StoreResponse) { requestCharge = ((StoreResponse)response).getRequestCharge(); } else if (response instanceof RxDocumentServiceResponse) { @@ -103,13 +140,31 @@ private void trackRequestCharge (T response) { CosmosException cosmosException = Utils.as(Exceptions.unwrap((Throwable) response), CosmosException.class); if (cosmosException != null) { requestCharge = cosmosException.getRequestCharge(); + failedRequest = true; + } + } + + ThroughputControlTrackingUnit trackingUnit = trackingDictionary.get(request.getOperationType()); + if (trackingUnit != null) { + if (failedRequest) { + trackingUnit.increaseFailedResponse(); + } else { + trackingUnit.increaseSuccessResponse(); + trackingUnit.trackRRuUsage(requestCharge); + } + } + + // If the response comes back in a different cycle, discard it. + if (StringUtils.equals(this.cycleId, request.requestContext.throughputControlCycleId)) { + this.availableThroughput.getAndAccumulate(requestCharge, (available, consumed) -> available - consumed); + } else { + if (trackingUnit != null) { + trackingUnit.increaseOutOfCycleResponse(); } } - this.availableThroughput.getAndAccumulate(requestCharge, (available, consumed) -> available - consumed); } finally { this.throughputReadLock.unlock(); } - } public double getAvailableThroughput() { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/controller/request/GlobalThroughputRequestController.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/controller/request/GlobalThroughputRequestController.java index c57f263fffde..21c9adfa52f8 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/controller/request/GlobalThroughputRequestController.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/controller/request/GlobalThroughputRequestController.java @@ -4,6 +4,7 @@ package com.azure.cosmos.implementation.throughputControl.controller.request; import com.azure.cosmos.implementation.RxDocumentServiceRequest; +import com.azure.cosmos.implementation.apachecommons.lang.StringUtils; import com.azure.cosmos.implementation.throughputControl.ThroughputRequestThrottler; import reactor.core.publisher.Mono; @@ -15,7 +16,7 @@ public class GlobalThroughputRequestController implements IThroughputRequestCont public GlobalThroughputRequestController(double initialScheduledThroughput) { this.scheduledThroughput = new AtomicReference<>(initialScheduledThroughput); - this.requestThrottler = new ThroughputRequestThrottler(this.scheduledThroughput.get()); + this.requestThrottler = new ThroughputRequestThrottler(this.scheduledThroughput.get(), StringUtils.EMPTY); } @Override diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/controller/request/PkRangesThroughputRequestController.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/controller/request/PkRangesThroughputRequestController.java index 61f800148b6d..9d57c793e249 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/controller/request/PkRangesThroughputRequestController.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/throughputControl/controller/request/PkRangesThroughputRequestController.java @@ -87,7 +87,7 @@ private void createRequestThrottlers() { for (PartitionKeyRange pkRange : pkRanges) { requestThrottlerMap.compute(pkRange.getId(), (pkRangeId, requestThrottler) -> { if (requestThrottler == null) { - requestThrottler = new ThroughputRequestThrottler(throughputPerPkRange); + requestThrottler = new ThroughputRequestThrottler(throughputPerPkRange, pkRangeId); } return requestThrottler; diff --git a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/ReflectionUtils.java b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/ReflectionUtils.java index 7d0c71851049..07b15848a3c0 100644 --- a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/ReflectionUtils.java +++ b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/ReflectionUtils.java @@ -13,6 +13,7 @@ import com.azure.cosmos.implementation.ConnectionPolicy; import com.azure.cosmos.implementation.DocumentCollection; import com.azure.cosmos.implementation.GlobalEndpointManager; +import com.azure.cosmos.implementation.OperationType; import com.azure.cosmos.implementation.RetryContext; import com.azure.cosmos.implementation.RxDocumentClientImpl; import com.azure.cosmos.implementation.RxStoreModel; @@ -28,6 +29,8 @@ import com.azure.cosmos.implementation.directconnectivity.rntbd.RntbdEndpoint; import com.azure.cosmos.implementation.http.HttpClient; import com.azure.cosmos.implementation.routing.CollectionRoutingMap; +import com.azure.cosmos.implementation.throughputControl.ThroughputControlTests; +import com.azure.cosmos.implementation.throughputControl.ThroughputControlTrackingUnit; import com.azure.cosmos.implementation.throughputControl.ThroughputRequestThrottler; import com.azure.cosmos.implementation.throughputControl.controller.request.GlobalThroughputRequestController; import com.azure.cosmos.implementation.throughputControl.controller.request.PkRangesThroughputRequestController; @@ -299,4 +302,10 @@ public static AsyncCache getRoutingMapAsyncCache(R public static AtomicBoolean isInitialized(CosmosAsyncContainer cosmosAsyncContainer) { return get(AtomicBoolean.class, cosmosAsyncContainer, "isInitialized"); } + + @SuppressWarnings("unchecked") + public static ConcurrentHashMap getThroughputControlTrackingDictionary( + ThroughputRequestThrottler requestThrottler) { + return get(ConcurrentHashMap.class, requestThrottler, "trackingDictionary"); + } } diff --git a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/ThroughputRequestThrottlerTests.java b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/ThroughputRequestThrottlerTests.java index cb775b796314..f00ec4c37460 100644 --- a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/ThroughputRequestThrottlerTests.java +++ b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/ThroughputRequestThrottlerTests.java @@ -3,15 +3,22 @@ package com.azure.cosmos.implementation.throughputControl; +import com.azure.cosmos.implementation.DocumentServiceRequestContext; import com.azure.cosmos.implementation.NotFoundException; +import com.azure.cosmos.implementation.OperationType; import com.azure.cosmos.implementation.RequestRateTooLargeException; import com.azure.cosmos.implementation.RxDocumentServiceRequest; +import com.azure.cosmos.implementation.apachecommons.lang.StringUtils; +import com.azure.cosmos.implementation.directconnectivity.ReflectionUtils; import com.azure.cosmos.implementation.directconnectivity.StoreResponse; import org.mockito.Mockito; import org.testng.annotations.Test; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + import static org.assertj.core.api.Assertions.assertThat; public class ThroughputRequestThrottlerTests { @@ -24,10 +31,13 @@ public void processRequest() { double availableThroughput = scheduledThroughput; RxDocumentServiceRequest requestMock = Mockito.mock(RxDocumentServiceRequest.class); + Mockito.doReturn(OperationType.Read).when(requestMock).getOperationType(); + requestMock.requestContext = new DocumentServiceRequestContext(); + StoreResponse responseMock = Mockito.mock(StoreResponse.class); Mockito.doReturn(requestChargePerRequest).when(responseMock).getRequestCharge(); - ThroughputRequestThrottler requestThrottler = new ThroughputRequestThrottler(scheduledThroughput); + ThroughputRequestThrottler requestThrottler = new ThroughputRequestThrottler(scheduledThroughput, StringUtils.EMPTY); // Request1: pass through TestPublisher requestPublisher1 = TestPublisher.create(); @@ -40,6 +50,7 @@ public void processRequest() { this.assertRequestThrottlerState(requestThrottler, availableThroughput, scheduledThroughput); // Request2: will get throttled since there is no available throughput + requestMock.requestContext.throughputControlCycleId = StringUtils.EMPTY; TestPublisher requestPublisher2 = TestPublisher.create(); StepVerifier.create(requestThrottler.processRequest(requestMock, requestPublisher2.mono())) .verifyError(RequestRateTooLargeException.class); @@ -52,6 +63,7 @@ public void processRequest() { assertThat(requestThrottler.getAvailableThroughput()).isEqualTo(availableThroughput); // Request 3: will get throttled since there is no available throughput + requestMock.requestContext.throughputControlCycleId = StringUtils.EMPTY; TestPublisher requestPublisher3 = TestPublisher.create(); StepVerifier.create(requestThrottler.processRequest(requestMock, requestPublisher3.mono())) .verifyError(RequestRateTooLargeException.class); @@ -64,6 +76,7 @@ public void processRequest() { assertThat(requestThrottler.getAvailableThroughput()).isEqualTo(availableThroughput); // Request 4: will pass the request, and record the charge from exception + requestMock.requestContext.throughputControlCycleId = StringUtils.EMPTY; NotFoundException notFoundException = Mockito.mock(NotFoundException.class); Mockito.doReturn(requestChargePerRequest).when(notFoundException).getRequestCharge(); TestPublisher requestPublisher4 = TestPublisher.create(); @@ -75,6 +88,49 @@ public void processRequest() { this.assertRequestThrottlerState(requestThrottler, availableThroughput, scheduledThroughput); } + @Test(groups = "unit") + public void responseOutOfCycle() { + double requestChargePerRequest = 2.0; + double scheduledThroughput = 1.0; + double availableThroughput = scheduledThroughput; + OperationType operationType = OperationType.Read; + + RxDocumentServiceRequest requestMock = Mockito.mock(RxDocumentServiceRequest.class); + Mockito.doReturn(operationType).when(requestMock).getOperationType(); + requestMock.requestContext = new DocumentServiceRequestContext(); + + StoreResponse responseMock = Mockito.mock(StoreResponse.class); + Mockito.doReturn(requestChargePerRequest).when(responseMock).getRequestCharge(); + + ThroughputRequestThrottler requestThrottler = new ThroughputRequestThrottler(scheduledThroughput, StringUtils.EMPTY); + + // Request1: pass through + TestPublisher requestPublisher1 = TestPublisher.create(); + StepVerifier.create(requestThrottler.processRequest(requestMock, requestPublisher1.mono())) + .then(() -> { + requestMock.requestContext.throughputControlCycleId = UUID.randomUUID().toString(); + requestPublisher1.emit(responseMock); + }) + .expectNext(responseMock) + .verifyComplete(); + + // verify no throughput will be deducted from available throughput because the response came back during a different throughput cycle + this.assertRequestThrottlerState(requestThrottler, availableThroughput, scheduledThroughput); + + ConcurrentHashMap trackingUnitDictionary = + ReflectionUtils.getThroughputControlTrackingDictionary(requestThrottler); + assertThat(trackingUnitDictionary).isNotNull(); + assertThat(trackingUnitDictionary.size()).isEqualTo(1); + ThroughputControlTrackingUnit readOperationTrackingUnit = trackingUnitDictionary.get(operationType); + assertThat(readOperationTrackingUnit).isNotNull(); + assertThat(readOperationTrackingUnit.getRejectedRequests()).isEqualTo(0); + assertThat(readOperationTrackingUnit.getPassedRequests()).isEqualTo(1); + assertThat(readOperationTrackingUnit.getSuccessRuUsage()).isEqualTo(requestChargePerRequest); + assertThat(readOperationTrackingUnit.getSuccessResponse()).isEqualTo(1); + assertThat(readOperationTrackingUnit.getFailedResponse()).isEqualTo(0); + assertThat(readOperationTrackingUnit.getOutOfCycleResponse()).isEqualTo(1); + } + private void assertRequestThrottlerState( ThroughputRequestThrottler requestThrottler, double expectedAvailableThroughput, diff --git a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/controller/GlobalThroughputRequestControllerTests.java b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/controller/GlobalThroughputRequestControllerTests.java index a551e94f6af3..527ab40f6211 100644 --- a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/controller/GlobalThroughputRequestControllerTests.java +++ b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/controller/GlobalThroughputRequestControllerTests.java @@ -3,6 +3,8 @@ package com.azure.cosmos.implementation.throughputControl.controller; +import com.azure.cosmos.implementation.DocumentServiceRequestContext; +import com.azure.cosmos.implementation.OperationType; import com.azure.cosmos.implementation.RxDocumentServiceRequest; import com.azure.cosmos.implementation.directconnectivity.ReflectionUtils; import com.azure.cosmos.implementation.directconnectivity.StoreResponse; @@ -49,6 +51,8 @@ public void processRequest() { // First request: Can find the matching region request throttler in request controller RxDocumentServiceRequest request1Mock = Mockito.mock(RxDocumentServiceRequest.class); + Mockito.doReturn(OperationType.Read).when(request1Mock).getOperationType(); + request1Mock.requestContext = new DocumentServiceRequestContext(); TestPublisher request1MonoPublisher = TestPublisher.create(); Mono request1Mono = request1MonoPublisher.mono(); diff --git a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/controller/PkRangesThroughputRequestControllerTests.java b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/controller/PkRangesThroughputRequestControllerTests.java index 0337d76c21dd..9812d32a3451 100644 --- a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/controller/PkRangesThroughputRequestControllerTests.java +++ b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/throughputControl/controller/PkRangesThroughputRequestControllerTests.java @@ -4,6 +4,7 @@ package com.azure.cosmos.implementation.throughputControl.controller; import com.azure.cosmos.implementation.DocumentServiceRequestContext; +import com.azure.cosmos.implementation.OperationType; import com.azure.cosmos.implementation.PartitionKeyRange; import com.azure.cosmos.implementation.RxDocumentServiceRequest; import com.azure.cosmos.implementation.Utils; @@ -30,6 +31,7 @@ import java.util.concurrent.ConcurrentHashMap; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.OPTIONAL; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -144,6 +146,7 @@ public void renewThroughputUsageCycle() { private RxDocumentServiceRequest createMockRequest(PartitionKeyRange resolvedPkRange) { RxDocumentServiceRequest requestMock = Mockito.mock(RxDocumentServiceRequest.class); + Mockito.doReturn(OperationType.Read).when(requestMock).getOperationType(); DocumentServiceRequestContext requestContextMock = Mockito.mock(DocumentServiceRequestContext.class); requestContextMock.resolvedPartitionKeyRange = resolvedPkRange; requestMock.requestContext = requestContextMock;