Skip to content

Commit

Permalink
Support host-based routing using a custom HTTP adapter where it has a…
Browse files Browse the repository at this point in the history
…ccess to a local DNS. This adapter will be able to resolve hostnames to the IP addresses scouted by NMAP.

PiperOrigin-RevId: 550954701
Change-Id: I58f554aa56651c7ac7cbe03d3efa0f41e414a333
  • Loading branch information
nttran8 authored and copybara-github committed Jul 25, 2023
1 parent 573624d commit 6eb44e5
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 57 deletions.
86 changes: 86 additions & 0 deletions plugin_server/py/common/net/http/host_resolver_http_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Custom HTTP Adapter to handle host-based routing support for load balancers."""

import socket
from typing import Optional
from urllib import parse

import requests

from tsunami.plugin_server.py.common.net.http.http_header_fields import HttpHeaderFields


class HostResolverHttpAdapter(requests.adapters.HTTPAdapter):
"""Custom HTTP adapter for proper hostname resolution.
When load balancers are used, there is a chance that the hostname does not
resolve to the IP address of the vulnerable application. When the hostname
does not resolve to the given IP address, the IP address returned by NMAP is
prioritized and used in the "netloc" portion of the URL (see
parse.urlsplit()). This Adapter also adds the host header of the request
package that would have been otherwise omitted by default.
Attributes:
pool_connections: Number of connection pools to cache.
pool_max: Maximum number of connections to save in the pool.
"""

def __init__(self, pool_connections: int, pool_maxsize: int):
super().__init__(
pool_connections=pool_connections, pool_maxsize=pool_maxsize
)

def _add_host_header(
self, request: requests.PreparedRequest, hostname: str
) -> None:
"""Adds host:port as the host header."""
request.headers[HttpHeaderFields.HOST.value] = hostname

def _require_ipv6_brackets(self, ip: str) -> str:
"""Adds enclosing brackets if IPV6."""
try:
socket.inet_pton(socket.AF_INET6, ip)
return "[%s]" % ip
except OSError:
return ip

def _resolve(self,
hostname: str,
ip: Optional[str] = None,
port: Optional[int] = None) -> Optional[str]:
"""Use the hostname if it resolves to the ip, else use the ip address.
Args:
hostname: Hostname of the target network. This could be the domain name or
the IP address.
ip: Optional IP address of target network.
port: Optional port of target network.
Returns:
String of the resolved hostname.
"""
if hostname == ip or not ip or ip in socket.getaddrinfo(hostname, port):
return hostname
return ip

def send(
self,
request: requests.PreparedRequest,
ip: Optional[str] = None,
**kwargs
) -> requests.Response:
result = parse.urlparse(request.url)
self._add_host_header(request, result.netloc)
# use local dns
resolved_host = self._resolve(result.hostname, ip=ip, port=result.port)
if resolved_host != result.hostname:
resolved_host = self._require_ipv6_brackets(resolved_host)
netloc = result.netloc.lower().replace(result.hostname, resolved_host)
request.url = parse.urlunparse((
result.scheme,
netloc,
result.path,
result.params,
result.query,
result.fragment,
))
return super().send(request, **kwargs)
144 changes: 144 additions & 0 deletions plugin_server/py/common/net/http/host_resolver_http_adapter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Tests for google3.third_party.java_src.tsunami.plugin_server.py.common.net.requests_http_client."""

from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import requests

from tsunami.plugin_server.py.common.net.http.host_resolver_http_adapter import HostResolverHttpAdapter
from tsunami.plugin_server.py.common.net.http.http_header_fields import HttpHeaderFields
from tsunami.plugin_server.py.common.net.http.http_method import HttpMethod


class HostResolverHttpAdapterTest(parameterized.TestCase):

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.custom_adapter = HostResolverHttpAdapter(5, 10)

def setUp(self):
super().setUp()
self.addCleanup(mock.patch.stopall)
# Mock of requests's HTTPAdapter
response = requests.Response()
response.status_code = 200
mock.patch.object(
requests.adapters.HTTPAdapter,
'send',
return_value=response,
).start()
# Mock hostname lookup
self.mock_getaddrinfo = mock.patch('socket.getaddrinfo').start()

