Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SigV4 interceptor to use latest aws sdk version #2884

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions gremlin-console/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ limitations under the License.
<artifactId>tinkergraph-gremlin</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.12.6.1</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
Expand Down
34 changes: 21 additions & 13 deletions gremlin-driver/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ limitations under the License.
</parent>
<artifactId>gremlin-driver</artifactId>
<name>Apache TinkerPop :: Gremlin Driver</name>
<properties>
<awssdk.version>2.29.3</awssdk.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.tinkerpop</groupId>
Expand All @@ -51,19 +54,24 @@ limitations under the License.
<optional>true</optional>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-core</artifactId>
<version>1.12.720</version>
<exclusions>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</exclusion>
<exclusion>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
</exclusion>
</exclusions>
<groupId>software.amazon.awssdk</groupId>
<artifactId>http-auth-aws</artifactId>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>auth</artifactId>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>utils</artifactId>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>protocol-core</artifactId>
<version>${awssdk.version}</version>
</dependency>
<!-- TinkerGraph is an optional dependency that is only required if doing deserialization of Graph instances -->
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
*/
package org.apache.tinkerpop.gremlin.driver.auth;

import com.amazonaws.auth.AWSCredentialsProvider;
import org.apache.tinkerpop.gremlin.driver.RequestInterceptor;
import org.apache.tinkerpop.gremlin.driver.Settings;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;

