|
5 | 5 |
|
6 | 6 | package org.opensearch.ml.common.httpclient; |
7 | 7 |
|
8 | | -import java.net.Inet4Address; |
9 | | -import java.net.InetAddress; |
10 | | -import java.net.UnknownHostException; |
11 | | -import java.time.Duration; |
12 | | -import java.util.Arrays; |
13 | | -import java.util.Locale; |
14 | | -import java.util.concurrent.atomic.AtomicBoolean; |
| 8 | +import static org.opensearch.secure_sm.AccessController.doPrivileged; |
15 | 9 |
|
16 | | -import org.opensearch.common.util.concurrent.ThreadContextAccess; |
| 10 | +import java.time.Duration; |
17 | 11 |
|
18 | 12 | import lombok.extern.log4j.Log4j2; |
19 | 13 | import software.amazon.awssdk.http.async.SdkAsyncHttpClient; |
|
22 | 16 | @Log4j2 |
23 | 17 | public class MLHttpClientFactory { |
24 | 18 |
|
25 | | - public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) { |
26 | | - return ThreadContextAccess |
27 | | - .doPrivileged( |
28 | | - () -> NettyNioAsyncHttpClient |
29 | | - .builder() |
30 | | - .connectionTimeout(connectionTimeout) |
31 | | - .readTimeout(readTimeout) |
32 | | - .maxConcurrency(maxConnections) |
33 | | - .build() |
34 | | - ); |
35 | | - } |
36 | | - |
37 | | - /** |
38 | | - * Validate the input parameters, such as protocol, host and port. |
39 | | - * @param protocol The protocol supported in remote inference, currently only http and https are supported. |
40 | | - * @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost. |
41 | | - * @param port The port number of the remote inference server, port number must be in range [0, 65536]. |
42 | | - * @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536]. |
43 | | - * @throws UnknownHostException Allow to use private IP or not. |
44 | | - */ |
45 | | - public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled) |
46 | | - throws UnknownHostException { |
47 | | - if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) { |
48 | | - log.error("Remote inference protocol is not http or https: {}", protocol); |
49 | | - throw new IllegalArgumentException("Protocol is not http or https: " + protocol); |
50 | | - } |
51 | | - // When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol. |
52 | | - if (port == -1) { |
53 | | - if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) { |
54 | | - port = 80; |
55 | | - } else { |
56 | | - port = 443; |
57 | | - } |
58 | | - } |
59 | | - if (port < 0 || port > 65536) { |
60 | | - log.error("Remote inference port out of range: {}", port); |
61 | | - throw new IllegalArgumentException("Port out of range: " + port); |
62 | | - } |
63 | | - validateIp(host, connectorPrivateIpEnabled); |
64 | | - } |
65 | | - |
66 | | - private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException { |
67 | | - InetAddress[] addresses = InetAddress.getAllByName(hostName); |
68 | | - if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) { |
69 | | - log.error("Remote inference host name has private ip address: {}", hostName); |
70 | | - throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName); |
71 | | - } |
72 | | - } |
73 | | - |
74 | | - private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) { |
75 | | - for (InetAddress ip : ipAddress) { |
76 | | - if (ip instanceof Inet4Address) { |
77 | | - byte[] bytes = ip.getAddress(); |
78 | | - if (bytes.length != 4) { |
79 | | - return true; |
80 | | - } else { |
81 | | - if (isPrivateIPv4(bytes)) { |
82 | | - return true; |
83 | | - } |
84 | | - } |
85 | | - } |
86 | | - } |
87 | | - return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress()); |
88 | | - } |
89 | | - |
90 | | - private static boolean isPrivateIPv4(byte[] bytes) { |
91 | | - int first = bytes[0] & 0xff; |
92 | | - int second = bytes[1] & 0xff; |
93 | | - |
94 | | - // 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x |
95 | | - return (first == 10) |
96 | | - || (first == 172 && second >= 16 && second <= 31) |
97 | | - || (first == 192 && second == 168) |
98 | | - || (first == 169 && second == 254); |
| 19 | + public static SdkAsyncHttpClient getAsyncHttpClient( |
| 20 | + Duration connectionTimeout, |
| 21 | + Duration readTimeout, |
| 22 | + int maxConnections, |
| 23 | + boolean connectorPrivateIpEnabled |
| 24 | + ) { |
| 25 | + return doPrivileged(() -> { |
| 26 | + log |
| 27 | + .debug( |
| 28 | + "Creating MLHttpClient with connectionTimeout: {}, readTimeout: {}, maxConnections: {}", |
| 29 | + connectionTimeout, |
| 30 | + readTimeout, |
| 31 | + maxConnections |
| 32 | + ); |
| 33 | + SdkAsyncHttpClient delegate = NettyNioAsyncHttpClient |
| 34 | + .builder() |
| 35 | + .connectionTimeout(connectionTimeout) |
| 36 | + .readTimeout(readTimeout) |
| 37 | + .maxConcurrency(maxConnections) |
| 38 | + .build(); |
| 39 | + return new MLValidatableAsyncHttpClient(delegate, connectorPrivateIpEnabled); |
| 40 | + }); |
99 | 41 | } |
100 | 42 | } |
0 commit comments