Skip to content

Commit

Permalink
Add additional headers to the client request.
Browse files Browse the repository at this point in the history
Add ClientRequestFilter.java interface in Presto-spi.
Improve Request Headers in the Authentication Filter Class.
  • Loading branch information
SthuthiGhosh9400 committed Jan 16, 2025
1 parent b511e7a commit d5df908
Show file tree
Hide file tree
Showing 17 changed files with 517 additions and 4 deletions.
1 change: 1 addition & 0 deletions presto-docs/src/main/sphinx/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ This guide is intended for Presto contributors and plugin developers.
develop/serialized-page
develop/presto-console
develop/presto-authenticator
develop/client-request-filter
26 changes: 26 additions & 0 deletions presto-docs/src/main/sphinx/develop/client-request-filter.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

======================
Client Request Filter
======================

Presto allows operators to customize the headers used to process queries. Some example use cases include customized authentication workflows, or enriching query attributes such as the query source. Use the Client Request Filter plugin to control header customization during query execution.

Implementation
--------------

The ``ClientRequestFilterFactory`` is responsible for creating instances of ``ClientRequestFilter``. It also defines
the name of the filter.

The ``ClientRequestFilter`` interface provides two methods: ``getExtraHeaders()``, which allows the runtime to quickly check if it needs to apply a more expensive call to enrich the headers, and ``getHeaderNames()``, which returns a list of header names used as the header names in client requests.

The implementation of ``ClientRequestFilterFactory`` must be wrapped as a plugin and installed on the Presto cluster.

After installing a plugin that implements ``ClientRequestFilterFactory`` on the coordinator, the ``AuthenticationFilter`` class passes the ``principal`` object to the request filter, which returns the header values as a map.

Presto uses the request filter to determine whether a header is present in the blocklist. The blocklist includes headers such as ``X-Presto-Transaction-Id``, ``X-Presto-Started-Transaction-Id``, ``X-Presto-Clear-Transaction-Id``, and ``X-Presto-Trace-Token``, which are not allowed to be overridden.

For a complete list of headers, see the `Java source`_.

Note: The `Java source`_ includes these blocklist headers that are not eligible for overriding. The other headers not mentioned here can be overridden.

.. _Java source: https://github.com/prestodb/presto/blob/master/presto-client/src/main/java/com/facebook/presto/client/PrestoHeaders.java
6 changes: 6 additions & 0 deletions presto-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,12 @@
<groupId>io.netty</groupId>
<artifactId>netty-transport</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto;

import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.ClientRequestFilterFactory;
import com.facebook.presto.spi.PrestoException;
import com.google.common.collect.ImmutableList;

import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS;
import static com.facebook.presto.spi.StandardErrorCode.CONFIGURATION_INVALID;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS;
import static com.google.common.collect.ImmutableList.toImmutableList;