@parameterized.named_parameters(
('with_hostname', 'vuln-app.com'),
('with_ipv4', '199.21.82.88'),
(
'with_ipv6',
'[2001:0db8:85a3:0000:0000:8a2e:0370:7334]',
),
)
def test_send_dispatches_with_host_header(self, host):
url = 'http://{}:8080/send'.format(host)
request = self._prepare_request(url)

self.custom_adapter.send(request)

requests.adapters.HTTPAdapter.send.assert_called_with(request)
self.assertEqual(
request.headers.get(HttpHeaderFields.HOST.value), '{}:8080'.format(host)
)

def test_send_without_target_ip_dispatches_default_hostname(self):
url = 'http://vuln-app.com:8080/send'
request = self._prepare_request(url)

self.custom_adapter.send(request)

requests.adapters.HTTPAdapter.send.assert_called_with(request)
self.assertEqual(request.url, url)

def test_send_when_hostname_resolves_to_ip_uses_default_hostname(self):
url = 'http://vuln-app.com:8080/send'
ip = '199.21.82.88'
request = self._prepare_request(url)

self.mock_getaddrinfo.return_value = [ip]
self.custom_adapter.send(request, ip=ip)

requests.adapters.HTTPAdapter.send.assert_called_once_with(request)
self.assertEqual(
request.headers.get(HttpHeaderFields.HOST.value), 'vuln-app.com:8080'
)
self.assertEqual(request.url, url)

def test_send_when_hostname_is_the_ip_uses_default_hostname(self):
ip = '2001:0db8:85a3:0000:0000:8a2e:0370:7334'
url = 'http://[{}]:8080/send'.format(ip)
request = self._prepare_request(url)

self.custom_adapter.send(request, ip=ip)

requests.adapters.HTTPAdapter.send.assert_called_once_with(request)
self.assertEqual(
request.headers.get(HttpHeaderFields.HOST.value),
'[{}]:8080'.format(ip))
self.assertEqual(request.url, url)

def test_send_when_hostname_is_case_insensitive(self):
url = 'http://vuln-APP.com:8080/send'
ip = '199.21.82.88'
request = self._prepare_request(url)

self.custom_adapter.send(request, ip=ip)

requests.adapters.HTTPAdapter.send.assert_called_once_with(request)
self.assertEqual(
request.headers.get(HttpHeaderFields.HOST.value), 'vuln-APP.com:8080'
)
self.assertEqual(request.url, 'http://199.21.82.88:8080/send')

def test_send_when_hostname_does_not_resolve_to_ipv4_uses_ipv4(self):
url = 'http://vuln-app.com:8080/send'
ip = '199.21.82.88'
request = self._prepare_request(url)

self.mock_getaddrinfo.return_value = ['1.1.1.1']
self.custom_adapter.send(request, ip=ip)

requests.adapters.HTTPAdapter.send.assert_called_once_with(request)
self.assertEqual(
request.headers.get(HttpHeaderFields.HOST.value), 'vuln-app.com:8080'
)
self.assertEqual(request.url, 'http://199.21.82.88:8080/send')

def test_send_when_hostname_does_not_resolve_to_ipv6_uses_ipv6(self):
url = 'http://vuln-app.com:8080/send'
ip = '2001:0db8:85a3:0000:0000:8a2e:0370:7334'
request = self._prepare_request(url)

self.mock_getaddrinfo.return_value = ['1.1.1.1']
self.custom_adapter.send(request, ip=ip)

requests.adapters.HTTPAdapter.send.assert_called_once_with(request)
self.assertEqual(
request.headers.get(HttpHeaderFields.HOST.value), 'vuln-app.com:8080'
)
self.assertEqual(
request.url,
'http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080/send',
)

def _prepare_request(self, url):
request = requests.Request(
method=HttpMethod.GET, url=url, data=b'HTML content'
)
request = request.prepare()
request.url = url
return request


if __name__ == '__main__':
absltest.main()
42 changes: 19 additions & 23 deletions plugin_server/py/common/net/http/requests_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import concurrent.futures
import functools
from typing import Optional

from absl import logging
import requests
from tsunami.plugin_server.py.common.data import network_service_utils

