Skip to content

AWS: Prevent excessive creation of auth sessions in S3V4RestSignerClient #13215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .palantir/revapi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,12 @@ acceptedBreaks:
- code: "java.field.removedWithConstant"
old: "field org.apache.iceberg.TableProperties.ROW_LINEAGE"
justification: "Removing deprecations for 1.10.0"
- code: "java.method.inheritedMovedToClass"
old: "method org.apache.iceberg.rest.auth.AuthSession org.apache.iceberg.rest.auth.AuthManager::tableSession(org.apache.iceberg.rest.RESTClient,\
\ java.util.Map<java.lang.String, java.lang.String>) @ org.apache.iceberg.rest.auth.OAuth2Manager"
new: "method org.apache.iceberg.rest.auth.OAuth2Util.AuthSession org.apache.iceberg.rest.auth.OAuth2Manager::tableSession(org.apache.iceberg.rest.RESTClient,\
\ java.util.Map<java.lang.String, java.lang.String>)"
justification: "overriding a default method is source- and binary-compatible"
- code: "java.method.removed"
old: "method boolean org.apache.iceberg.TableMetadata::rowLineageEnabled()"
justification: "Removing deprecations for 1.10.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import software.amazon.awssdk.services.s3.model.ObjectIdentifier;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.UploadPartRequest;
import software.amazon.awssdk.utils.IoUtils;

