Skip to content

Commit

Permalink
feat(gax): add API key authentication to ClientSettings (#3137)
Browse files Browse the repository at this point in the history
Allow gax client libraries to authenticate using API key via setApiKey
method exposed from ClientSettings. Also added deduping to GRPC calls
for api key headers.

Tested using LanguageServiceSettings

cc @westarle
  • Loading branch information
ldetmer authored Oct 2, 2024
1 parent e08906c commit df08956
Show file tree
Hide file tree
Showing 12 changed files with 717 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.internal.EnvironmentProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.ApiKeyCredentials;
import com.google.auth.Credentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.annotations.VisibleForTesting;
Expand All @@ -63,6 +64,8 @@
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
Expand Down Expand Up @@ -123,6 +126,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable private final Boolean allowNonDefaultServiceAccount;
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;
@Nullable private final MtlsProvider mtlsProvider;
@VisibleForTesting final Map<String, String> headersWithDuplicatesRemoved = new HashMap<>();

@Nullable
private final ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
Expand Down Expand Up @@ -408,7 +412,8 @@ ChannelCredentials createMtlsChannelCredentials() throws IOException, GeneralSec

private ManagedChannel createSingleChannel() throws IOException {
GrpcHeaderInterceptor headerInterceptor =
new GrpcHeaderInterceptor(headerProvider.getHeaders());
new GrpcHeaderInterceptor(headersWithDuplicatesRemoved);

GrpcMetadataHandlerInterceptor metadataHandlerInterceptor =
new GrpcMetadataHandlerInterceptor();

Expand Down Expand Up @@ -496,6 +501,28 @@ private ManagedChannel createSingleChannel() throws IOException {
return managedChannel;
}

/* Remove provided headers that will also get set by {@link com.google.auth.ApiKeyCredentials}. They will be added as part of the grpc call when performing auth
* {@link io.grpc.auth.GoogleAuthLibraryCallCredentials#applyRequestMetadata}. GRPC does not dedup headers {@link https://github.com/grpc/grpc-java/blob/a140e1bb0cfa662bcdb7823d73320eb8d49046f1/api/src/main/java/io/grpc/Metadata.java#L504} so we must before initiating the call.
*
* Note: This is specific for ApiKeyCredentials as duplicate API key headers causes a failure on the back end. At this time we are not sure of the behavior for other credentials.
*/
private void removeApiKeyCredentialDuplicateHeaders() {
if (headerProvider != null) {
headersWithDuplicatesRemoved.putAll(headerProvider.getHeaders());
}
if (credentials != null && credentials instanceof ApiKeyCredentials) {
try {
Map<String, List<String>> credentialRequestMetatData = credentials.getRequestMetadata();
if (credentialRequestMetatData != null) {
headersWithDuplicatesRemoved.keySet().removeAll(credentialRequestMetatData.keySet());
}
} catch (IOException e) {
// unreachable, there is no scenario that getRequestMetatData for ApiKeyCredentials will
// throw an IOException
}
}
}

/**
* Marked as Internal Api and intended for internal use. DirectPath must be enabled via the
* settings and a few other configurations/settings must also be valid for the request to go
Expand Down Expand Up @@ -883,7 +910,10 @@ public Builder setDirectPathServiceConfig(Map<String, ?> serviceConfig) {
}

public InstantiatingGrpcChannelProvider build() {
return new InstantiatingGrpcChannelProvider(this);
InstantiatingGrpcChannelProvider instantiatingGrpcChannelProvider =
new InstantiatingGrpcChannelProvider(this);
instantiatingGrpcChannelProvider.removeApiKeyCredentialDuplicateHeaders();
return instantiatingGrpcChannelProvider;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,21 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;

import com.google.api.core.ApiFunction;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.internal.EnvironmentProvider;
import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.ApiKeyCredentials;
import com.google.auth.Credentials;
import com.google.auth.http.AuthHttpConstants;
import com.google.auth.oauth2.CloudShellCredentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -79,6 +83,8 @@

class InstantiatingGrpcChannelProviderTest extends AbstractMtlsTransportChannelTest {
private static final String DEFAULT_ENDPOINT = "test.googleapis.com:443";
private static final String API_KEY_HEADER_VALUE = "fake_api_key_2";
private static final String API_KEY_AUTH_HEADER_KEY = "x-goog-api-key";
private static String originalOSName;
private ComputeEngineCredentials computeEngineCredentials;

Expand Down Expand Up @@ -877,6 +883,103 @@ public void canUseDirectPath_nonGDUUniverseDomain() {
Truth.assertThat(provider.canUseDirectPath()).isFalse();
}

@Test
void providerInitializedWithNonConflictingHeaders_retainsHeaders() {
InstantiatingGrpcChannelProvider.Builder builder =
InstantiatingGrpcChannelProvider.newBuilder()
.setHeaderProvider(getHeaderProviderWithApiKeyHeader())
.setEndpoint("test.random.com:443");

InstantiatingGrpcChannelProvider provider = builder.build();

assertEquals(1, provider.headersWithDuplicatesRemoved.size());
assertEquals(
API_KEY_HEADER_VALUE, provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY));
}

@Test
void providersInitializedWithConflictingApiKeyCredentialHeaders_removesDuplicates() {
String correctApiKey = "fake_api_key";
ApiKeyCredentials apiKeyCredentials = ApiKeyCredentials.create(correctApiKey);
InstantiatingGrpcChannelProvider.Builder builder =
InstantiatingGrpcChannelProvider.newBuilder()
.setCredentials(apiKeyCredentials)
.setHeaderProvider(getHeaderProviderWithApiKeyHeader())
.setEndpoint("test.random.com:443");

InstantiatingGrpcChannelProvider provider = builder.build();

assertEquals(0, provider.headersWithDuplicatesRemoved.size());
assertNull(provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY));
}

@Test
void providersInitializedWithConflictingNonApiKeyCredentialHeaders_doesNotRemoveDuplicates() {
String authProvidedHeader = "Bearer token";
Map<String, String> header = new HashMap<>();
header.put(AuthHttpConstants.AUTHORIZATION, authProvidedHeader);
InstantiatingGrpcChannelProvider.Builder builder =
InstantiatingGrpcChannelProvider.newBuilder()
.setCredentials(computeEngineCredentials)
.setHeaderProvider(FixedHeaderProvider.create(header))
.setEndpoint("test.random.com:443");

InstantiatingGrpcChannelProvider provider = builder.build();

assertEquals(1, provider.headersWithDuplicatesRemoved.size());
assertEquals(
authProvidedHeader,
provider.headersWithDuplicatesRemoved.get(AuthHttpConstants.AUTHORIZATION));
}

@Test
void buildProvider_handlesNullHeaderProvider() {
InstantiatingGrpcChannelProvider.Builder builder =
InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("test.random.com:443");

InstantiatingGrpcChannelProvider provider = builder.build();

assertEquals(0, provider.headersWithDuplicatesRemoved.size());
}

@Test
void buildProvider_handlesNullCredentialsMetadataRequest() throws IOException {
Credentials credentials = Mockito.mock(Credentials.class);
Mockito.when(credentials.getRequestMetadata()).thenReturn(null);
InstantiatingGrpcChannelProvider.Builder builder =
InstantiatingGrpcChannelProvider.newBuilder()
.setHeaderProvider(getHeaderProviderWithApiKeyHeader())
.setEndpoint("test.random.com:443");

InstantiatingGrpcChannelProvider provider = builder.build();

assertEquals(1, provider.headersWithDuplicatesRemoved.size());
assertEquals(
API_KEY_HEADER_VALUE, provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY));
}

@Test
void buildProvider_handlesErrorRetrievingCredentialsMetadataRequest() throws IOException {
Credentials credentials = Mockito.mock(Credentials.class);
Mockito.when(credentials.getRequestMetadata())
.thenThrow(new IOException("Error getting request metadata"));
InstantiatingGrpcChannelProvider.Builder builder =
InstantiatingGrpcChannelProvider.newBuilder()
.setHeaderProvider(getHeaderProviderWithApiKeyHeader())
.setEndpoint("test.random.com:443");
InstantiatingGrpcChannelProvider provider = builder.build();

assertEquals(1, provider.headersWithDuplicatesRemoved.size());
assertEquals(
API_KEY_HEADER_VALUE, provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY));
}

private FixedHeaderProvider getHeaderProviderWithApiKeyHeader() {
Map<String, String> header = new HashMap<>();
header.put(API_KEY_AUTH_HEADER_KEY, API_KEY_HEADER_VALUE);
return FixedHeaderProvider.create(header);
}

private static class FakeLogHandler extends Handler {

List<LogRecord> records = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
import com.google.api.gax.rpc.internal.QuotaProjectIdHidingCredentials;
import com.google.api.gax.tracing.ApiTracerFactory;
import com.google.api.gax.tracing.BaseApiTracerFactory;
import com.google.auth.ApiKeyCredentials;
import com.google.auth.Credentials;
import com.google.auth.oauth2.GdchCredentials;
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
Expand Down Expand Up @@ -175,9 +177,9 @@ public static ClientContext create(StubSettings settings) throws IOException {
// A valid EndpointContext should have been created in the StubSettings
EndpointContext endpointContext = settings.getEndpointContext();
String endpoint = endpointContext.resolvedEndpoint();

Credentials credentials = getCredentials(settings);
// check if need to adjust credentials/endpoint/endpointContext for GDC-H
String settingsGdchApiAudience = settings.getGdchApiAudience();
Credentials credentials = settings.getCredentialsProvider().getCredentials();
boolean usingGDCH = credentials instanceof GdchCredentials;
if (usingGDCH) {
// Can only determine if the GDC-H is being used via the Credentials. The Credentials object
Expand All @@ -187,22 +189,7 @@ public static ClientContext create(StubSettings settings) throws IOException {
// Resolve the new endpoint with the GDC-H flow
endpoint = endpointContext.resolvedEndpoint();
// We recompute the GdchCredentials with the audience
String audienceString;
if (!Strings.isNullOrEmpty(settingsGdchApiAudience)) {
audienceString = settingsGdchApiAudience;
} else if (!Strings.isNullOrEmpty(endpoint)) {
audienceString = endpoint;
} else {
throw new IllegalArgumentException("Could not infer GDCH api audience from settings");
}

URI gdchAudienceUri;
try {
gdchAudienceUri = URI.create(audienceString);
} catch (IllegalArgumentException ex) { // thrown when passing a malformed uri string
throw new IllegalArgumentException("The GDC-H API audience string is not a valid URI", ex);
}
credentials = ((GdchCredentials) credentials).createWithGdchAudience(gdchAudienceUri);
credentials = getGdchCredentials(settingsGdchApiAudience, endpoint, credentials);
} else if (!Strings.isNullOrEmpty(settingsGdchApiAudience)) {
throw new IllegalArgumentException(
"GDC-H API audience can only be set when using GdchCredentials");
Expand Down Expand Up @@ -291,6 +278,43 @@ public static ClientContext create(StubSettings settings) throws IOException {
.build();
}

/** Determines which credentials to use. API key overrides credentials provided by provider. */
private static Credentials getCredentials(StubSettings settings) throws IOException {
Credentials credentials;
if (settings.getApiKey() != null) {
// if API key exists it becomes the default credential
credentials = ApiKeyCredentials.create(settings.getApiKey());
} else {
credentials = settings.getCredentialsProvider().getCredentials();
}
return credentials;
}

/**
* Constructs a new {@link com.google.auth.Credentials} object based on credentials provided with
* a GDC-H audience
*/
@VisibleForTesting
static GdchCredentials getGdchCredentials(
String settingsGdchApiAudience, String endpoint, Credentials credentials) throws IOException {
String audienceString;
if (!Strings.isNullOrEmpty(settingsGdchApiAudience)) {
audienceString = settingsGdchApiAudience;
} else if (!Strings.isNullOrEmpty(endpoint)) {
audienceString = endpoint;
} else {
throw new IllegalArgumentException("Could not infer GDCH api audience from settings");
}

URI gdchAudienceUri;
try {
gdchAudienceUri = URI.create(audienceString);
} catch (IllegalArgumentException ex) { // thrown when passing a malformed uri string
throw new IllegalArgumentException("The GDC-H API audience string is not a valid URI", ex);
}
return ((GdchCredentials) credentials).createWithGdchAudience(gdchAudienceUri);
}

/**
* Getting a header map from HeaderProvider and InternalHeaderProvider from settings with Quota
* Project Id.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ public final WatchdogProvider getWatchdogProvider() {
return stubSettings.getStreamWatchdogProvider();
}

/** Gets the API Key that should be used for authentication. */
public final String getApiKey() {
return stubSettings.getApiKey();
}

/** This method is obsolete. Use {@link #getWatchdogCheckIntervalDuration()} instead. */
@Nonnull
@ObsoleteApi("Use getWatchdogCheckIntervalDuration() instead")
Expand Down Expand Up @@ -144,6 +149,7 @@ public String toString() {
.add("watchdogProvider", getWatchdogProvider())
.add("watchdogCheckInterval", getWatchdogCheckInterval())
.add("gdchApiAudience", getGdchApiAudience())
.add("apiKey", getApiKey())
.toString();
}

Expand Down Expand Up @@ -302,6 +308,21 @@ public B setGdchApiAudience(@Nullable String gdchApiAudience) {
return self();
}

/**
* Sets the API key. The API key will get translated to an {@link
* com.google.auth.ApiKeyCredentials} and stored in {@link ClientContext}.
*
* <p>API Key authorization is not supported for every product. Please check the documentation
* for each product to confirm if it is supported.
*
* <p>Note: If you set an API key and {@link CredentialsProvider} in the same ClientSettings the
* API key will override any credentials provided.
*/
public B setApiKey(String apiKey) {
stubSettings.setApiKey(apiKey);
return self();
}

/**
* Gets the ExecutorProvider that was previously set on this Builder. This ExecutorProvider is
* to use for running asynchronous API call logic (such as retries and long-running operations),
Expand Down Expand Up @@ -364,6 +385,11 @@ public WatchdogProvider getWatchdogProvider() {
return stubSettings.getStreamWatchdogProvider();
}

/** Gets the API Key that was previously set on this Builder. */
public String getApiKey() {
return stubSettings.getApiKey();
}

/** This method is obsolete. Use {@link #getWatchdogCheckIntervalDuration()} instead */
@Nullable
@ObsoleteApi("Use getWatchdogCheckIntervalDuration() instead")
Expand Down Expand Up @@ -405,6 +431,7 @@ public String toString() {
.add("watchdogProvider", getWatchdogProvider())
.add("watchdogCheckInterval", getWatchdogCheckIntervalDuration())
.add("gdchApiAudience", getGdchApiAudience())
.add("apiKey", getApiKey())
.toString();
}
}
Expand Down
Loading

0 comments on commit df08956

Please sign in to comment.