Skip to content

Commit 263357d

Browse files
committed
Fixing regex bypass issue
Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent 3950a87 commit 263357d

File tree

8 files changed

+381
-294
lines changed

8 files changed

+381
-294
lines changed

common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java

Lines changed: 24 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -5,96 +5,39 @@
55

66
package org.opensearch.ml.common.httpclient;
77

8-
import java.net.Inet4Address;
9-
import java.net.InetAddress;
10-
import java.net.UnknownHostException;
8+
import static org.opensearch.secure_sm.AccessController.doPrivileged;
9+
1110
import java.time.Duration;
12-
import java.util.Arrays;
13-
import java.util.Locale;
1411
import java.util.concurrent.atomic.AtomicBoolean;
1512

16-
import org.opensearch.common.util.concurrent.ThreadContextAccess;
17-
1813
import lombok.extern.log4j.Log4j2;
1914
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
2015
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
2116

2217
@Log4j2
2318
public class MLHttpClientFactory {
2419

25-
public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) {
26-
return ThreadContextAccess
27-
.doPrivileged(
28-
() -> NettyNioAsyncHttpClient
29-
.builder()
30-
.connectionTimeout(connectionTimeout)
31-
.readTimeout(readTimeout)
32-
.maxConcurrency(maxConnections)
33-
.build()
34-
);
35-
}
36-
37-
/**
38-
* Validate the input parameters, such as protocol, host and port.
39-
* @param protocol The protocol supported in remote inference, currently only http and https are supported.
40-
* @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost.
41-
* @param port The port number of the remote inference server, port number must be in range [0, 65536].
42-
* @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536].
43-
* @throws UnknownHostException Allow to use private IP or not.
44-
*/
45-
public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled)
46-
throws UnknownHostException {
47-
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
48-
log.error("Remote inference protocol is not http or https: {}", protocol);
49-
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
50-
}
51-
// When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
52-
if (port == -1) {
53-
if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) {
54-
port = 80;
55-
} else {
56-
port = 443;
57-
}
58-
}
59-
if (port < 0 || port > 65536) {
60-
log.error("Remote inference port out of range: {}", port);
61-
throw new IllegalArgumentException("Port out of range: " + port);
62-
}
63-
validateIp(host, connectorPrivateIpEnabled);
64-
}
65-
66-
private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
67-
InetAddress[] addresses = InetAddress.getAllByName(hostName);
68-
if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) {
69-
log.error("Remote inference host name has private ip address: {}", hostName);
70-
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
71-
}
72-
}
73-
74-
private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
75-
for (InetAddress ip : ipAddress) {
76-
if (ip instanceof Inet4Address) {
77-
byte[] bytes = ip.getAddress();
78-
if (bytes.length != 4) {
79-
return true;
80-
} else {
81-
if (isPrivateIPv4(bytes)) {
82-
return true;
83-
}
84-
}
85-
}
86-
}
87-
return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
88-
}
89-
90-
private static boolean isPrivateIPv4(byte[] bytes) {
91-
int first = bytes[0] & 0xff;
92-
int second = bytes[1] & 0xff;
93-
94-
// 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x
95-
return (first == 10)
96-
|| (first == 172 && second >= 16 && second <= 31)
97-
|| (first == 192 && second == 168)
98-
|| (first == 169 && second == 254);
20+
public static SdkAsyncHttpClient getAsyncHttpClient(
21+
Duration connectionTimeout,
22+
Duration readTimeout,
23+
int maxConnections,
24+
AtomicBoolean connectorPrivateIpEnabled
25+
) {
26+
return doPrivileged(() -> {
27+
log
28+
.debug(
29+
"Creating MLHttpClient with connectionTimeout: {}, readTimeout: {}, maxConnections: {}",
30+
connectionTimeout,
31+
readTimeout,
32+
maxConnections
33+
);
34+
SdkAsyncHttpClient delegate = NettyNioAsyncHttpClient
35+
.builder()
36+
.connectionTimeout(connectionTimeout)
37+
.readTimeout(readTimeout)
38+
.maxConcurrency(maxConnections)
39+
.build();
40+
return new ValidatingHttpClient(delegate, connectorPrivateIpEnabled);
41+
});
9942
}
10043
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.httpclient;
7+
8+
import java.net.Inet4Address;
9+
import java.net.InetAddress;
10+
import java.net.UnknownHostException;
11+
import java.util.Arrays;
12+
import java.util.Locale;
13+
import java.util.concurrent.CompletableFuture;
14+
import java.util.concurrent.atomic.AtomicBoolean;
15+
16+
import lombok.extern.log4j.Log4j2;
17+
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
18+
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
19+
20+
@Log4j2
21+
public class ValidatingHttpClient implements SdkAsyncHttpClient {
22+
private final SdkAsyncHttpClient delegate;
23+
private final AtomicBoolean connectorPrivateIpEnabled;
24+
25+
protected ValidatingHttpClient(SdkAsyncHttpClient client, AtomicBoolean connectorPrivateIpEnabled) {
26+
this.delegate = client;
27+
this.connectorPrivateIpEnabled = connectorPrivateIpEnabled;
28+
}
29+
30+
@Override
31+
public CompletableFuture<Void> execute(AsyncExecuteRequest request) {
32+
String protocol = request.request().protocol();
33+
String host = request.request().host();
34+
int port = request.request().port();
35+
try {
36+
validate(protocol, host, port, connectorPrivateIpEnabled);
37+
return delegate.execute(request);
38+
} catch (Exception e) {
39+
log.error("Failed to validate request!", e);
40+
throw new IllegalArgumentException(e.getMessage(), e);
41+
}
42+
}
43+
44+
@Override
45+
public void close() {
46+
delegate.close();
47+
}
48+
49+
/**
50+
* Validate the input parameters, such as protocol, host and port.
51+
* @param protocol The protocol supported in remote inference, currently only http and https are supported.
52+
* @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost.
53+
* @param port The port number of the remote inference server, port number must be in range [0, 65536].
54+
* @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536].
55+
* @throws UnknownHostException Allow to use private IP or not.
56+
*/
57+
public void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
58+
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
59+
log.error("Remote inference protocol is not http or https: {}", protocol);
60+
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
61+
}
62+
// When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
63+
if (port == -1) {
64+
if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) {
65+
port = 80;
66+
} else {
67+
port = 443;
68+
}
69+
}
70+
if (port < 0 || port > 65536) {
71+
log.error("Remote inference port out of range: {}", port);
72+
throw new IllegalArgumentException("Port out of range: " + port);
73+
}
74+
validateIp(host, connectorPrivateIpEnabled);
75+
}
76+
77+
private void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
78+
InetAddress[] addresses = InetAddress.getAllByName(hostName);
79+
if (connectorPrivateIpEnabled != null && !connectorPrivateIpEnabled.get() && hasPrivateIpAddress(addresses)) {
80+
log.error("Remote inference host name has private ip address: {}", hostName);
81+
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
82+
}
83+
}
84+
85+
private boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
86+
for (InetAddress ip : ipAddress) {
87+
if (ip instanceof Inet4Address) {
88+
byte[] bytes = ip.getAddress();
89+
if (bytes.length != 4) {
90+
return true;
91+
} else {
92+
if (isPrivateIPv4(bytes)) {
93+
return true;
94+
}
95+
}
96+
}
97+
}
98+
return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
99+
}
100+
101+
private boolean isPrivateIPv4(byte[] bytes) {
102+
int first = bytes[0] & 0xff;
103+
int second = bytes[1] & 0xff;
104+
105+
// 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x
106+
return (first == 10)
107+
|| (first == 172 && second >= 16 && second <= 31)
108+
|| (first == 192 && second == 168)
109+
|| (first == 169 && second == 254);
110+
}
111+
}

0 commit comments

Comments
 (0)