Skip to content

Commit

Permalink
Connectivity: Fixes address resolution calls when using EnableTcpConn…
Browse files Browse the repository at this point in the history
…ectionEndpointRediscovery (#1758)

* Removing cache

* Adding tests

* trace message

Co-authored-by: j82w <j82w@users.noreply.github.com>
  • Loading branch information
ealsur and j82w authored Aug 7, 2020
1 parent 1a8e087 commit c1c15ce
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 15 deletions.
17 changes: 4 additions & 13 deletions Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ public async Task<PartitionAddressInformation> TryGetAddressesAsync(
}
}

public async Task TryUpdateAddressAsync(
public Task TryRemoveAddressesAsync(
ServerKey serverKey,
CancellationToken cancellationToken)
{
Expand All @@ -263,29 +263,20 @@ public async Task TryUpdateAddressAsync(
{
foreach (PartitionKeyRangeIdentity pkRangeId in pkRangeIds)
{
DefaultTrace.TraceInformation("Refresh addresses for collectionRid :{0}, pkRangeId: {1}, serviceEndpoint: {2}",
DefaultTrace.TraceInformation("Remove addresses for collectionRid :{0}, pkRangeId: {1}, serviceEndpoint: {2}",
pkRangeId.CollectionRid,
pkRangeId.PartitionKeyRangeId,
this.serviceEndpoint);

tasks.Add(this.serverPartitionAddressCache.GetAsync(
pkRangeId,
null,
() => this.GetAddressesForRangeIdAsync(
null,
pkRangeId.CollectionRid,
pkRangeId.PartitionKeyRangeId,
forceRefresh: true),
cancellationToken,
forceRefresh: true));
tasks.Add(this.serverPartitionAddressCache.RemoveAsync(pkRangeId));
}

// remove the server key from the map since we are updating the addresses
HashSet<PartitionKeyRangeIdentity> ignorePkRanges;
this.serverPartitionAddressToPkRangeIdMap.TryRemove(serverKey, out ignorePkRanges);
}

await Task.WhenAll(tasks);
return Task.WhenAll(tasks);
}

public async Task<PartitionAddressInformation> UpdateAsync(
Expand Down
4 changes: 2 additions & 2 deletions Microsoft.Azure.Cosmos/src/Routing/GlobalAddressResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ public async Task UpdateAsync(

foreach (KeyValuePair<Uri, EndpointCache> addressCache in this.addressCacheByEndpoint)
{
// since we don't know which address cache contains the pkRanges mapped to this node, we do a tryUpdate on all AddressCaches of all regions
tasks.Add(addressCache.Value.AddressCache.TryUpdateAddressAsync(serverKey, cancellationToken));
// since we don't know which address cache contains the pkRanges mapped to this node, we do a tryRemove on all AddressCaches of all regions
tasks.Add(addressCache.Value.AddressCache.TryRemoveAddressesAsync(serverKey, cancellationToken));
}

await Task.WhenAll(tasks);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------

namespace Microsoft.Azure.Cosmos
{
using System;
using System.Net.Http;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;
using Microsoft.Azure.Documents;
using Microsoft.Azure.Cosmos.Routing;
using System.Threading.Tasks;
using System.Threading;
using System.Net;
using System.Text;
using System.Collections.ObjectModel;
using System.Collections.Generic;
using System.Linq;

/// <summary>
/// Tests for <see cref="GatewayAddressCache"/>.
/// </summary>
[TestClass]
public class GatewayAddressCacheTests
{
private const string DatabaseAccountApiEndpoint = "https://endpoint.azure.com";
private Mock<IAuthorizationTokenProvider> mockTokenProvider;
private Mock<IServiceConfigurationReader> mockServiceConfigReader;
private int targetReplicaSetSize = 4;
private PartitionKeyRangeIdentity testPartitionKeyRangeIdentity;
private ServiceIdentity serviceIdentity;
private Uri serviceName;

public GatewayAddressCacheTests()
{
this.mockTokenProvider = new Mock<IAuthorizationTokenProvider>();
string payload;
this.mockTokenProvider.Setup(foo => foo.GetUserAuthorizationToken(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<Documents.Collections.INameValueCollection>(), It.IsAny<AuthorizationTokenType>(), out payload))
.Returns("token!");
this.mockServiceConfigReader = new Mock<IServiceConfigurationReader>();
this.mockServiceConfigReader.Setup(foo => foo.SystemReplicationPolicy).Returns(new ReplicationPolicy() { MaxReplicaSetSize = this.targetReplicaSetSize });
this.mockServiceConfigReader.Setup(foo => foo.UserReplicationPolicy).Returns(new ReplicationPolicy() { MaxReplicaSetSize = this.targetReplicaSetSize });
this.testPartitionKeyRangeIdentity = new PartitionKeyRangeIdentity("YxM9ANCZIwABAAAAAAAAAA==", "YxM9ANCZIwABAAAAAAAAAA==");
this.serviceName = new Uri(GatewayAddressCacheTests.DatabaseAccountApiEndpoint);
this.serviceIdentity = new ServiceIdentity("federation1", this.serviceName, false);
}

[TestMethod]
public void TestGatewayAddressCacheAutoRefreshOnSuboptimalPartition()
{
FakeMessageHandler messageHandler = new FakeMessageHandler();
HttpClient httpClient = new HttpClient(messageHandler);
httpClient.Timeout = TimeSpan.FromSeconds(120);
GatewayAddressCache cache = new GatewayAddressCache(
new Uri(GatewayAddressCacheTests.DatabaseAccountApiEndpoint),
Documents.Client.Protocol.Https,
this.mockTokenProvider.Object,
this.mockServiceConfigReader.Object,
httpClient,
suboptimalPartitionForceRefreshIntervalInSeconds: 2);

int initialAddressesCount = cache.TryGetAddressesAsync(
DocumentServiceRequest.Create(OperationType.Invalid, ResourceType.Address, AuthorizationTokenType.Invalid),
this.testPartitionKeyRangeIdentity,
this.serviceIdentity,
false,
CancellationToken.None).Result.AllAddresses.Count();

Assert.IsTrue(initialAddressesCount < this.targetReplicaSetSize);

Task.Delay(3000).Wait();

int finalAddressCount = cache.TryGetAddressesAsync(
DocumentServiceRequest.Create(OperationType.Invalid, ResourceType.Address, AuthorizationTokenType.Invalid),
this.testPartitionKeyRangeIdentity,
this.serviceIdentity,
false,
CancellationToken.None).Result.AllAddresses.Count();

Assert.IsTrue(finalAddressCount == this.targetReplicaSetSize);
}

[TestMethod]
public async Task TestGatewayAddressCacheUpdateOnConnectionResetAsync()
{
FakeMessageHandler messageHandler = new FakeMessageHandler();
HttpClient httpClient = new HttpClient(messageHandler);
httpClient.Timeout = TimeSpan.FromSeconds(120);
GatewayAddressCache cache = new GatewayAddressCache(
new Uri(GatewayAddressCacheTests.DatabaseAccountApiEndpoint),
Documents.Client.Protocol.Https,
this.mockTokenProvider.Object,
this.mockServiceConfigReader.Object,
httpClient,
suboptimalPartitionForceRefreshIntervalInSeconds: 2,
enableTcpConnectionEndpointRediscovery: true);

PartitionAddressInformation addresses = cache.TryGetAddressesAsync(
DocumentServiceRequest.Create(OperationType.Invalid, ResourceType.Address, AuthorizationTokenType.Invalid),
this.testPartitionKeyRangeIdentity,
this.serviceIdentity,
false,
CancellationToken.None).Result;

Assert.IsNotNull(addresses.AllAddresses.Select(address => address.PhysicalUri == "https://blabla.com"));

// call updateAddress
await cache.TryRemoveAddressesAsync(new Documents.Rntbd.ServerKey(new Uri("https://blabla.com")), CancellationToken.None);

// check if the addresss is updated
addresses = cache.TryGetAddressesAsync(
DocumentServiceRequest.Create(OperationType.Invalid, ResourceType.Address, AuthorizationTokenType.Invalid),
this.testPartitionKeyRangeIdentity,
this.serviceIdentity,
false,
CancellationToken.None).Result;

Assert.IsNotNull(addresses.AllAddresses.Select(address => address.PhysicalUri == "https://blabla5.com"));
}

[TestMethod]
[Timeout(2000)]
public void GlobalAddressResolverUpdateAsyncSynchronizationTest()
{
SynchronizationContext prevContext = SynchronizationContext.Current;
try
{
TestSynchronizationContext syncContext = new TestSynchronizationContext();
SynchronizationContext.SetSynchronizationContext(syncContext);
syncContext.Post(_ =>
{
UserAgentContainer container = new UserAgentContainer();
FakeMessageHandler messageHandler = new FakeMessageHandler();

AccountProperties databaseAccount = new AccountProperties();
Mock<IDocumentClientInternal> mockDocumentClient = new Mock<IDocumentClientInternal>();
mockDocumentClient.Setup(owner => owner.ServiceEndpoint).Returns(new Uri("https://blabla.com/"));
mockDocumentClient.Setup(owner => owner.GetDatabaseAccountInternalAsync(It.IsAny<Uri>(), It.IsAny<CancellationToken>())).ReturnsAsync(databaseAccount);

GlobalEndpointManager globalEndpointManager = new GlobalEndpointManager(mockDocumentClient.Object, new ConnectionPolicy());

ConnectionPolicy connectionPolicy = new ConnectionPolicy
{
RequestTimeout = TimeSpan.FromSeconds(10)
};

GlobalAddressResolver globalAddressResolver = new GlobalAddressResolver(globalEndpointManager, Documents.Client.Protocol.Tcp, this.mockTokenProvider.Object, null, null, this.mockServiceConfigReader.Object, connectionPolicy, new HttpClient(messageHandler));

ConnectionStateListener connectionStateListener = new ConnectionStateListener(globalAddressResolver);
connectionStateListener.OnConnectionEvent(ConnectionEvent.ReadEof, DateTime.Now, new Documents.Rntbd.ServerKey(new Uri("https://endpoint.azure.com:4040/")));

}, state: null);
}
finally
{
SynchronizationContext.SetSynchronizationContext(prevContext);
}
}

private class FakeMessageHandler : HttpMessageHandler
{
private bool returnFullReplicaSet;
private bool returnUpdatedAddresses;

public FakeMessageHandler()
{
this.returnFullReplicaSet = false;
this.returnUpdatedAddresses = false;
}

protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
List<Address> addresses = new List<Address>()
{
new Address() { IsPrimary = true, PhysicalUri = "https://blabla.com", Protocol = RuntimeConstants.Protocols.HTTPS, PartitionKeyRangeId = "YxM9ANCZIwABAAAAAAAAAA==" },
new Address() { IsPrimary = false, PhysicalUri = "https://blabla3.com", Protocol = RuntimeConstants.Protocols.HTTPS, PartitionKeyRangeId = "YxM9ANCZIwABAAAAAAAAAA==" },
new Address() { IsPrimary = false, PhysicalUri = "https://blabla2.com", Protocol = RuntimeConstants.Protocols.HTTPS, PartitionKeyRangeId = "YxM9ANCZIwABAAAAAAAAAA==" },
};

if (this.returnFullReplicaSet)
{
addresses.Add(new Address() { IsPrimary = false, PhysicalUri = "https://blabla4.com", Protocol = RuntimeConstants.Protocols.HTTPS, PartitionKeyRangeId = "YxM9ANCZIwABAAAAAAAAAA==" });
this.returnFullReplicaSet = false;
}
else
{
this.returnFullReplicaSet = true;
}

if (this.returnUpdatedAddresses)
{
addresses.RemoveAll(address => address.IsPrimary == true);
addresses.Add(new Address() { IsPrimary = true, PhysicalUri = "https://blabla5.com", Protocol = RuntimeConstants.Protocols.HTTPS, PartitionKeyRangeId = "YxM9ANCZIwABAAAAAAAAAA==" });
this.returnUpdatedAddresses = false;
}
else
{
this.returnUpdatedAddresses = true;
}

FeedResource<Address> addressFeedResource = new FeedResource<Address>()
{
Id = "YxM9ANCZIwABAAAAAAAAAA==",
SelfLink = "dbs/YxM9AA==/colls/YxM9ANCZIwA=/docs/YxM9ANCZIwABAAAAAAAAAA==/",
Timestamp = DateTime.Now,
InnerCollection = new Collection<Address>(addresses),
};

StringBuilder feedResourceString = new StringBuilder();
addressFeedResource.SaveTo(feedResourceString);

StringContent content = new StringContent(feedResourceString.ToString());
HttpResponseMessage responseMessage = new HttpResponseMessage()
{
StatusCode = HttpStatusCode.OK,
Content = content,
};

return Task.FromResult<HttpResponseMessage>(responseMessage);
}
}

public class TestSynchronizationContext : SynchronizationContext
{
private object locker = new object();
public override void Post(SendOrPostCallback d, object state)
{
lock (this.locker)
{
d(state);
}
}
}
}
}

0 comments on commit c1c15ce

Please sign in to comment.