Skip to content

Extract class to store Authentication in context #52032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.node.Node;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.Authentication.AuthenticationType;
import org.elasticsearch.xpack.core.security.authc.support.AuthenticationContextSerializer;
import org.elasticsearch.xpack.core.security.user.User;

import java.io.IOException;
Expand All @@ -29,17 +30,12 @@ public class SecurityContext {
private final Logger logger = LogManager.getLogger(SecurityContext.class);

private final ThreadContext threadContext;
private final UserSettings userSettings;
private final AuthenticationContextSerializer authenticationSerializer;
private final String nodeName;

/**
* Creates a new security context.
* If cryptoService is null, security is disabled and {@link UserSettings#getUser()}
* and {@link UserSettings#getAuthentication()} will always return null.
*/
public SecurityContext(Settings settings, ThreadContext threadContext) {
this.threadContext = threadContext;
this.userSettings = new UserSettings(threadContext);
this.authenticationSerializer = new AuthenticationContextSerializer();
this.nodeName = Node.NODE_NAME_SETTING.get(settings);
}

Expand All @@ -52,13 +48,17 @@ public User getUser() {
/** Returns the authentication information, or null if the current request has no authentication info. */
public Authentication getAuthentication() {
try {
return Authentication.readFromContext(threadContext);
return authenticationSerializer.readFromContext(threadContext);
} catch (IOException e) {
logger.error("failed to read authentication", e);
throw new UncheckedIOException(e);
}
}

public ThreadContext getThreadContext() {
return threadContext;
}

/**
* Sets the user forcefully to the provided user. There must not be an existing user in the ThreadContext otherwise an exception
* will be thrown. This method is package private for testing.
Expand Down Expand Up @@ -103,7 +103,7 @@ public void executeAsUser(User user, Consumer<StoredContext> consumer, Version v
*/
public void executeAfterRewritingAuthentication(Consumer<StoredContext> consumer, Version version) {
final StoredContext original = threadContext.newStoredContext(true);
final Authentication authentication = Objects.requireNonNull(userSettings.getAuthentication());
final Authentication authentication = getAuthentication();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
setAuthentication(new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(),
authentication.getLookedUpBy(), version, authentication.getAuthenticationType(), authentication.getMetadata()));
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.security.authc.support.AuthenticationContextSerializer;
import org.elasticsearch.xpack.core.security.user.InternalUserSerializationHelper;
import org.elasticsearch.xpack.core.security.user.User;

Expand Down Expand Up @@ -92,59 +93,12 @@ public Map<String, Object> getMetadata() {
return metadata;
}

public static Authentication readFromContext(ThreadContext ctx) throws IOException, IllegalArgumentException {
Authentication authentication = ctx.getTransient(AuthenticationField.AUTHENTICATION_KEY);
if (authentication != null) {
assert ctx.getHeader(AuthenticationField.AUTHENTICATION_KEY) != null;
return authentication;
}

String authenticationHeader = ctx.getHeader(AuthenticationField.AUTHENTICATION_KEY);
if (authenticationHeader == null) {
return null;
}
return deserializeHeaderAndPutInContext(authenticationHeader, ctx);
}

public static Authentication getAuthentication(ThreadContext context) {
return context.getTransient(AuthenticationField.AUTHENTICATION_KEY);
}

static Authentication deserializeHeaderAndPutInContext(String header, ThreadContext ctx)
throws IOException, IllegalArgumentException {
assert ctx.getTransient(AuthenticationField.AUTHENTICATION_KEY) == null;

Authentication authentication = decode(header);
ctx.putTransient(AuthenticationField.AUTHENTICATION_KEY, authentication);
return authentication;
}

public static Authentication decode(String header) throws IOException {
byte[] bytes = Base64.getDecoder().decode(header);
StreamInput input = StreamInput.wrap(bytes);
Version version = Version.readVersion(input);
input.setVersion(version);
return new Authentication(input);
}

/**
* Writes the authentication to the context. There must not be an existing authentication in the context and if there is an
* {@link IllegalStateException} will be thrown
*/
public void writeToContext(ThreadContext ctx) throws IOException, IllegalArgumentException {
ensureContextDoesNotContainAuthentication(ctx);
String header = encode();
ctx.putTransient(AuthenticationField.AUTHENTICATION_KEY, this);
ctx.putHeader(AuthenticationField.AUTHENTICATION_KEY, header);
}

void ensureContextDoesNotContainAuthentication(ThreadContext ctx) {
if (ctx.getTransient(AuthenticationField.AUTHENTICATION_KEY) != null) {
if (ctx.getHeader(AuthenticationField.AUTHENTICATION_KEY) == null) {
throw new IllegalStateException("authentication present as a transient but not a header");
}
throw new IllegalStateException("authentication is already present in the context");
}
new AuthenticationContextSerializer().writeToContext(this, ctx);
}

public String encode() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.core.security.authc.support;

import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.AuthenticationField;

import java.io.IOException;
import java.util.Base64;

/**
* A class from reading/writing {@link org.elasticsearch.xpack.core.security.authc.Authentication} objects to/from a
* {@link org.elasticsearch.common.util.concurrent.ThreadContext} under a specified key
*/
public class AuthenticationContextSerializer {

private final String contextKey;

public AuthenticationContextSerializer() {
this(AuthenticationField.AUTHENTICATION_KEY);
}

public AuthenticationContextSerializer(String contextKey) {
this.contextKey = contextKey;
}

@Nullable
public Authentication readFromContext(ThreadContext ctx) throws IOException {
Authentication authentication = ctx.getTransient(contextKey);
if (authentication != null) {
assert ctx.getHeader(contextKey) != null;
return authentication;
}

String authenticationHeader = ctx.getHeader(contextKey);
if (authenticationHeader == null) {
return null;
}
return deserializeHeaderAndPutInContext(authenticationHeader, ctx);
}

Authentication deserializeHeaderAndPutInContext(String headerValue, ThreadContext ctx)
throws IOException, IllegalArgumentException {
assert ctx.getTransient(contextKey) == null;

Authentication authentication = decode(headerValue);
ctx.putTransient(contextKey, authentication);
return authentication;
}

public static Authentication decode(String header) throws IOException {
byte[] bytes = Base64.getDecoder().decode(header);
StreamInput input = StreamInput.wrap(bytes);
Version version = Version.readVersion(input);
input.setVersion(version);
return new Authentication(input);
}

public Authentication getAuthentication(ThreadContext context) {
return context.getTransient(contextKey);
}

/**
* Writes the authentication to the context. There must not be an existing authentication in the context and if there is an
* {@link IllegalStateException} will be thrown
*/
public void writeToContext(Authentication authentication, ThreadContext ctx) throws IOException {
ensureContextDoesNotContainAuthentication(ctx);
String header = authentication.encode();
ctx.putTransient(contextKey, authentication);
ctx.putHeader(contextKey, header);
}

void ensureContextDoesNotContainAuthentication(ThreadContext ctx) {
if (ctx.getTransient(contextKey) != null) {
if (ctx.getHeader(contextKey) == null) {
throw new IllegalStateException("authentication present as a transient ([" + contextKey + "]) but not a header");
}
throw new IllegalStateException("authentication ([" + contextKey + "]) is already present in the context");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
import org.elasticsearch.index.shard.ShardUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.authz.AuthorizationServiceField;
import org.elasticsearch.xpack.core.security.authz.permission.DocumentPermissions;
import org.elasticsearch.xpack.core.security.support.Exceptions;
import org.elasticsearch.xpack.core.security.user.User;

import java.io.IOException;
import java.util.Objects;
import java.util.function.Function;

/**
Expand All @@ -45,16 +46,16 @@ public class SecurityIndexReaderWrapper implements CheckedFunction<DirectoryRead
private final Function<ShardId, QueryShardContext> queryShardContextProvider;
private final DocumentSubsetBitsetCache bitsetCache;
private final XPackLicenseState licenseState;
private final ThreadContext threadContext;
private final SecurityContext securityContext;
private final ScriptService scriptService;

public SecurityIndexReaderWrapper(Function<ShardId, QueryShardContext> queryShardContextProvider,
DocumentSubsetBitsetCache bitsetCache, ThreadContext threadContext, XPackLicenseState licenseState,
ScriptService scriptService) {
DocumentSubsetBitsetCache bitsetCache, SecurityContext securityContext,
XPackLicenseState licenseState, ScriptService scriptService) {
this.scriptService = scriptService;
this.queryShardContextProvider = queryShardContextProvider;
this.bitsetCache = bitsetCache;
this.threadContext = threadContext;
this.securityContext = securityContext;
this.licenseState = licenseState;
}

Expand Down Expand Up @@ -95,16 +96,16 @@ public DirectoryReader apply(final DirectoryReader reader) {
}

protected IndicesAccessControl getIndicesAccessControl() {
final ThreadContext threadContext = securityContext.getThreadContext();
IndicesAccessControl indicesAccessControl = threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY);
if (indicesAccessControl == null) {
throw Exceptions.authorizationError("no indices permissions found");
}
return indicesAccessControl;
}

protected User getUser(){
Authentication authentication = Authentication.getAuthentication(threadContext);
return authentication.getUser();
protected User getUser() {
return Objects.requireNonNull(securityContext.getUser());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.AuthenticationField;
import org.elasticsearch.xpack.core.security.authc.support.AuthenticationContextSerializer;
import org.elasticsearch.xpack.core.watcher.actions.ActionWrapperResult;
import org.elasticsearch.xpack.core.watcher.condition.Condition;
import org.elasticsearch.xpack.core.watcher.history.WatchRecord;
Expand Down Expand Up @@ -262,7 +263,7 @@ public static String getUsernameFromWatch(Watch watch) throws IOException {
if (watch != null && watch.status() != null && watch.status().getHeaders() != null) {
String header = watch.status().getHeaders().get(AuthenticationField.AUTHENTICATION_KEY);
if (header != null) {
Authentication auth = Authentication.decode(header);
Authentication auth = AuthenticationContextSerializer.decode(header);
return auth.getUser().principal();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.test.AbstractBuilderTestCase;
import org.elasticsearch.test.IndexSettingsModule;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.AuthenticationField;
import org.elasticsearch.xpack.core.security.authc.support.AuthenticationContextSerializer;
import org.elasticsearch.xpack.core.security.authz.permission.DocumentPermissions;
import org.elasticsearch.xpack.core.security.authz.permission.FieldPermissions;
import org.elasticsearch.xpack.core.security.user.User;
Expand Down Expand Up @@ -69,10 +70,14 @@ public void testDLS() throws Exception {
when(mapperService.simpleMatchToFullName(anyString()))
.then(invocationOnMock -> Collections.singletonList((String) invocationOnMock.getArguments()[0]));

ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
final SecurityContext securityContext = new SecurityContext(Settings.EMPTY, threadContext);

final Authentication authentication = mock(Authentication.class);
when(authentication.getUser()).thenReturn(mock(User.class));
threadContext.putTransient(AuthenticationField.AUTHENTICATION_KEY, authentication);
when(authentication.encode()).thenReturn(randomAlphaOfLength(24)); // don't care as long as it's not null
new AuthenticationContextSerializer().writeToContext(authentication, threadContext);

IndexSettings indexSettings = IndexSettingsModule.newIndexSettings(shardId.getIndex(), Settings.EMPTY);
Client client = mock(Client.class);
when(client.settings()).thenReturn(Settings.EMPTY);
Expand Down Expand Up @@ -135,7 +140,7 @@ null, null, mapperService, null, null, xContentRegistry(), writableRegistry(),
FieldPermissions(),
DocumentPermissions.filteredBy(singleton(new BytesArray(termQuery))));
SecurityIndexReaderWrapper wrapper = new SecurityIndexReaderWrapper(s -> queryShardContext,
bitsetCache, threadContext, licenseState, scriptService) {
bitsetCache, securityContext, licenseState, scriptService) {

@Override
protected IndicesAccessControl getIndicesAccessControl() {
Expand Down Expand Up @@ -173,10 +178,13 @@ public void testDLSWithLimitedPermissions() throws Exception {
when(mapperService.simpleMatchToFullName(anyString()))
.then(invocationOnMock -> Collections.singletonList((String) invocationOnMock.getArguments()[0]));

ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
final SecurityContext securityContext = new SecurityContext(Settings.EMPTY, threadContext);
final Authentication authentication = mock(Authentication.class);
when(authentication.getUser()).thenReturn(mock(User.class));
threadContext.putTransient(AuthenticationField.AUTHENTICATION_KEY, authentication);
when(authentication.encode()).thenReturn(randomAlphaOfLength(24)); // don't care as long as it's not null
new AuthenticationContextSerializer().writeToContext(authentication, threadContext);

final boolean noFilteredIndexPermissions = randomBoolean();
boolean restrictiveLimitedIndexPermissions = false;
if (noFilteredIndexPermissions == false) {
Expand Down Expand Up @@ -208,7 +216,7 @@ null, null, mapperService, null, null, xContentRegistry(), writableRegistry(),
XPackLicenseState licenseState = mock(XPackLicenseState.class);
when(licenseState.isDocumentAndFieldLevelSecurityAllowed()).thenReturn(true);
SecurityIndexReaderWrapper wrapper = new SecurityIndexReaderWrapper(s -> queryShardContext,
bitsetCache, threadContext, licenseState, scriptService) {
bitsetCache, securityContext, licenseState, scriptService) {

@Override
protected IndicesAccessControl getIndicesAccessControl() {
Expand Down
Loading