Skip to content

Commit cff9dd5

Browse files
committed
Change connector private ip enabled to boolean type
Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent 6289562 commit cff9dd5

File tree

12 files changed

+30
-40
lines changed

12 files changed

+30
-40
lines changed

common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import static org.opensearch.secure_sm.AccessController.doPrivileged;
99

1010
import java.time.Duration;
11-
import java.util.concurrent.atomic.AtomicBoolean;
1211

1312
import lombok.extern.log4j.Log4j2;
1413
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
@@ -21,7 +20,7 @@ public static SdkAsyncHttpClient getAsyncHttpClient(
2120
Duration connectionTimeout,
2221
Duration readTimeout,
2322
int maxConnections,
24-
AtomicBoolean connectorPrivateIpEnabled
23+
boolean connectorPrivateIpEnabled
2524
) {
2625
return doPrivileged(() -> {
2726
log
@@ -37,7 +36,7 @@ public static SdkAsyncHttpClient getAsyncHttpClient(
3736
.readTimeout(readTimeout)
3837
.maxConcurrency(maxConnections)
3938
.build();
40-
return new ValidatingHttpClient(delegate, connectorPrivateIpEnabled);
39+
return new MLValidatableAsyncHttpClient(delegate, connectorPrivateIpEnabled);
4140
});
4241
}
4342
}

common/src/main/java/org/opensearch/ml/common/httpclient/ValidatingHttpClient.java renamed to common/src/main/java/org/opensearch/ml/common/httpclient/MLValidatableAsyncHttpClient.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,17 @@
1111
import java.util.Arrays;
1212
import java.util.Locale;
1313
import java.util.concurrent.CompletableFuture;
14-
import java.util.concurrent.atomic.AtomicBoolean;
1514

1615
import lombok.extern.log4j.Log4j2;
1716
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
1817
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
1918

2019
@Log4j2
21-
public class ValidatingHttpClient implements SdkAsyncHttpClient {
20+
public class MLValidatableAsyncHttpClient implements SdkAsyncHttpClient {
2221
private final SdkAsyncHttpClient delegate;
23-
private final AtomicBoolean connectorPrivateIpEnabled;
22+
private final boolean connectorPrivateIpEnabled;
2423

25-
protected ValidatingHttpClient(SdkAsyncHttpClient client, AtomicBoolean connectorPrivateIpEnabled) {
24+
protected MLValidatableAsyncHttpClient(SdkAsyncHttpClient client, boolean connectorPrivateIpEnabled) {
2625
this.delegate = client;
2726
this.connectorPrivateIpEnabled = connectorPrivateIpEnabled;
2827
}
@@ -54,7 +53,7 @@ public void close() {
5453
* @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536].
5554
* @throws UnknownHostException Allow to use private IP or not.
5655
*/
57-
public void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
56+
public void validate(String protocol, String host, int port, boolean connectorPrivateIpEnabled) throws UnknownHostException {
5857
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
5958
log.error("Remote inference protocol is not http or https: {}", protocol);
6059
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
@@ -74,9 +73,9 @@ public void validate(String protocol, String host, int port, AtomicBoolean conne
7473
validateIp(host, connectorPrivateIpEnabled);
7574
}
7675

77-
private void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
76+
private void validateIp(String hostName, boolean connectorPrivateIpEnabled) throws UnknownHostException {
7877
InetAddress[] addresses = InetAddress.getAllByName(hostName);
79-
if (connectorPrivateIpEnabled != null && !connectorPrivateIpEnabled.get() && hasPrivateIpAddress(addresses)) {
78+
if (!connectorPrivateIpEnabled && hasPrivateIpAddress(addresses)) {
8079
log.error("Remote inference host name has private ip address: {}", hostName);
8180
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
8281
}

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import java.util.ArrayList;
2727
import java.util.List;
28-
import java.util.concurrent.atomic.AtomicBoolean;
2928

3029
import org.opensearch.cluster.service.ClusterService;
3130
import org.opensearch.common.settings.Settings;
@@ -38,7 +37,7 @@ public class MLFeatureEnabledSetting {
3837
private volatile Boolean isAgentFrameworkEnabled;
3938

4039
private volatile Boolean isLocalModelEnabled;
41-
private volatile AtomicBoolean isConnectorPrivateIpEnabled;
40+
private volatile Boolean isConnectorPrivateIpEnabled;
4241

4342
private volatile Boolean isControllerEnabled;
4443
private volatile Boolean isBatchIngestionEnabled;
@@ -70,7 +69,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
7069
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
7170
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);
7271
isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings);
73-
isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings));
72+
isConnectorPrivateIpEnabled = ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings);
7473
isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings);
7574
isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings);
7675
isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings);
@@ -94,7 +93,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
9493
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it);
9594
clusterService
9695
.getClusterSettings()
97-
.addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it));
96+
.addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled = it);
9897
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it);
9998
clusterService
10099
.getClusterSettings()
@@ -145,7 +144,7 @@ public boolean isLocalModelEnabled() {
145144
return isLocalModelEnabled;
146145
}
147146

