Skip to content

Commit c4ec014

Browse files
authored
Fixing regex bypass issue (#4336)
* Fixing regex bypass issue Signed-off-by: zane-neo <zaniu@amazon.com> * fix failure ITs Signed-off-by: zane-neo <zaniu@amazon.com> * Change connector private ip enabled to boolean type Signed-off-by: zane-neo <zaniu@amazon.com> * fix dependency version conflict issue Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent eec7179 commit c4ec014

File tree

15 files changed

+399
-312
lines changed

15 files changed

+399
-312
lines changed

common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java

Lines changed: 24 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,9 @@
55

66
package org.opensearch.ml.common.httpclient;
77

8-
import java.net.Inet4Address;
9-
import java.net.InetAddress;
10-
import java.net.UnknownHostException;
11-
import java.time.Duration;
12-
import java.util.Arrays;
13-
import java.util.Locale;
14-
import java.util.concurrent.atomic.AtomicBoolean;
8+
import static org.opensearch.secure_sm.AccessController.doPrivileged;
159

16-
import org.opensearch.common.util.concurrent.ThreadContextAccess;
10+
import java.time.Duration;
1711

1812
import lombok.extern.log4j.Log4j2;
1913
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
@@ -22,79 +16,27 @@
2216
@Log4j2
2317
public class MLHttpClientFactory {
2418

25-
public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) {
26-
return ThreadContextAccess
27-
.doPrivileged(
28-
() -> NettyNioAsyncHttpClient
29-
.builder()
30-
.connectionTimeout(connectionTimeout)
31-
.readTimeout(readTimeout)
32-
.maxConcurrency(maxConnections)
33-
.build()
34-
);
35-
}
36-
37-
/**
38-
* Validate the input parameters, such as protocol, host and port.
39-
* @param protocol The protocol supported in remote inference, currently only http and https are supported.
40-
* @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost.
41-
* @param port The port number of the remote inference server, port number must be in range [0, 65536].
42-
* @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536].
43-
* @throws UnknownHostException Allow to use private IP or not.
44-
*/
45-
public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled)
46-
throws UnknownHostException {
47-
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
48-
log.error("Remote inference protocol is not http or https: {}", protocol);
49-
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
50-
}
51-
// When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
52-
if (port == -1) {
53-
if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) {
54-
port = 80;
55-
} else {
56-
port = 443;
57-
}
58-
}
59-
if (port < 0 || port > 65536) {
60-
log.error("Remote inference port out of range: {}", port);
61-
throw new IllegalArgumentException("Port out of range: " + port);
62-
}
63-
validateIp(host, connectorPrivateIpEnabled);
64-
}
65-
66-
private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
67-
InetAddress[] addresses = InetAddress.getAllByName(hostName);
68-
if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) {
69-
log.error("Remote inference host name has private ip address: {}", hostName);
70-
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
71-
}
72-
}
73-
74-
private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
75-
for (InetAddress ip : ipAddress) {
76-
if (ip instanceof Inet4Address) {
77-
byte[] bytes = ip.getAddress();
78-
if (bytes.length != 4) {
79-
return true;
80-
} else {
81-
if (isPrivateIPv4(bytes)) {
82-
return true;
83-
}
84-
}
85-
}
86-
}
87-
return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
88-
}
89-
90-
private static boolean isPrivateIPv4(byte[] bytes) {
91-
int first = bytes[0] & 0xff;
92-
int second = bytes[1] & 0xff;
93-
94-
// 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x
95-
return (first == 10)
96-
|| (first == 172 && second >= 16 && second <= 31)
97-
|| (first == 192 && second == 168)
98-
|| (first == 169 && second == 254);
19+
public static SdkAsyncHttpClient getAsyncHttpClient(
20+
Duration connectionTimeout,
21+
Duration readTimeout,
22+
int maxConnections,
23+
boolean connectorPrivateIpEnabled
24+
) {
25+
return doPrivileged(() -> {
26+
log
27+
.debug(
28+
"Creating MLHttpClient with connectionTimeout: {}, readTimeout: {}, maxConnections: {}",
29+
connectionTimeout,
30+
readTimeout,
31+
maxConnections
32+
);
33+
SdkAsyncHttpClient delegate = NettyNioAsyncHttpClient
34+
.builder()
35+
.connectionTimeout(connectionTimeout)
36+
.readTimeout(readTimeout)
37+
.maxConcurrency(maxConnections)
38+
.build();
39+
return new MLValidatableAsyncHttpClient(delegate, connectorPrivateIpEnabled);
40+
});
9941
}
10042
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.httpclient;
7+
8+
import java.net.Inet4Address;
9+
import java.net.InetAddress;
10+
import java.net.UnknownHostException;
11+
import java.util.Arrays;
12+
import java.util.Locale;
13+
import java.util.concurrent.CompletableFuture;
14+
15+
import lombok.extern.log4j.Log4j2;
16+
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
17+
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
18+
19+
@Log4j2
20+
public class MLValidatableAsyncHttpClient implements SdkAsyncHttpClient {
21+
private final SdkAsyncHttpClient delegate;
22+
private final boolean connectorPrivateIpEnabled;
23+
24+
protected MLValidatableAsyncHttpClient(SdkAsyncHttpClient client, boolean connectorPrivateIpEnabled) {
25+
this.delegate = client;
26+
this.connectorPrivateIpEnabled = connectorPrivateIpEnabled;
27+
}
28+
29+
@Override
30+
public CompletableFuture<Void> execute(AsyncExecuteRequest request) {
31+
String protocol = request.request().protocol();
32+
String host = request.request().host();
33+
int port = request.request().port();
34+
try {
35+
validate(protocol, host, port, connectorPrivateIpEnabled);
36+
return delegate.execute(request);
37+
} catch (Exception e) {
38+
log.error("Failed to validate request!", e);
39+
throw new IllegalArgumentException(e.getMessage(), e);
40+
}
41+
}
42+
43+
@Override
44+
public void close() {
45+
delegate.close();
46+
}
47+
48+
/**
49+
* Validate the input parameters, such as protocol, host and port.
50+
* @param protocol The protocol supported in remote inference, currently only http and https are supported.
51+
* @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost.
52+
* @param port The port number of the remote inference server, port number must be in range [0, 65536].
53+
* @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536].
54+
* @throws UnknownHostException Allow to use private IP or not.
55+
*/
56+
public void validate(String protocol, String host, int port, boolean connectorPrivateIpEnabled) throws UnknownHostException {
57+
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
58+
log.error("Remote inference protocol is not http or https: {}", protocol);
59+
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
60+
}
61+
// When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
62+
if (port == -1) {
63+
if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) {
64+
port = 80;
65+
} else {
66+
port = 443;
67+
}
68+
}
69+
if (port < 0 || port > 65536) {
70+
log.error("Remote inference port out of range: {}", port);
71+
throw new IllegalArgumentException("Port out of range: " + port);
72+
}
73+
validateIp(host, connectorPrivateIpEnabled);
74+
}
75+
76+
private void validateIp(String hostName, boolean connectorPrivateIpEnabled) throws UnknownHostException {
77+
InetAddress[] addresses = InetAddress.getAllByName(hostName);
78+
if (!connectorPrivateIpEnabled && hasPrivateIpAddress(addresses)) {
79+
log.error("Remote inference host name has private ip address: {}", hostName);
80+
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
81+
}
82+
}
83+
84+
private boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
85+
for (InetAddress ip : ipAddress) {
86+
if (ip instanceof Inet4Address) {
87+
byte[] bytes = ip.getAddress();
88+
if (bytes.length != 4) {
89+
return true;
90+
} else {
91+
if (isPrivateIPv4(bytes)) {
92+
return true;
93+
}
94+
}
95+
}
96+
}
97+
return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
98+
}
99+
100+
private boolean isPrivateIPv4(byte[] bytes) {
101+
int first = bytes[0] & 0xff;
102+
int second = bytes[1] & 0xff;
103+
104+
// 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x
105+
return (first == 10)
106+
|| (first == 172 && second >= 16 && second <= 31)
107+
|| (first == 192 && second == 168)
108+
|| (first == 169 && second == 254);
109+
}
110+
}

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import java.util.ArrayList;
2727
import java.util.List;
28-
import java.util.concurrent.atomic.AtomicBoolean;
2928

