Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for using ip address in discovery response. #322

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion core/src/main/java/tech/ydb/core/impl/YdbDiscovery.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tech.ydb.core.operation.OperationBinder;
import tech.ydb.core.utils.FutureTools;
import tech.ydb.proto.discovery.DiscoveryProtos;
import tech.ydb.proto.discovery.DiscoveryProtos.EndpointInfo;
import tech.ydb.proto.discovery.v1.DiscoveryServiceGrpc;

/**
Expand Down Expand Up @@ -185,6 +186,21 @@ private void handleOk(String selfLocation, List<EndpointRecord> endpoints) {
}
}

private static String createAddress(EndpointInfo e) {
String addr;
if (e.getIpV6Count() > 0 && e.getIpV6(0) != null && !e.getIpV6(0).isEmpty()) {
addr = e.getIpV6(0);
} else if (e.getIpV4Count() > 0 && e.getIpV4(0) != null && !e.getIpV4(0).isEmpty()) {
addr = e.getIpV4(0);
} else {
addr = e.getAddress();
}

logger.debug("address {} will be used to connect to node {}", addr, e.getAddress());

return addr;
}

private void handleDiscoveryResult(Result<DiscoveryProtos.ListEndpointsResult> response, Throwable th) {
if (th != null) {
Throwable cause = FutureTools.unwrapCompletionException(th);
Expand All @@ -202,7 +218,8 @@ private void handleDiscoveryResult(Result<DiscoveryProtos.ListEndpointsResult> r
}

List<EndpointRecord> records = result.getEndpointsList().stream()
.map(e -> new EndpointRecord(e.getAddress(), e.getPort(), e.getNodeId(), e.getLocation()))
.map(e -> new EndpointRecord(createAddress(e), e.getPort(), e.getNodeId(), e.getLocation(),
e.getSslTargetNameOverride()))
.collect(Collectors.toList());

logger.debug("successfully received ListEndpoints result with {} endpoints", records.size());
Expand Down
17 changes: 14 additions & 3 deletions core/src/main/java/tech/ydb/core/impl/pool/EndpointRecord.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,35 @@ public class EndpointRecord {
private final String host;
private final String hostAndPort;
private final String locationDC;
private final String authority;
private final int port;
private final int nodeId;

public EndpointRecord(String host, int port, int nodeId, String locationDC) {
public EndpointRecord(String host, int port, int nodeId, String locationDC, String authority) {
this.host = Objects.requireNonNull(host);
this.port = port;
this.hostAndPort = host + ":" + port;
this.nodeId = nodeId;
this.locationDC = locationDC;
if (authority != null && !authority.isEmpty()) {
this.authority = authority;
} else {
this.authority = null;
}
}

public EndpointRecord(String host, int port) {
this(host, port, 0, null);
this(host, port, 0, null, null);
}

public String getHost() {
return host;
}

public String getAuthority() {
return authority;
}

public int getPort() {
return port;
}
Expand All @@ -46,6 +56,7 @@ public String getLocation() {

@Override
public String toString() {
return "Endpoint{host=" + host + ", port=" + port + ", node=" + nodeId + ", location=" + locationDC + "}";
return "Endpoint{host=" + host + ", port=" + port + ", node=" + nodeId +
", location=" + locationDC + ", overrideAuthority=" + authority + "}";
}
}
3 changes: 2 additions & 1 deletion core/src/main/java/tech/ydb/core/impl/pool/GrpcChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public GrpcChannel(EndpointRecord endpoint, ManagedChannelFactory factory) {
try {
logger.debug("Creating grpc channel with {}", endpoint);
this.endpoint = endpoint;
this.channel = factory.newManagedChannel(endpoint.getHost(), endpoint.getPort());
this.channel = factory.newManagedChannel(endpoint.getHost(), endpoint.getPort(),
endpoint.getAuthority());
this.connectTimeoutMs = factory.getConnectTimeoutMs();
this.readyWatcher = new ReadyWatcher();
this.readyWatcher.checkState();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ interface Builder {
ManagedChannelFactory buildFactory(GrpcTransportBuilder builder);
}

ManagedChannel newManagedChannel(String host, int port);
ManagedChannel newManagedChannel(String host, int port, String authority);

long getConnectTimeoutMs();
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ public long getConnectTimeoutMs() {

@SuppressWarnings("deprecation")
@Override
public ManagedChannel newManagedChannel(String host, int port) {
public ManagedChannel newManagedChannel(String host, int port, String sslHostOverride) {
NettyChannelBuilder channelBuilder = NettyChannelBuilder
.forAddress(host, port);

if (useTLS) {
channelBuilder
.negotiationType(NegotiationType.TLS)
.sslContext(createSslContext());
if (sslHostOverride != null) {
channelBuilder.overrideAuthority(sslHostOverride);
}
} else {
channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ public long getConnectTimeoutMs() {

@SuppressWarnings("deprecation")
@Override
public ManagedChannel newManagedChannel(String host, int port) {
public ManagedChannel newManagedChannel(String host, int port, String sslHostOverride) {
NettyChannelBuilder channelBuilder = NettyChannelBuilder
.forAddress(host, port);

if (useTLS) {
channelBuilder
.negotiationType(NegotiationType.TLS)
.sslContext(createSslContext());
if (sslHostOverride != null) {
channelBuilder.overrideAuthority(sslHostOverride);
}
} else {
channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void setUp() throws InterruptedException {
Mockito.when(channel.shutdownNow()).thenReturn(channel);
Mockito.when(channel.awaitTermination(Mockito.anyLong(), Mockito.any())).thenReturn(true);

Mockito.when(channelFactory.newManagedChannel(Mockito.any(), Mockito.anyInt())).thenReturn(channel);
Mockito.when(channelFactory.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull())).thenReturn(channel);
}

private <T extends Throwable> T checkFutureException(CompletableFuture<Boolean> f, String message, Class<T> clazz) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ public void setUp() throws InterruptedException {
Mockito.when(transportChannel.shutdownNow()).thenReturn(transportChannel);
Mockito.when(transportChannel.awaitTermination(Mockito.anyLong(), Mockito.any())).thenReturn(true);

Mockito.when(channelFactory.newManagedChannel(Mockito.eq("mocked"), Mockito.eq(2136)))
Mockito.when(channelFactory.newManagedChannel(Mockito.eq("mocked"), Mockito.eq(2136), Mockito.isNull()))
.thenReturn(discoveryChannel);
Mockito.when(channelFactory.newManagedChannel(Mockito.eq("node"), Mockito.eq(2136)))
Mockito.when(channelFactory.newManagedChannel(Mockito.eq("node"), Mockito.eq(2136), Mockito.isNull()))
.thenReturn(transportChannel);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void defaultParams() {
channelStaticMock.verify(FOR_ADDRESS, times(0));

Assert.assertEquals(30_000l, factory.getConnectTimeoutMs());
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));

channelStaticMock.verify(FOR_ADDRESS, times(1));

Expand All @@ -100,7 +100,7 @@ public void defaultSslFactory() {
channelStaticMock.verify(FOR_ADDRESS, times(0));

Assert.assertEquals(60000l, factory.getConnectTimeoutMs());
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));

channelStaticMock.verify(FOR_ADDRESS, times(1));

Expand All @@ -124,7 +124,7 @@ public void customChannelInitializer() {

channelStaticMock.verify(FOR_ADDRESS, times(0));

Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));

channelStaticMock.verify(FOR_ADDRESS, times(1));

Expand All @@ -150,7 +150,7 @@ public void customSslFactory() throws CertificateException, IOException {
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);

Assert.assertEquals(4000l, factory.getConnectTimeoutMs());
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));

} finally {
selfSignedCert.delete();
Expand All @@ -176,7 +176,7 @@ public void invalidSslCert() {
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);

RuntimeException ex = Assert.assertThrows(RuntimeException.class,
() -> factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
() -> factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));

Assert.assertEquals("cannot create ssl context", ex.getMessage());
Assert.assertNotNull(ex.getCause());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ public void nodePessimizationTest() {
check(pool.getEndpoint(2)).hostname("n2.ydb.tech").nodeID(2).port(12342);

// Pessimize unknown nodes - nothing is changed
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12341, 2, null));
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12342, 2, null));
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12341, 2, null, null));
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12342, 2, null, null));
pool.pessimizeEndpoint(null);
check(pool).records(5).knownNodes(5).needToReDiscovery(false).bestEndpointsCount(4);

Expand Down Expand Up @@ -553,6 +553,6 @@ private static List<EndpointRecord> list(EndpointRecord... records) {
}

private static EndpointRecord endpoint(int nodeID, String hostname, int port, String location) {
return new EndpointRecord(hostname, port, nodeID, location);
return new EndpointRecord(hostname, port, nodeID, location, null);
}
}
20 changes: 10 additions & 10 deletions core/src/test/java/tech/ydb/core/impl/pool/GrpcChannelPoolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class GrpcChannelPoolTest {
@Before
public void setUp() {
Mockito.when(factoryMock.getConnectTimeoutMs()).thenReturn(500l); // timeout for ready watcher
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt()))
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull()))
.then((args) -> ManagedChannelMock.good());
}

Expand All @@ -34,8 +34,8 @@ public void tearDown() throws Exception {

@Test
public void simpleTest() {
EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null);
EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null, null);

GrpcChannelPool pool = new GrpcChannelPool(factoryMock, scheduler);
Assert.assertEquals(0, pool.getChannels().size());
Expand Down Expand Up @@ -66,9 +66,9 @@ public void simpleTest() {

@Test
public void removeChannels() {
EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null);
EndpointRecord e3 = new EndpointRecord("host1", 1236, 12, null);
EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null, null);
EndpointRecord e3 = new EndpointRecord("host1", 1236, 12, null, null);

GrpcChannelPool pool = new GrpcChannelPool(factoryMock, scheduler);
Assert.assertEquals(0, pool.getChannels().size());
Expand Down Expand Up @@ -121,13 +121,13 @@ public void removeChannels() {

@Test
public void badShutdownTest() {
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt())).thenReturn(
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull())).thenReturn(
ManagedChannelMock.good(), ManagedChannelMock.good(),
ManagedChannelMock.wrongShutdown(), ManagedChannelMock.wrongShutdown());

EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null);
EndpointRecord e3 = new EndpointRecord("host1", 1236, 12, null);
EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null, null);
EndpointRecord e3 = new EndpointRecord("host1", 1236, 12, null, null);

GrpcChannelPool pool = new GrpcChannelPool(factoryMock, scheduler);
Assert.assertEquals(0, pool.getChannels().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void setUp() {

@Test
public void goodChannels() {
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt()))
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull()))
.thenReturn(ManagedChannelMock.good(), ManagedChannelMock.good());

EndpointRecord endpoint = new EndpointRecord("host1", 1234);
Expand All @@ -52,7 +52,7 @@ public void slowChannels() {
ConnectivityState.READY,
};

Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt()))
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull()))
.thenReturn(new ManagedChannelMock(ConnectivityState.IDLE).nextStates(states))
.thenReturn(new ManagedChannelMock(ConnectivityState.IDLE).nextStates(states));

Expand All @@ -74,7 +74,7 @@ public void badChannels() {
ConnectivityState.SHUTDOWN,
};

Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt()))
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull()))
.thenReturn(new ManagedChannelMock(ConnectivityState.IDLE).nextStates(states))
.thenReturn(new ManagedChannelMock(ConnectivityState.IDLE).nextStates(states));

Expand All @@ -84,7 +84,7 @@ public void badChannels() {
Assert.assertEquals(endpoint, channel.getEndpoint());

RuntimeException ex1 = Assert.assertThrows(RuntimeException.class, channel::getReadyChannel);
Assert.assertEquals("Channel Endpoint{host=host1, port=1234, node=0, location=null} connecting problem",
Assert.assertEquals("Channel Endpoint{host=host1, port=1234, node=0, location=null, overrideAuthority=null} connecting problem",
ex1.getMessage());

channel.shutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE

public static ManagedChannelFactory.Builder MOCKED = (GrpcTransportBuilder builder) -> new ManagedChannelFactory() {
@Override
public ManagedChannel newManagedChannel(String host, int port) {
public ManagedChannel newManagedChannel(String host, int port, String authority) {
return good();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void fixedLocalDcTest() {

@Test
public void detectLocalDCfallbackTest() {
List<EndpointRecord> single = Collections.singletonList(new EndpointRecord("localhost", 8080, 0, "DC1"));
List<EndpointRecord> single = Collections.singletonList(new EndpointRecord("localhost", 8080, 0, "DC1", null));
PriorityPicker ignoreSelftLocation = PriorityPicker.from(BalancingSettings.detectLocalDs(), "DC1", single);

Assert.assertEquals(0, ignoreSelftLocation.getEndpointPriority("DC1"));
Expand All @@ -73,7 +73,7 @@ public void detectLocalDCTest() {
final int port = serverSocket.getLocalPort();

List<EndpointRecord> records = Arrays.asList("DC1", "DC1", "DC2", "DC2", "DC2", "DC3")
.stream().map(location -> new EndpointRecord("localhost", port, 1, location))
.stream().map(location -> new EndpointRecord("localhost", port, 1, location, null))
.collect(Collectors.toList());

String localDC = PriorityPicker.detectLocalDC(records, testTicker);
Expand Down