148-
public AtomicBoolean isConnectorPrivateIpEnabled() {
147+
public boolean isConnectorPrivateIpEnabled() {
149148
return isConnectorPrivateIpEnabled;
150149
}
151150

common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import static org.junit.Assert.assertNotNull;
99

1010
import java.time.Duration;
11-
import java.util.concurrent.atomic.AtomicBoolean;
1211

1312
import org.junit.Test;
1413

@@ -18,8 +17,7 @@ public class MLHttpClientFactoryTests {
1817

1918
@Test
2019
public void test_getSdkAsyncHttpClient_success() {
21-
SdkAsyncHttpClient client = MLHttpClientFactory
22-
.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, new AtomicBoolean(false));
20+
SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false);
2321
assertNotNull(client);
2422
}
2523

common/src/test/java/org/opensearch/ml/common/httpclient/ValidatingHttpClientTests.java renamed to common/src/test/java/org/opensearch/ml/common/httpclient/MLValidatableAsyncHttpClientTests.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,23 @@
1010
import static org.junit.Assert.assertThrows;
1111
import static org.mockito.Mockito.mock;
1212

13-
import java.util.concurrent.atomic.AtomicBoolean;
14-
1513
import org.junit.Rule;
1614
import org.junit.Test;
1715
import org.junit.rules.ExpectedException;
1816

1917
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
2018