@Testcontainers
public class TestS3RestSigner {
Expand Down Expand Up @@ -122,8 +123,7 @@ public static void afterClass() throws Exception {
// there aren't other token refreshes being scheduled after every sign request and after
// TestS3RestSigner completes all tests, there should be only this single token in the queue
// that is scheduled for refresh
assertThat(validatingSigner.icebergSigner)
.extracting("authManager")
assertThat(S3V4RestSignerClient.authManager)
.extracting("refreshExecutor")
.asInstanceOf(type(ScheduledThreadPoolExecutor.class))
.satisfies(
Expand All @@ -143,6 +143,11 @@ public static void afterClass() throws Exception {
if (null != httpServer) {
httpServer.stop();
}

IoUtils.closeQuietlyV2(S3V4RestSignerClient.authManager, null);
IoUtils.closeQuietlyV2(S3V4RestSignerClient.httpClient, null);
S3V4RestSignerClient.authManager = null;
S3V4RestSignerClient.httpClient = null;
}

@BeforeEach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
public abstract class S3V4RestSignerClient
extends AbstractAws4Signer<AwsS3V4SignerParams, Aws4PresignerParams> implements AutoCloseable {

static {
installShutdownHook();
}

private static final Logger LOG = LoggerFactory.getLogger(S3V4RestSignerClient.class);
public static final String S3_SIGNER_URI = "s3.signer.uri";
public static final String S3_SIGNER_ENDPOINT = "s3.signer.endpoint";
Expand All @@ -76,15 +80,25 @@ public abstract class S3V4RestSignerClient

private static final String SCOPE = "sign";

@SuppressWarnings("immutables:incompat")
private volatile AuthManager authManager;
@SuppressWarnings({"immutables:incompat", "VisibilityModifier"})
@VisibleForTesting
static volatile AuthManager authManager;

@SuppressWarnings({"immutables:incompat", "VisibilityModifier"})
@VisibleForTesting
static volatile RESTClient httpClient;

@SuppressWarnings("immutables:incompat")
private volatile AuthSession authSession;
@SuppressWarnings("ShutdownHook")
private static void installShutdownHook() {
Runtime.getRuntime()
.addShutdownHook(
new Thread(
() -> {
IoUtils.closeQuietlyV2(authManager, null);
IoUtils.closeQuietlyV2(httpClient, null);
},
"S3V4RestSignerClient-shutdown-hook"));
}

public abstract Map<String, String> properties();

Expand Down Expand Up @@ -135,6 +149,18 @@ boolean keepTokenRefreshed() {
OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT);
}

private AuthManager authManager() {
if (null == authManager) {
synchronized (S3V4RestSignerClient.class) {
if (null == authManager) {
authManager = AuthManagers.loadAuthManager("s3-signer", properties());
}
}
}

return authManager;
}

private RESTClient httpClient() {
if (null == httpClient) {
synchronized (S3V4RestSignerClient.class) {
Expand All @@ -153,32 +179,23 @@ private RESTClient httpClient() {

@VisibleForTesting
AuthSession authSession() {
if (null == authSession) {
synchronized (S3V4RestSignerClient.class) {
if (null == authSession) {
authManager = AuthManagers.loadAuthManager("s3-signer", properties());
ImmutableMap.Builder<String, String> properties =
ImmutableMap.<String, String>builder()
.putAll(properties())
.putAll(optionalOAuthParams())
.put(OAuth2Properties.OAUTH2_SERVER_URI, oauth2ServerUri())
.put(OAuth2Properties.TOKEN_REFRESH_ENABLED, String.valueOf(keepTokenRefreshed()))
.put(OAuth2Properties.SCOPE, SCOPE);
String token = token().get();
if (null != token) {
properties.put(OAuth2Properties.TOKEN, token);
}

if (credentialProvided()) {
properties.put(OAuth2Properties.CREDENTIAL, credential());
}

authSession = authManager.tableSession(httpClient(), properties.buildKeepingLast());
}
}
ImmutableMap.Builder<String, String> properties =
ImmutableMap.<String, String>builder()
.putAll(properties())
.putAll(optionalOAuthParams())
.put(OAuth2Properties.OAUTH2_SERVER_URI, oauth2ServerUri())
.put(OAuth2Properties.TOKEN_REFRESH_ENABLED, String.valueOf(keepTokenRefreshed()))
.put(OAuth2Properties.SCOPE, SCOPE);
String token = token().get();
if (null != token) {
properties.put(OAuth2Properties.TOKEN, token);
}

return authSession;
if (credentialProvided()) {
properties.put(OAuth2Properties.CREDENTIAL, credential());
}

return authManager().tableSession(httpClient(), properties.buildKeepingLast());
}

private boolean credentialProvided() {
Expand Down Expand Up @@ -283,10 +300,7 @@ public SdkHttpFullRequest sign(
}

@Override
public void close() throws Exception {
IoUtils.closeQuietlyV2(authSession, null);
IoUtils.closeQuietlyV2(authManager, null);
}
public void close() throws Exception {}

/**
* Only add body for DeleteObjectsRequest. Refer to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,19 @@
import org.apache.iceberg.rest.auth.OAuth2Util;
import org.apache.iceberg.rest.responses.OAuthTokenResponse;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;
import software.amazon.awssdk.utils.IoUtils;

class TestS3V4RestSignerClient {

@BeforeAll
static void beforeAll() {
S3V4RestSignerClient.authManager = null;
S3V4RestSignerClient.httpClient = Mockito.mock(RESTClient.class);
when(S3V4RestSignerClient.httpClient.withAuthSession(Mockito.any()))
.thenReturn(S3V4RestSignerClient.httpClient);
Expand Down Expand Up @@ -69,6 +72,12 @@ static void afterAll() {
S3V4RestSignerClient.httpClient = null;
}

@AfterEach
void afterEach() {
IoUtils.closeQuietlyV2(S3V4RestSignerClient.authManager, null);
S3V4RestSignerClient.authManager = null;
}

@ParameterizedTest
@MethodSource("validOAuth2Properties")
void authSessionOAuth2(Map<String, String> properties, String expectedToken) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ default AuthSession initSession(RESTClient initClient, Map<String, String> prope
* Returns a new session targeting a table or view. This method is intended for components other
* that the catalog that need to access tables or views, such as request signer clients.
*
* <p>This method cannot return null.
* <p>This method cannot return null. By default, it returns the catalog session.
*
* <p>Implementors should cache table sessions internally, as the owning component will not cache
* them. Also, the owning component never closes table sessions; implementations should manage
* their lifecycle themselves and close them when they are no longer needed.
*/
default AuthSession tableSession(RESTClient sharedClient, Map<String, String> properties) {
return catalogSession(sharedClient, properties);
Expand Down
46 changes: 46 additions & 0 deletions core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,38 @@ public OAuth2Util.AuthSession tableSession(
(OAuth2Util.AuthSession) parent);
}

@Override
public AuthSession tableSession(RESTClient sharedClient, Map<String, String> properties) {

AuthConfig config = AuthConfig.fromProperties(properties);
Map<String, String> headers = OAuth2Util.authHeaders(config.token());
OAuth2Util.AuthSession parent = new OAuth2Util.AuthSession(headers, config);

keepRefreshed(config.keepRefreshed());

// Important: this method is invoked from standalone components; we must not assume that
// the refresh client and session cache have been initialized, because catalogSession()
// won't be called.
if (refreshClient == null) {
refreshClient = sharedClient.withAuthSession(parent);
}
if (sessionCache == null) {
sessionCache = newSessionCache(name, properties);
}

if (config.token() != null) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically the same logic as in catalogSession(), but with caching enabled.

return sessionCache.cachedSession(
config.token(), k -> newSessionFromAccessToken(config.token(), properties, parent));
}

if (config.credential() != null && !config.credential().isEmpty()) {
return sessionCache.cachedSession(
config.credential(), k -> newSessionFromTokenResponse(config, parent));
}

return parent;
}

@Override
public void close() {
try {
Expand Down Expand Up @@ -226,6 +258,20 @@ protected OAuth2Util.AuthSession newSessionFromTokenExchange(
refreshClient, refreshExecutor(), token, tokenType, parent);
}

protected OAuth2Util.AuthSession newSessionFromTokenResponse(
AuthConfig config, OAuth2Util.AuthSession parent) {
OAuthTokenResponse response =
OAuth2Util.fetchToken(
refreshClient,
Map.of(),
config.credential(),
config.scope(),
config.oauth2ServerUri(),
config.optionalOAuthParams());
return OAuth2Util.AuthSession.fromTokenResponse(
refreshClient, refreshExecutor(), response, System.currentTimeMillis(), parent);
}

private static void warnIfDeprecatedTokenEndpointUsed(Map<String, String> properties) {
if (usesDeprecatedTokenEndpoint(properties)) {
String credential = properties.get(OAuth2Properties.CREDENTIAL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,81 @@ void tableSessionDisallowedTableProperties() {
Mockito.verifyNoMoreInteractions(client);
}

@Test
void standaloneTableSessionEmptyProperties() {
Map<String, String> properties = Map.of();
try (OAuth2Manager manager = new OAuth2Manager("test");
OAuth2Util.AuthSession tableSession =
(OAuth2Util.AuthSession) manager.tableSession(client, properties)) {
assertThat(tableSession.headers()).isEmpty();
assertThat(manager)
.extracting("refreshExecutor")
.as("should not create refresh executor when no table credentials provided")
.isNull();
assertThat(manager)
.extracting("sessionCache")
.asInstanceOf(type(AuthSessionCache.class))
.as("should create session cache for table with token")
.satisfies(cache -> assertThat(cache.sessionCache().asMap()).isEmpty());
}
Mockito.verify(client).withAuthSession(any());
Mockito.verifyNoMoreInteractions(client);
}

@Test
void standaloneTableSessionTokenProvided() {
Map<String, String> tableProperties = Map.of(OAuth2Properties.TOKEN, "table-token");
try (OAuth2Manager manager = new OAuth2Manager("test");
OAuth2Util.AuthSession tableSession =
(OAuth2Util.AuthSession) manager.tableSession(client, tableProperties)) {
assertThat(tableSession.headers()).containsOnly(entry("Authorization", "Bearer table-token"));
assertThat(manager)
.extracting("refreshExecutor")
.as("should create refresh executor when table session created")
.isNotNull();
assertThat(manager)
.extracting("sessionCache")
.asInstanceOf(type(AuthSessionCache.class))
.as("should create session cache for table with token")
.satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1));
}
Mockito.verify(client).withAuthSession(any());
Mockito.verifyNoMoreInteractions(client);
}

@Test
void standaloneTableSessionCredentialProvided() {
Map<String, String> tableProperties = Map.of(OAuth2Properties.CREDENTIAL, "client:secret");
try (OAuth2Manager manager = new OAuth2Manager("test");
OAuth2Util.AuthSession tableSession =
(OAuth2Util.AuthSession) manager.tableSession(client, tableProperties)) {
assertThat(tableSession.headers()).containsOnly(entry("Authorization", "Bearer test"));
assertThat(manager)
.extracting("refreshExecutor")
.as("should create refresh executor when table session created")
.isNotNull();
assertThat(manager)
.extracting("sessionCache")
.asInstanceOf(type(AuthSessionCache.class))
.as("should create session cache for table with token")
.satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1));
}
Mockito.verify(client).withAuthSession(any());
Mockito.verify(client)
.postForm(
any(),
eq(
Map.of(
"grant_type", "client_credentials",
"client_id", "client",
"client_secret", "secret",
"scope", "catalog")),
eq(OAuthTokenResponse.class),
eq(Map.of()),
any());
Mockito.verifyNoMoreInteractions(client);
}

@Test
void close() {
Map<String, String> catalogProperties = Map.of();
Expand Down