Skip to content

Commit 2543dab

Browse files
committed
Add extensionName() to security extension
Extension loading code needs to know how to refer to an extension at runtime. It previously used "toString()", but there was no contract that required that this method be implemented in a meaningful way. A new extensionName() method is added which defaults to the class name of the extension, but can be customized by implementations Backport of: elastic#79329
1 parent 3bd8055 commit 2543dab

File tree

3 files changed

+54
-30
lines changed

3 files changed

+54
-30
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityExtension.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,8 @@ default AuthenticationFailureHandler getAuthenticationFailureHandler(SecurityCom
115115
default AuthorizationEngine getAuthorizationEngine(Settings settings) {
116116
return null;
117117
}
118+
119+
default String extensionName() {
120+
return getClass().getName();
121+
}
118122
}

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.elasticsearch.common.util.concurrent.EsExecutors;
4444
import org.elasticsearch.common.util.concurrent.ThreadContext;
4545
import org.elasticsearch.common.util.set.Sets;
46+
import org.elasticsearch.core.Nullable;
4647
import org.elasticsearch.xcontent.NamedXContentRegistry;
4748
import org.elasticsearch.xcontent.XContentBuilder;
4849
import org.elasticsearch.env.Environment;
@@ -597,7 +598,7 @@ Collection<Object> createComponents(Client client, ThreadPool threadPool, Cluste
597598
extensionComponents
598599
);
599600
if (providers != null && providers.isEmpty() == false) {
600-
customRoleProviders.put(extension.toString(), providers);
601+
customRoleProviders.put(extension.extensionName(), providers);
601602
}
602603
}
603604

@@ -695,37 +696,15 @@ auditTrailService, failureHandler, threadPool, anonymousUser, getAuthorizationEn
695696
}
696697

697698
private AuthorizationEngine getAuthorizationEngine() {
698-
AuthorizationEngine authorizationEngine = null;
699-
String extensionName = null;
700-
for (SecurityExtension extension : securityExtensions) {
701-
final AuthorizationEngine extensionEngine = extension.getAuthorizationEngine(settings);
702-
if (extensionEngine != null && authorizationEngine != null) {
703-
throw new IllegalStateException("Extensions [" + extensionName + "] and [" + extension.toString() + "] "
704-
+ "both set an authorization engine");
705-
}
706-
authorizationEngine = extensionEngine;
707-
extensionName = extension.toString();
708-
}
709-
710-
if (authorizationEngine != null) {
711-
logger.debug("Using authorization engine from extension [" + extensionName + "]");
712-
}
713-
return authorizationEngine;
699+
return findValueFromExtensions("authorization engine", extension -> extension.getAuthorizationEngine(settings));
714700
}
715701

716702
private AuthenticationFailureHandler createAuthenticationFailureHandler(final Realms realms,
717703
final SecurityExtension.SecurityComponents components) {
718-
AuthenticationFailureHandler failureHandler = null;
719-
String extensionName = null;
720-
for (SecurityExtension extension : securityExtensions) {
721-
AuthenticationFailureHandler extensionFailureHandler = extension.getAuthenticationFailureHandler(components);
722-
if (extensionFailureHandler != null && failureHandler != null) {
723-
throw new IllegalStateException("Extensions [" + extensionName + "] and [" + extension.toString() + "] "
724-
+ "both set an authentication failure handler");
725-
}
726-
failureHandler = extensionFailureHandler;
727-
extensionName = extension.toString();
728-
}
704+
AuthenticationFailureHandler failureHandler = findValueFromExtensions(
705+
"authentication failure handler",
706+
extension -> extension.getAuthenticationFailureHandler(components)
707+
);
729708
if (failureHandler == null) {
730709
logger.debug("Using default authentication failure handler");
731710
Supplier<Map<String, List<String>>> headersSupplier = () -> {
@@ -762,12 +741,48 @@ private AuthenticationFailureHandler createAuthenticationFailureHandler(final Re
762741
getLicenseState().addListener(() -> {
763742
finalDefaultFailureHandler.setHeaders(headersSupplier.get());
764743
});
765-
} else {
766-
logger.debug("Using authentication failure handler from extension [" + extensionName + "]");
767744
}
768745
return failureHandler;
769746
}
770747

748+
/**
749+
* Calls the provided function for each configured extension and return the value that was generated by the extensions.
750+
* If multiple extensions provide a value, throws {@link IllegalStateException}.
751+
* If no extensions provide a value (or if there are no extensions) returns {@code null}.
752+
*/
753+
@Nullable
754+
private <T> T findValueFromExtensions(String valueType, Function<SecurityExtension, T> method) {
755+
T foundValue = null;
756+
String fromExtension = null;
757+
for (SecurityExtension extension : securityExtensions) {
758+
final T extensionValue = method.apply(extension);
759+
if (extensionValue == null) {
760+
continue;
761+
}
762+
if (foundValue == null) {
763+
foundValue = extensionValue;
764+
fromExtension = extension.extensionName();
765+
} else {
766+
throw new IllegalStateException(
767+
"Extensions ["
768+
+ fromExtension
769+
+ "] and ["
770+
+ extension.extensionName()
771+
+ "] "
772+
+ " both attempted to provide a value for ["
773+
+ valueType
774+
+ "]"
775+
);
776+
}
777+
}
778+
if (foundValue == null) {
779+
return null;
780+
} else {
781+
logger.debug("Using [{}] [{}] from extension [{}]", valueType, foundValue, fromExtension);
782+
return foundValue;
783+
}
784+
}
785+
771786
@Override
772787
public Settings additionalSettings() {
773788
return additionalSettings(settings, enabled, transportClientMode);

x-pack/qa/security-example-spi-extension/src/main/java/org/elasticsearch/example/ExampleSecurityExtension.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ public class ExampleSecurityExtension implements SecurityExtension {
4242
});
4343
}
4444

45+
@Override
46+
public String extensionName() {
47+
return "example";
48+
}
49+
4550
@Override
4651
public Map<String, Realm.Factory> getRealms(SecurityComponents components) {
4752
final Map<String, Realm.Factory> map = new HashMap<>();

0 commit comments

Comments
 (0)