public interface Auth extends RequestInterceptor {
String AUTH_BASIC = "basic";
Expand All @@ -34,7 +34,7 @@ static Auth sigv4(final String regionName, final String serviceName) {
return new Sigv4(regionName, serviceName);
}

static Auth sigv4(final String regionName, final AWSCredentialsProvider awsCredentialsProvider, final String serviceName) {
static Auth sigv4(final String regionName, final AwsCredentialsProvider awsCredentialsProvider, final String serviceName) {
return new Sigv4(regionName, awsCredentialsProvider, serviceName);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,126 +18,150 @@
*/
package org.apache.tinkerpop.gremlin.driver.auth;

import com.amazonaws.DefaultRequest;
import com.amazonaws.SignableRequest;
import com.amazonaws.auth.AWS4Signer;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.http.HttpMethodName;
import com.amazonaws.util.SdkHttpUtils;
import com.amazonaws.util.StringUtils;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import org.apache.http.entity.StringEntity;
import org.apache.tinkerpop.gremlin.driver.HttpRequest;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import static com.amazonaws.auth.internal.SignerConstants.AUTHORIZATION;
import static com.amazonaws.auth.internal.SignerConstants.HOST;
import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_DATE;
import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_SECURITY_TOKEN;
import java.util.Set;
import org.apache.tinkerpop.gremlin.driver.HttpRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
import software.amazon.awssdk.http.auth.spi.signer.SignedRequest;
import software.amazon.awssdk.utils.http.SdkHttpUtils;

import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.AUTHORIZATION;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.HOST;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_CONTENT_SHA256;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_DATE;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_SECURITY_TOKEN;

/**
* A {@link org.apache.tinkerpop.gremlin.driver.RequestInterceptor} that provides headers required for SigV4. Because
* the signing process requires final header and body data, this interceptor should almost always be last.
*/
public class Sigv4 implements Auth {
private final AWSCredentialsProvider awsCredentialsProvider;
private final AWS4Signer aws4Signer;
private static final Logger logger = LoggerFactory.getLogger(Sigv4.class);
private final AwsCredentialsProvider awsCredentialsProvider;
private final AwsV4HttpSigner aws4Signer;
private final String serviceName;
private final String regionName;

public Sigv4(final String regionName, final String serviceName) {
this(regionName, new DefaultAWSCredentialsProviderChain(), serviceName);
this(regionName, DefaultCredentialsProvider.create(), serviceName);
}

public Sigv4(final String regionName, final AWSCredentialsProvider awsCredentialsProvider, final String serviceName) {
public Sigv4(final String regionName, final AwsCredentialsProvider awsCredentialsProvider, final String serviceName) {
this.awsCredentialsProvider = awsCredentialsProvider;

aws4Signer = new AWS4Signer();
aws4Signer.setRegionName(regionName);
aws4Signer.setServiceName(serviceName);
aws4Signer = AwsV4HttpSigner.create();
this.regionName = regionName;
this.serviceName = serviceName;
}

@Override
public HttpRequest apply(final HttpRequest httpRequest) {
try {
final ContentStreamProvider content = toContentStream(httpRequest);
// Convert Http request into an AWS SDK signable request
final SignableRequest<?> awsSignableRequest = toSignableRequest(httpRequest);
final SdkHttpRequest awsSignableRequest = toSignableRequest(httpRequest);
final AwsCredentials credentials = awsCredentialsProvider.resolveCredentials();

// Sign the AWS SDK signable request (which internally adds some HTTP headers)
final AWSCredentials credentials = awsCredentialsProvider.getCredentials();
aws4Signer.sign(awsSignableRequest, credentials);

// extract session token if temporary credentials are provided
String sessionToken = "";
if ((credentials instanceof BasicSessionCredentials)) {
sessionToken = ((BasicSessionCredentials) credentials).getSessionToken();
}
final SignedRequest signed = aws4Signer.sign(r -> r.identity(credentials)
.request(awsSignableRequest)
.payload(content)
.putProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, this.serviceName)
.putProperty(AwsV4HttpSigner.REGION_NAME, this.regionName));

final Map<String, String> headers = httpRequest.headers();
headers.remove(HttpRequest.Headers.HOST);
headers.put(HOST, awsSignableRequest.getHeaders().get(HOST));
headers.put(X_AMZ_DATE, awsSignableRequest.getHeaders().get(X_AMZ_DATE));
headers.put(AUTHORIZATION, awsSignableRequest.getHeaders().get(AUTHORIZATION));

if (!sessionToken.isEmpty()) {
headers.put(X_AMZ_SECURITY_TOKEN, sessionToken);
}
setSignedHeaders(headers, signed);
setSessionToken(headers, credentials);
} catch (final Exception ex) {
logger.error("Error signing HTTP request: {}", ex.getMessage(), ex);
throw new AuthenticationException(ex);
}
return httpRequest;
}

private SignableRequest<?> toSignableRequest(final HttpRequest request) throws IOException {
private void setSessionToken(final Map<String, String> headers, final AwsCredentials credentials) {
// extract session token if temporary credentials are provided
if ((credentials instanceof AwsSessionCredentials)) {
final String sessionToken = ((AwsSessionCredentials) credentials).sessionToken();
if (sessionToken != null && !sessionToken.isEmpty()) {
headers.put(X_AMZ_SECURITY_TOKEN, sessionToken);
}
}
}

private void setSignedHeaders(final Map<String, String> headers, final SignedRequest signed) {
headers.remove(HttpRequest.Headers.HOST);
headers.put(HOST, signed.request().host());
final Map<String, List<String>> signedHeaders = signed.request().headers();
headers.put(X_AMZ_DATE, getSingleHeaderValue(signedHeaders, X_AMZ_DATE));
headers.put(AUTHORIZATION, getSingleHeaderValue(signedHeaders, AUTHORIZATION));
headers.put(X_AMZ_CONTENT_SHA256, getSingleHeaderValue(signedHeaders, X_AMZ_CONTENT_SHA256));
}

private String getSingleHeaderValue(final Map<String, List<String>> headers, final String headerName) {
final Set<String> headerValues = new HashSet<>(headers.containsKey(headerName) ? headers.get(headerName) : Collections.emptySet());
if (headerValues.size() != 1) {
throw new IllegalArgumentException(String.format("Expected 1 header %s but found %d", headerName, headerValues.size()));
}
return headerValues.iterator().next();
}

private ContentStreamProvider toContentStream(final HttpRequest httpRequest) {
// carry over the entity (or an empty entity, if no entity is provided)
if (!(httpRequest.getBody() instanceof byte[])) {
throw new IllegalArgumentException("Expected byte[] in HttpRequest body but got " + httpRequest.getBody().getClass());
}
final byte[] body = (byte[]) httpRequest.getBody();
return (body.length != 0) ? ContentStreamProvider.fromByteArray(body) : ContentStreamProvider.fromUtf8String("");
}

private SdkHttpRequest toSignableRequest(final HttpRequest request) {

// make sure the request contains the minimal required set of information
checkNotNull(request.getUri(), "The request URI must not be null");
checkNotNull(request.getMethod(), "The request method must not be null");

// convert the headers to the internal API format
final Map<String, String> headers = request.headers();
final Map<String, String> headersInternal = new HashMap<>();
final Map<String, List<String>> headersInternal = new HashMap<>();

// we don't want to add the Host header as the Signer always adds the host header.
for (Map.Entry<String, String> header : headers.entrySet()) {
// Skip adding the Host header as the signing process will add one.
if (!header.getKey().equalsIgnoreCase(HttpRequest.Headers.HOST)) {
headersInternal.put(header.getKey(), header.getValue());
headersInternal.put(header.getKey(), Collections.singletonList(header.getValue()));
}
}

// convert the parameters to the internal API format
final URI uri = request.getUri();
final Map<String, List<String>> parametersInternal = extractParametersFromQueryString(uri.getQuery());

// carry over the entity (or an empty entity, if no entity is provided)
if (!(request.getBody() instanceof byte[])) {
throw new IllegalArgumentException("Expected byte[] in HttpRequest body but got " + request.getBody().getClass());
}

final byte[] body = (byte[]) request.getBody();
final InputStream content = (body.length != 0) ? new ByteArrayInputStream(body) : new StringEntity("").getContent();
final URI endpointUri = URI.create(uri.getScheme() + "://" + uri.getHost());

return convertToSignableRequest(
request.getMethod(),
endpointUri,
uri.getPath(),
headersInternal,
parametersInternal,
content);
// create the HTTP AWS SdkHttpRequest and carry over information
return SdkHttpRequest.builder()
.uri(endpointUri)
.encodedPath(uri.getPath())
.method(SdkHttpMethod.fromValue(request.getMethod()))
.headers(headersInternal)
.rawQueryParameters(parametersInternal)
.build();
}

private HashMap<String, List<String>> extractParametersFromQueryString(final String queryStr) {
Expand Down Expand Up @@ -177,26 +201,6 @@ private HashMap<String, List<String>> extractParametersFromQueryString(final Str
return parameters;
}

private SignableRequest<?> convertToSignableRequest(
final String httpMethodName,
final URI httpEndpointUri,
final String resourcePath,
final Map<String, String> httpHeaders,
final Map<String, List<String>> httpParameters,
final InputStream httpContent) {

// create the HTTP AWS SDK Signable Request and carry over information
final DefaultRequest<?> awsRequest = new DefaultRequest<>(aws4Signer.getServiceName());
awsRequest.setHttpMethod(HttpMethodName.fromValue(httpMethodName));
awsRequest.setEndpoint(httpEndpointUri);
awsRequest.setResourcePath(resourcePath);
awsRequest.setHeaders(httpHeaders);
awsRequest.setParameters(httpParameters);
awsRequest.setContent(httpContent);

return awsRequest;
}

private void checkNotNull(final Object obj, final String errMsg) {
if (obj == null) {
throw new IllegalArgumentException(errMsg);
Expand Down
Loading
Loading