from tsunami.plugin_server.py.common.data.network_service_utils import NetworkService
from tsunami.plugin_server.py.common.net.http.host_resolver_http_adapter import HostResolverHttpAdapter
from tsunami.plugin_server.py.common.net.http.http_client import Builder
from tsunami.plugin_server.py.common.net.http.http_client import HttpClient

from tsunami.plugin_server.py.common.net.http.http_header_fields import HttpHeaderFields
from tsunami.plugin_server.py.common.net.http.http_headers import HttpHeaders
from tsunami.plugin_server.py.common.net.http.http_request import HttpRequest
Expand Down Expand Up @@ -44,14 +45,14 @@ class RequestsHttpClient(HttpClient):

def __init__(
self,
session: requests.Session,
allow_redirects: Optional[bool],
log_id: Optional[str],
max_workers: Optional[int],
session: Optional[requests.Session],
timeout_sec: Optional[float],
verify_ssl: Optional[bool],
):
self.session = session or requests.Session()
self.session = session
self.allow_redirects = allow_redirects
self.log_id = log_id
self.max_workers = max_workers
Expand All @@ -70,11 +71,11 @@ def send(self,
req = self._prepare_request(http_request)
resp = self.session.send(
request=req,
# TODO(b/288615444) handle host-based routing support for load balancers
# proxies=self._get_proxies(network_service),
ip=self._get_ip(network_service),
verify=self.verify_ssl,
timeout=self.timeout_sec,
allow_redirects=self.allow_redirects)
allow_redirects=self.allow_redirects,
)
return self._parse_response(resp)

def send_async(
Expand All @@ -86,7 +87,7 @@ def send_async(
http_request.method, http_request.url)
req = self._prepare_request(http_request)
loop = asyncio.get_event_loop()
future = asyncio.ensure_future(self._prepare_future(network_service, req))
future = asyncio.ensure_future(self._prepare_future(req, network_service))
loop.run_until_complete(future)
res = future.result()
return self._parse_response(res)
Expand All @@ -102,17 +103,6 @@ def _build_response_headers(self, headers: dict[str, str]) -> HttpHeaders:
headers_builder.add_header(field, headers[field])
return headers_builder.build()

def _get_proxies(
self, network_service: Optional[NetworkService] = None
) -> dict[str, str]:
if not network_service:
return {}
uri = network_service_utils.build_web_uri_authority(network_service)
return {
'http': uri,
'https': uri,
}

def _parse_response(self, res: requests.Response) -> HttpResponse:
response_header = self._build_response_headers(res.headers)
status = HttpStatus.from_code(res.status_code)
Expand All @@ -127,8 +117,8 @@ def _parse_response(self, res: requests.Response) -> HttpResponse:

async def _prepare_future(
self,
network_service: Optional[NetworkService],
req: requests.PreparedRequest,
network_service: Optional[NetworkService],
):
"""Prepare async request to include configuration."""
loop = asyncio.get_event_loop()
Expand All @@ -137,7 +127,7 @@ async def _prepare_future(
functools.partial(
self.session.send,
request=req,
proxies=self._get_proxies(network_service),
ip=self._get_ip(network_service),
verify=self.verify_ssl,
timeout=self.timeout_sec,
allow_redirects=self.allow_redirects,
Expand Down Expand Up @@ -180,6 +170,11 @@ def _serialize_request_headers(self, headers: HttpHeaders) -> dict[str, str]:
HttpHeaderFields.USER_AGENT.value] = self.TSUNAMI_USER_AGENT
return serialized_headers

def _get_ip(self, network_service: Optional[NetworkService]) -> Optional[str]:
if not network_service:
return None
return network_service.network_endpoint.ip_address.address


class RequestsHttpClientBuilder(Builder):
"""Base builder for implementations of RequestsHttpClient."""
Expand Down Expand Up @@ -233,8 +228,9 @@ def set_verify_ssl(self, verify_ssl: bool) -> 'RequestsHttpClientBuilder':

def build(self) -> RequestsHttpClient:
session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
pool_maxsize=self.pool_maxsize, pool_connections=self.pool_connections
adapter = HostResolverHttpAdapter(
pool_maxsize=self.pool_maxsize,
pool_connections=self.pool_connections,
)
session.mount('http://', adapter)
session.mount('https://', adapter)
Expand Down
Loading

0 comments on commit 6eb44e5

Please sign in to comment.