@ThreadSafe
public class ClientRequestFilterManager
{
private Map<String, ClientRequestFilterFactory> factories = new ConcurrentHashMap<>();

@GuardedBy("this")
private volatile List<ClientRequestFilter> filters = ImmutableList.of();
private final AtomicBoolean loaded = new AtomicBoolean();

public void registerClientRequestFilterFactory(ClientRequestFilterFactory factory)
{
if (loaded.get()) {
throw new PrestoException(INVALID_ARGUMENTS, "Cannot register factories after filters are loaded.");
}

String name = factory.getName();
if (factories.putIfAbsent(name, factory) != null) {
throw new PrestoException(ALREADY_EXISTS, "A factory with the name '" + name + "' is already registered.");
}
}

public void loadClientRequestFilters()
{
if (!loaded.compareAndSet(false, true)) {
throw new PrestoException(CONFIGURATION_INVALID, "loadClientRequestFilters can only be called once.");
}

filters = factories.values().stream()
.map(factory -> factory.create())
.collect(toImmutableList());
factories = null;
}

public List<ClientRequestFilter> getClientRequestFilters()
{
return filters;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto;

import com.google.inject.Binder;
import com.google.inject.Module;
import com.google.inject.Scopes;

public class ClientRequestFilterModule
implements Module
{
@Override
public void configure(Binder binder)
{
binder.bind(ClientRequestFilterManager.class).in(Scopes.SINGLETON);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.airlift.log.Logger;
import com.facebook.airlift.node.NodeInfo;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.common.block.BlockEncoding;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.common.type.ParametricType;
Expand All @@ -28,6 +29,7 @@
import com.facebook.presto.security.AccessControlManager;
import com.facebook.presto.server.security.PasswordAuthenticatorManager;
import com.facebook.presto.server.security.PrestoAuthenticatorManager;
import com.facebook.presto.spi.ClientRequestFilterFactory;
import com.facebook.presto.spi.CoordinatorPlugin;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.analyzer.AnalyzerProvider;
Expand Down Expand Up @@ -137,6 +139,7 @@ public class PluginManager
private final AnalyzerProviderManager analyzerProviderManager;
private final QueryPreparerProviderManager queryPreparerProviderManager;
private final NodeStatusNotificationManager nodeStatusNotificationManager;
private final ClientRequestFilterManager clientRequestFilterManager;
private final PlanCheckerProviderManager planCheckerProviderManager;

@Inject
Expand All @@ -161,6 +164,7 @@ public PluginManager(
HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager,
TracerProviderManager tracerProviderManager,
NodeStatusNotificationManager nodeStatusNotificationManager,
ClientRequestFilterManager clientRequestFilterManager,
PlanCheckerProviderManager planCheckerProviderManager)
{
requireNonNull(nodeInfo, "nodeInfo is null");
Expand Down Expand Up @@ -194,6 +198,7 @@ public PluginManager(
this.analyzerProviderManager = requireNonNull(analyzerProviderManager, "analyzerProviderManager is null");
this.queryPreparerProviderManager = requireNonNull(queryPreparerProviderManager, "queryPreparerProviderManager is null");
this.nodeStatusNotificationManager = requireNonNull(nodeStatusNotificationManager, "nodeStatusNotificationManager is null");
this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null");
this.planCheckerProviderManager = requireNonNull(planCheckerProviderManager, "planCheckerProviderManager is null");
}

Expand Down Expand Up @@ -364,6 +369,11 @@ public void installPlugin(Plugin plugin)
log.info("Registering node status notification provider %s", nodeStatusNotificationProviderFactory.getName());
nodeStatusNotificationManager.addNodeStatusNotificationProviderFactory(nodeStatusNotificationProviderFactory);
}

for (ClientRequestFilterFactory clientRequestFilterFactory : plugin.getClientRequestFilterFactories()) {
log.info("Registering client request filter factory");
clientRequestFilterManager.registerClientRequestFilterFactory(clientRequestFilterFactory);
}
}

public void installCoordinatorPlugin(CoordinatorPlugin plugin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import com.facebook.airlift.tracetoken.TraceTokenModule;
import com.facebook.drift.server.DriftServer;
import com.facebook.drift.transport.netty.server.DriftNettyServerTransport;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.ClientRequestFilterModule;
import com.facebook.presto.dispatcher.QueryPrerequisitesManager;
import com.facebook.presto.dispatcher.QueryPrerequisitesManagerModule;
import com.facebook.presto.eventlistener.EventListenerManager;
Expand Down Expand Up @@ -117,6 +119,7 @@ public void run()
new NodeModule(),
new DiscoveryModule(),
new HttpServerModule(),
new ClientRequestFilterModule(),
new JsonModule(),
installModuleIf(
FeaturesConfig.class,
Expand Down Expand Up @@ -192,6 +195,7 @@ public void run()
PluginNodeManager pluginNodeManager = new PluginNodeManager(nodeManager, nodeInfo.getEnvironment());
planCheckerProviderManager.loadPlanCheckerProviders(pluginNodeManager);

injector.getInstance(ClientRequestFilterManager.class).loadClientRequestFilters();
startAssociatedProcesses(injector);

injector.getInstance(Announcer.class).start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@

import com.facebook.airlift.http.server.AuthenticationException;
import com.facebook.airlift.http.server.Authenticator;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.PrestoException;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.net.HttpHeaders;

import javax.inject.Inject;
Expand All @@ -35,14 +40,20 @@
import java.io.InputStream;
import java.io.PrintWriter;
import java.security.Principal;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.facebook.presto.spi.StandardErrorCode.HEADER_MODIFICATION_ATTEMPT;
import static com.google.common.io.ByteStreams.copy;
import static com.google.common.io.ByteStreams.nullOutputStream;
import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE;
import static com.google.common.net.MediaType.PLAIN_TEXT_UTF_8;
import static java.util.Collections.enumeration;
import static java.util.Collections.list;
import static java.util.Objects.requireNonNull;
import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED;

Expand All @@ -52,12 +63,15 @@ public class AuthenticationFilter
private static final String HTTPS_PROTOCOL = "https";
private final List<Authenticator> authenticators;
private final boolean allowForwardedHttps;
private final ClientRequestFilterManager clientRequestFilterManager;
private final List<String> headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token");

@Inject
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig)
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager)
{
this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null"));
this.allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps();
this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null");
}

@Override
Expand Down Expand Up @@ -95,9 +109,9 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
e.getAuthenticateHeader().ifPresent(authenticateHeaders::add);
continue;
}

// authentication succeeded
nextFilter.doFilter(withPrincipal(request, principal), response);
HttpServletRequest wrappedRequest = mergeExtraHeaders(request, principal);
nextFilter.doFilter(withPrincipal(wrappedRequest, principal), response);
return;
}

Expand Down Expand Up @@ -126,6 +140,47 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
}
}