3029
import org.opensearch.cluster.service.ClusterService;
3130
import org.opensearch.common.settings.Settings;
@@ -38,7 +37,7 @@ public class MLFeatureEnabledSetting {
3837
private volatile Boolean isAgentFrameworkEnabled;
3938

4039
private volatile Boolean isLocalModelEnabled;
41-
private volatile AtomicBoolean isConnectorPrivateIpEnabled;
40+
private volatile Boolean isConnectorPrivateIpEnabled;
4241

4342
private volatile Boolean isControllerEnabled;
4443
private volatile Boolean isBatchIngestionEnabled;
@@ -70,7 +69,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
7069
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
7170
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);
7271
isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings);
73-
isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings));
72+
isConnectorPrivateIpEnabled = ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings);
7473
isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings);
7574
isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings);
7675
isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings);
@@ -94,7 +93,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
9493
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it);
9594
clusterService
9695
.getClusterSettings()
97-
.addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it));
96+
.addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled = it);
9897
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it);
9998
clusterService
10099
.getClusterSettings()
@@ -145,7 +144,7 @@ public boolean isLocalModelEnabled() {
145144
return isLocalModelEnabled;
146145
}
147146

148-
public AtomicBoolean isConnectorPrivateIpEnabled() {
147+
public boolean isConnectorPrivateIpEnabled() {
149148
return isConnectorPrivateIpEnabled;
150149
}
151150

0 commit comments

Comments
 (0)