Skip to content

Commit

Permalink
Update SigV4 interceptor to use latest aws sdk version (#2884)
Browse files Browse the repository at this point in the history
  • Loading branch information
kenhuuu authored Nov 4, 2024
2 parents d0533ae + cfd6889 commit 6662336
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 115 deletions.
6 changes: 0 additions & 6 deletions gremlin-console/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ limitations under the License.
<artifactId>tinkergraph-gremlin</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.12.6.1</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
Expand Down
34 changes: 21 additions & 13 deletions gremlin-driver/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ limitations under the License.
</parent>
<artifactId>gremlin-driver</artifactId>
<name>Apache TinkerPop :: Gremlin Driver</name>
<properties>
<awssdk.version>2.29.3</awssdk.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.tinkerpop</groupId>
Expand All @@ -51,19 +54,24 @@ limitations under the License.
<optional>true</optional>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-core</artifactId>
<version>1.12.720</version>
<exclusions>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</exclusion>
<exclusion>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
</exclusion>
</exclusions>
<groupId>software.amazon.awssdk</groupId>
<artifactId>http-auth-aws</artifactId>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>auth</artifactId>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>utils</artifactId>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>protocol-core</artifactId>
<version>${awssdk.version}</version>
</dependency>
<!-- TinkerGraph is an optional dependency that is only required if doing deserialization of Graph instances -->
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
*/
package org.apache.tinkerpop.gremlin.driver.auth;

import com.amazonaws.auth.AWSCredentialsProvider;
import org.apache.tinkerpop.gremlin.driver.RequestInterceptor;
import org.apache.tinkerpop.gremlin.driver.Settings;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;

public interface Auth extends RequestInterceptor {
String AUTH_BASIC = "basic";
Expand All @@ -34,7 +34,7 @@ static Auth sigv4(final String regionName, final String serviceName) {
return new Sigv4(regionName, serviceName);
}

static Auth sigv4(final String regionName, final AWSCredentialsProvider awsCredentialsProvider, final String serviceName) {
static Auth sigv4(final String regionName, final AwsCredentialsProvider awsCredentialsProvider, final String serviceName) {
return new Sigv4(regionName, awsCredentialsProvider, serviceName);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,126 +18,150 @@
*/
package org.apache.tinkerpop.gremlin.driver.auth;

import com.amazonaws.DefaultRequest;
import com.amazonaws.SignableRequest;
import com.amazonaws.auth.AWS4Signer;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.http.HttpMethodName;
import com.amazonaws.util.SdkHttpUtils;
import com.amazonaws.util.StringUtils;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import org.apache.http.entity.StringEntity;
import org.apache.tinkerpop.gremlin.driver.HttpRequest;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import static com.amazonaws.auth.internal.SignerConstants.AUTHORIZATION;
import static com.amazonaws.auth.internal.SignerConstants.HOST;
import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_DATE;
import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_SECURITY_TOKEN;
import java.util.Set;
import org.apache.tinkerpop.gremlin.driver.HttpRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
import software.amazon.awssdk.http.auth.spi.signer.SignedRequest;
import software.amazon.awssdk.utils.http.SdkHttpUtils;

import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.AUTHORIZATION;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.HOST;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_CONTENT_SHA256;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_DATE;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_SECURITY_TOKEN;

/**
* A {@link org.apache.tinkerpop.gremlin.driver.RequestInterceptor} that provides headers required for SigV4. Because
* the signing process requires final header and body data, this interceptor should almost always be last.
*/
public class Sigv4 implements Auth {
private final AWSCredentialsProvider awsCredentialsProvider;
private final AWS4Signer aws4Signer;
private static final Logger logger = LoggerFactory.getLogger(Sigv4.class);
private final AwsCredentialsProvider awsCredentialsProvider;
private final AwsV4HttpSigner aws4Signer;
private final String serviceName;
private final String regionName;

public Sigv4(final String regionName, final String serviceName) {
this(regionName, new DefaultAWSCredentialsProviderChain(), serviceName);
this(regionName, DefaultCredentialsProvider.create(), serviceName);
}

public Sigv4(final String regionName, final AWSCredentialsProvider awsCredentialsProvider, final String serviceName) {
public Sigv4(final String regionName, final AwsCredentialsProvider awsCredentialsProvider, final String serviceName) {
this.awsCredentialsProvider = awsCredentialsProvider;

aws4Signer = new AWS4Signer();
aws4Signer.setRegionName(regionName);
aws4Signer.setServiceName(serviceName);
aws4Signer = AwsV4HttpSigner.create();
this.regionName = regionName;
this.serviceName = serviceName;
}

@Override
public HttpRequest apply(final HttpRequest httpRequest) {
try {
final ContentStreamProvider content = toContentStream(httpRequest);
// Convert Http request into an AWS SDK signable request
final SignableRequest<?> awsSignableRequest = toSignableRequest(httpRequest);
final SdkHttpRequest awsSignableRequest = toSignableRequest(httpRequest);
final AwsCredentials credentials = awsCredentialsProvider.resolveCredentials();

// Sign the AWS SDK signable request (which internally adds some HTTP headers)
final AWSCredentials credentials = awsCredentialsProvider.getCredentials();
aws4Signer.sign(awsSignableRequest, credentials);

// extract session token if temporary credentials are provided
String sessionToken = "";
if ((credentials instanceof BasicSessionCredentials)) {
sessionToken = ((BasicSessionCredentials) credentials).getSessionToken();
}
final SignedRequest signed = aws4Signer.sign(r -> r.identity(credentials)
.request(awsSignableRequest)
.payload(content)
.putProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, this.serviceName)
.putProperty(AwsV4HttpSigner.REGION_NAME, this.regionName));

final Map<String, String> headers = httpRequest.headers();
headers.remove(HttpRequest.Headers.HOST);
headers.put(HOST, awsSignableRequest.getHeaders().get(HOST));
headers.put(X_AMZ_DATE, awsSignableRequest.getHeaders().get(X_AMZ_DATE));
headers.put(AUTHORIZATION, awsSignableRequest.getHeaders().get(AUTHORIZATION));

if (!sessionToken.isEmpty()) {
headers.put(X_AMZ_SECURITY_TOKEN, sessionToken);
}
setSignedHeaders(headers, signed);
setSessionToken(headers, credentials);
} catch (final Exception ex) {
logger.error("Error signing HTTP request: {}", ex.getMessage(), ex);
throw new AuthenticationException(ex);
}
return httpRequest;
}

private SignableRequest<?> toSignableRequest(final HttpRequest request) throws IOException {
private void setSessionToken(final Map<String, String> headers, final AwsCredentials credentials) {
// extract session token if temporary credentials are provided
if ((credentials instanceof AwsSessionCredentials)) {
final String sessionToken = ((AwsSessionCredentials) credentials).sessionToken();
if (sessionToken != null && !sessionToken.isEmpty()) {
headers.put(X_AMZ_SECURITY_TOKEN, sessionToken);
}
}
}

private void setSignedHeaders(final Map<String, String> headers, final SignedRequest signed) {
headers.remove(HttpRequest.Headers.HOST);
headers.put(HOST, signed.request().host());
final Map<String, List<String>> signedHeaders = signed.request().headers();
headers.put(X_AMZ_DATE, getSingleHeaderValue(signedHeaders, X_AMZ_DATE));
headers.put(AUTHORIZATION, getSingleHeaderValue(signedHeaders, AUTHORIZATION));
headers.put(X_AMZ_CONTENT_SHA256, getSingleHeaderValue(signedHeaders, X_AMZ_CONTENT_SHA256));
}

private String getSingleHeaderValue(final Map<String, List<String>> headers, final String headerName) {
final Set<String> headerValues = new HashSet<>(headers.containsKey(headerName) ? headers.get(headerName) : Collections.emptySet());
if (headerValues.size() != 1) {
throw new IllegalArgumentException(String.format("Expected 1 header %s but found %d", headerName, headerValues.size()));
}
return headerValues.iterator().next();
}

private ContentStreamProvider toContentStream(final HttpRequest httpRequest) {
// carry over the entity (or an empty entity, if no entity is provided)
if (!(httpRequest.getBody() instanceof byte[])) {
throw new IllegalArgumentException("Expected byte[] in HttpRequest body but got " + httpRequest.getBody().getClass());
}
final byte[] body = (byte[]) httpRequest.getBody();
return (body.length != 0) ? ContentStreamProvider.fromByteArray(body) : ContentStreamProvider.fromUtf8String("");
}

private SdkHttpRequest toSignableRequest(final HttpRequest request) {

// make sure the request contains the minimal required set of information
checkNotNull(request.getUri(), "The request URI must not be null");
checkNotNull(request.getMethod(), "The request method must not be null");

// convert the headers to the internal API format
final Map<String, String> headers = request.headers();
final Map<String, String> headersInternal = new HashMap<>();
final Map<String, List<String>> headersInternal = new HashMap<>();

// we don't want to add the Host header as the Signer always adds the host header.
for (Map.Entry<String, String> header : headers.entrySet()) {
// Skip adding the Host header as the signing process will add one.
if (!header.getKey().equalsIgnoreCase(HttpRequest.Headers.HOST)) {
headersInternal.put(header.getKey(), header.getValue());
headersInternal.put(header.getKey(), Collections.singletonList(header.getValue()));
}
}

// convert the parameters to the internal API format
final URI uri = request.getUri();
final Map<String, List<String>> parametersInternal = extractParametersFromQueryString(uri.getQuery());

// carry over the entity (or an empty entity, if no entity is provided)
if (!(request.getBody() instanceof byte[])) {
throw new IllegalArgumentException("Expected byte[] in HttpRequest body but got " + request.getBody().getClass());
}

final byte[] body = (byte[]) request.getBody();
final InputStream content = (body.length != 0) ? new ByteArrayInputStream(body) : new StringEntity("").getContent();
final URI endpointUri = URI.create(uri.getScheme() + "://" + uri.getHost());

return convertToSignableRequest(
request.getMethod(),
endpointUri,
uri.getPath(),
headersInternal,
parametersInternal,
content);
// create the HTTP AWS SdkHttpRequest and carry over information
return SdkHttpRequest.builder()
.uri(endpointUri)
.encodedPath(uri.getPath())
.method(SdkHttpMethod.fromValue(request.getMethod()))
.headers(headersInternal)
.rawQueryParameters(parametersInternal)
.build();
}

private HashMap<String, List<String>> extractParametersFromQueryString(final String queryStr) {
Expand Down Expand Up @@ -177,26 +201,6 @@ private HashMap<String, List<String>> extractParametersFromQueryString(final Str
return parameters;
}

private SignableRequest<?> convertToSignableRequest(
final String httpMethodName,
final URI httpEndpointUri,
final String resourcePath,
final Map<String, String> httpHeaders,
final Map<String, List<String>> httpParameters,
final InputStream httpContent) {

// create the HTTP AWS SDK Signable Request and carry over information
final DefaultRequest<?> awsRequest = new DefaultRequest<>(aws4Signer.getServiceName());
awsRequest.setHttpMethod(HttpMethodName.fromValue(httpMethodName));
awsRequest.setEndpoint(httpEndpointUri);
awsRequest.setResourcePath(resourcePath);
awsRequest.setHeaders(httpHeaders);
awsRequest.setParameters(httpParameters);
awsRequest.setContent(httpContent);

return awsRequest;
}

private void checkNotNull(final Object obj, final String errMsg) {
if (obj == null) {
throw new IllegalArgumentException(errMsg);
Expand Down
Loading

0 comments on commit 6662336

Please sign in to comment.