public HttpServletRequest mergeExtraHeaders(HttpServletRequest request, Principal principal)
{
List<ClientRequestFilter> clientRequestFilters = clientRequestFilterManager.getClientRequestFilters();

if (clientRequestFilters.isEmpty()) {
return request;
}

ImmutableMap.Builder<String, String> extraHeadersMapBuilder = ImmutableMap.builder();
Set<String> addedHeaders = new HashSet<>();

for (ClientRequestFilter requestFilter : clientRequestFilters) {
boolean headersPresent = requestFilter.getExtraHeaderKeys().stream()
.allMatch(headerName -> request.getHeader(headerName) != null);

if (!headersPresent) {
Map<String, String> extraHeaderValueMap = requestFilter.getExtraHeaders(principal);

if (!extraHeaderValueMap.isEmpty()) {
for (Map.Entry<String, String> extraHeaderEntry : extraHeaderValueMap.entrySet()) {
String headerKey = extraHeaderEntry.getKey();
if (headersBlockList.contains(headerKey)) {
throw new PrestoException(HEADER_MODIFICATION_ATTEMPT,
"Modification attempt detected: The header " + headerKey + " is not allowed to be modified. The following headers cannot be modified: " +
String.join(", ", headersBlockList));
}
if (addedHeaders.contains(headerKey)) {
throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, "Header conflict detected: " + headerKey + " already added by another filter.");
}
if (request.getHeader(headerKey) == null && requestFilter.getExtraHeaderKeys().contains(headerKey)) {
extraHeadersMapBuilder.put(headerKey, extraHeaderEntry.getValue());
addedHeaders.add(headerKey);
}
}
}
}
}

return new ModifiedHttpServletRequest(request, extraHeadersMapBuilder.build());
}

private boolean doesRequestSupportAuthentication(HttpServletRequest request)
{
if (authenticators.isEmpty()) {
Expand Down Expand Up @@ -166,4 +221,43 @@ private static void skipRequestBody(HttpServletRequest request)
copy(inputStream, nullOutputStream());
}
}

public static class ModifiedHttpServletRequest
extends HttpServletRequestWrapper
{
private final Map<String, String> customHeaders;

public ModifiedHttpServletRequest(HttpServletRequest request, Map<String, String> headers)
{
super(request);
this.customHeaders = ImmutableMap.copyOf(requireNonNull(headers, "headers is null"));
}

@Override
public String getHeader(String name)
{
if (customHeaders.containsKey(name)) {
return customHeaders.get(name);
}
return super.getHeader(name);
}

@Override
public Enumeration<String> getHeaderNames()
{
return enumeration(ImmutableSet.<String>builder()
.addAll(customHeaders.keySet())
.addAll(list(super.getHeaderNames()))
.build());
}

@Override
public Enumeration<String> getHeaders(String name)
{
if (customHeaders.containsKey(name)) {
return enumeration(ImmutableList.of(customHeaders.get(name)));
}
return super.getHeaders(name);
}
}
}
Loading

0 comments on commit d5df908

Please sign in to comment.