Skip to content

Commit

Permalink
Fix rare private ip address bypass SSRF issue (#1070)
Browse files Browse the repository at this point in the history
* Change connector access control creation allow empty list

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix rare private ip address bypass SSRF issue

Signed-off-by: zane-neo <zaniu@amazon.com>

---------

Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo authored Jul 11, 2023
1 parent 83bbdae commit e3cb2e3
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.http.protocol.HttpContext;
import org.apache.logging.log4j.util.Strings;

import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
Expand All @@ -42,10 +43,7 @@ public int resolve(HttpHost host) throws UnsupportedSchemeException {
}
});

builder.setDnsResolver(hostName -> {
validateIp(hostName);
return InetAddress.getAllByName(hostName);
});
builder.setDnsResolver(MLHttpClientFactory::validateIp);

builder.setRedirectStrategy(new LaxRedirectStrategy() {
@Override
Expand Down Expand Up @@ -79,15 +77,51 @@ protected static void validateSchemaAndPort(HttpHost host) {
}
}

protected static void validateIp(String hostName) throws UnknownHostException {
protected static InetAddress[] validateIp(String hostName) throws UnknownHostException {
InetAddress[] addresses = InetAddress.getAllByName(hostName);
if (hasPrivateIpAddress(addresses)) {
log.error("Remote inference host name has private ip address: " + hostName);
throw new IllegalArgumentException(hostName);
}
return addresses;
}

private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
for (InetAddress ip : ipAddress) {
if (ip instanceof Inet4Address) {
byte[] bytes = ip.getAddress();
if (bytes.length != 4) {
return true;
} else {
int firstOctets = bytes[0] & 0xff;
int firstInOctal = parseWithOctal(String.valueOf(firstOctets));
int firstInHex = Integer.parseInt(String.valueOf(firstOctets), 16);
if (firstInOctal == 127 || firstInHex == 127) {
return bytes[1] == 0 && bytes[2] == 0 && bytes[3] == 1;
} else if (firstInOctal == 10 || firstInHex == 10) {
return true;
} else if (firstInOctal == 172 || firstInHex == 172) {
int secondOctets = bytes[1] & 0xff;
int secondInOctal = parseWithOctal(String.valueOf(secondOctets));
int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
return (secondInOctal >= 16 && secondInOctal <= 32) || (secondInHex >= 16 && secondInHex <= 32);
} else if (firstInOctal == 192 || firstInHex == 192) {
int secondOctets = bytes[1] & 0xff;
int secondInOctal = parseWithOctal(String.valueOf(secondOctets));
int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
return secondInOctal == 168 || secondInHex == 168;
}
}
}
}
return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
}

private static int parseWithOctal(String input) {
try {
return Integer.parseInt(input, 8);
} catch (NumberFormatException e) {
return Integer.parseInt(input);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;

import java.net.InetAddress;
import java.net.UnknownHostException;

import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;

public class MLHttpClientFactoryTests {

Expand Down Expand Up @@ -43,6 +45,45 @@ public void test_validateIp_privateIp_throwException() throws UnknownHostExcepti
MLHttpClientFactory.validateIp("localhost");
}

@Test
public void test_validateIp_rarePrivateIp_throwException() throws UnknownHostException {
try {
MLHttpClientFactory.validateIp("0254.020.00.01");
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validateIp("172.1048577");
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validateIp("2886729729");
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validateIp("192.11010049");
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validateIp("3232300545");
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validateIp("0:0:0:0:0:ffff:127.0.0.1");
} catch (IllegalArgumentException e) {
assertNotNull(e);
}
}

@Test
public void test_validateSchemaAndPort_success() {
HttpHost httpHost = new HttpHost("api.openai.com", 8080, "https");
Expand Down

0 comments on commit e3cb2e3

Please sign in to comment.