21-
public class ValidatingHttpClientTests {
19+
public class MLValidatableAsyncHttpClientTests {
2220
private static final String TEST_HOST = "api.openai.com";
2321
private static final String HTTP = "http";
2422
private static final String HTTPS = "https";
25-
private static final AtomicBoolean PRIVATE_IP_DISABLED = new AtomicBoolean(false);
26-
private static final AtomicBoolean PRIVATE_IP_ENABLED = new AtomicBoolean(true);
23+
private static final boolean PRIVATE_IP_DISABLED = false;
24+
private static final boolean PRIVATE_IP_ENABLED = true;
2725

28-
private final ValidatingHttpClient validatingHttpClient = new ValidatingHttpClient(mock(SdkAsyncHttpClient.class), PRIVATE_IP_DISABLED);
26+
private final MLValidatableAsyncHttpClient validatingHttpClient = new MLValidatableAsyncHttpClient(
27+
mock(SdkAsyncHttpClient.class),
28+
PRIVATE_IP_DISABLED
29+
);
2930

3031
@Rule
3132
public ExpectedException expectedException = ExpectedException.none();

common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public void testDefaults_allFeaturesEnabled() {
8181
assertTrue(setting.isRemoteInferenceEnabled());
8282
assertTrue(setting.isAgentFrameworkEnabled());
8383
assertTrue(setting.isLocalModelEnabled());
84-
assertTrue(setting.isConnectorPrivateIpEnabled().get());
84+
assertTrue(setting.isConnectorPrivateIpEnabled());
8585
assertTrue(setting.isControllerEnabled());
8686
assertTrue(setting.isOfflineBatchIngestionEnabled());
8787
assertTrue(setting.isOfflineBatchInferenceEnabled());
@@ -122,7 +122,7 @@ public void testDefaults_someFeaturesDisabled() {
122122
assertFalse(setting.isRemoteInferenceEnabled());
123123
assertFalse(setting.isAgentFrameworkEnabled());
124124
assertFalse(setting.isLocalModelEnabled());
125-
assertFalse(setting.isConnectorPrivateIpEnabled().get());
125+
assertFalse(setting.isConnectorPrivateIpEnabled());
126126
assertFalse(setting.isControllerEnabled());
127127
assertFalse(setting.isOfflineBatchIngestionEnabled());
128128
assertFalse(setting.isOfflineBatchInferenceEnabled());

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import java.util.Locale;
1919
import java.util.Map;
2020
import java.util.concurrent.CompletableFuture;
21-
import java.util.concurrent.atomic.AtomicBoolean;
2221
import java.util.concurrent.atomic.AtomicReference;
2322

2423
import org.apache.commons.text.StringEscapeUtils;
@@ -81,7 +80,7 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
8180
private StreamTransportService streamTransportService;
8281

8382
@Setter
84-
private AtomicBoolean connectorPrivateIpEnabled;
83+
private boolean connectorPrivateIpEnabled;
8584

8685
public AwsConnectorExecutor(Connector connector) {
8786
super.initialize(connector);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import java.util.Locale;
1818
import java.util.Map;
1919
import java.util.concurrent.CompletableFuture;
20-
import java.util.concurrent.atomic.AtomicBoolean;
2120
import java.util.concurrent.atomic.AtomicReference;
2221

2322
import org.apache.commons.text.StringEscapeUtils;
@@ -74,7 +73,7 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor {
7473
@Getter
7574
private MLGuard mlGuard;
7675
@Setter
77-
private volatile AtomicBoolean connectorPrivateIpEnabled;
76+
private volatile boolean connectorPrivateIpEnabled;
7877

7978
private final AtomicReference<SdkAsyncHttpClient> httpClientRef = new AtomicReference<>();
8079

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import java.util.Locale;
2020
import java.util.Map;
2121
import java.util.Optional;
22-
import java.util.concurrent.atomic.AtomicBoolean;
2322

2423
import org.apache.logging.log4j.Logger;
2524
import org.opensearch.ExceptionsHelper;
@@ -183,7 +182,7 @@ default void setScriptService(ScriptService scriptService) {}
183182

184183
default void setClient(Client client) {}
185184

186-
default void setConnectorPrivateIpEnabled(AtomicBoolean connectorPrivateIpEnabled) {}
185+
default void setConnectorPrivateIpEnabled(boolean connectorPrivateIpEnabled) {}
187186

188187
default void setXContentRegistry(NamedXContentRegistry xContentRegistry) {}
189188

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import java.util.Map;
1212
import java.util.concurrent.CompletableFuture;
1313
import java.util.concurrent.CompletionStage;
14-
import java.util.concurrent.atomic.AtomicBoolean;
1514

1615
import org.opensearch.cluster.service.ClusterService;
1716
import org.opensearch.common.settings.Settings;
@@ -127,7 +126,7 @@ public CompletionStage<Boolean> initModelAsync(MLModel model, Map<String, Object
127126
this.connectorExecutor.setRateLimiter((TokenBucket) params.get(RATE_LIMITER));
128127
this.connectorExecutor.setUserRateLimiterMap((Map<String, TokenBucket>) params.get(USER_RATE_LIMITER_MAP));
129128
this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS));
130-
this.connectorExecutor.setConnectorPrivateIpEnabled((AtomicBoolean) params.get(CONNECTOR_PRIVATE_IP_ENABLED));
129+
this.connectorExecutor.setConnectorPrivateIpEnabled((boolean) params.getOrDefault(CONNECTOR_PRIVATE_IP_ENABLED, false));
131130
return CompletableFuture.completedStage(true);
132131
}).exceptionally(e -> {
133132
log.error("Failed to init remote model.", e);

0 commit comments

Comments
 (0)