|
43 | 43 | import org.elasticsearch.common.util.concurrent.EsExecutors;
|
44 | 44 | import org.elasticsearch.common.util.concurrent.ThreadContext;
|
45 | 45 | import org.elasticsearch.common.util.set.Sets;
|
| 46 | +import org.elasticsearch.core.Nullable; |
46 | 47 | import org.elasticsearch.xcontent.NamedXContentRegistry;
|
47 | 48 | import org.elasticsearch.xcontent.XContentBuilder;
|
48 | 49 | import org.elasticsearch.env.Environment;
|
@@ -597,7 +598,7 @@ Collection<Object> createComponents(Client client, ThreadPool threadPool, Cluste
|
597 | 598 | extensionComponents
|
598 | 599 | );
|
599 | 600 | if (providers != null && providers.isEmpty() == false) {
|
600 |
| - customRoleProviders.put(extension.toString(), providers); |
| 601 | + customRoleProviders.put(extension.extensionName(), providers); |
601 | 602 | }
|
602 | 603 | }
|
603 | 604 |
|
@@ -695,37 +696,15 @@ auditTrailService, failureHandler, threadPool, anonymousUser, getAuthorizationEn
|
695 | 696 | }
|
696 | 697 |
|
697 | 698 | private AuthorizationEngine getAuthorizationEngine() {
|
698 |
| - AuthorizationEngine authorizationEngine = null; |
699 |
| - String extensionName = null; |
700 |
| - for (SecurityExtension extension : securityExtensions) { |
701 |
| - final AuthorizationEngine extensionEngine = extension.getAuthorizationEngine(settings); |
702 |
| - if (extensionEngine != null && authorizationEngine != null) { |
703 |
| - throw new IllegalStateException("Extensions [" + extensionName + "] and [" + extension.toString() + "] " |
704 |
| - + "both set an authorization engine"); |
705 |
| - } |
706 |
| - authorizationEngine = extensionEngine; |
707 |
| - extensionName = extension.toString(); |
708 |
| - } |
709 |
| - |
710 |
| - if (authorizationEngine != null) { |
711 |
| - logger.debug("Using authorization engine from extension [" + extensionName + "]"); |
712 |
| - } |
713 |
| - return authorizationEngine; |
| 699 | + return findValueFromExtensions("authorization engine", extension -> extension.getAuthorizationEngine(settings)); |
714 | 700 | }
|
715 | 701 |
|
716 | 702 | private AuthenticationFailureHandler createAuthenticationFailureHandler(final Realms realms,
|
717 | 703 | final SecurityExtension.SecurityComponents components) {
|
718 |
| - AuthenticationFailureHandler failureHandler = null; |
719 |
| - String extensionName = null; |
720 |
| - for (SecurityExtension extension : securityExtensions) { |
721 |
| - AuthenticationFailureHandler extensionFailureHandler = extension.getAuthenticationFailureHandler(components); |
722 |
| - if (extensionFailureHandler != null && failureHandler != null) { |
723 |
| - throw new IllegalStateException("Extensions [" + extensionName + "] and [" + extension.toString() + "] " |
724 |
| - + "both set an authentication failure handler"); |
725 |
| - } |
726 |
| - failureHandler = extensionFailureHandler; |
727 |
| - extensionName = extension.toString(); |
728 |
| - } |
| 704 | + AuthenticationFailureHandler failureHandler = findValueFromExtensions( |
| 705 | + "authentication failure handler", |
| 706 | + extension -> extension.getAuthenticationFailureHandler(components) |
| 707 | + ); |
729 | 708 | if (failureHandler == null) {
|
730 | 709 | logger.debug("Using default authentication failure handler");
|
731 | 710 | Supplier<Map<String, List<String>>> headersSupplier = () -> {
|
@@ -762,12 +741,48 @@ private AuthenticationFailureHandler createAuthenticationFailureHandler(final Re
|
762 | 741 | getLicenseState().addListener(() -> {
|
763 | 742 | finalDefaultFailureHandler.setHeaders(headersSupplier.get());
|
764 | 743 | });
|
765 |
| - } else { |
766 |
| - logger.debug("Using authentication failure handler from extension [" + extensionName + "]"); |
767 | 744 | }
|
768 | 745 | return failureHandler;
|
769 | 746 | }
|
770 | 747 |
|
| 748 | + /** |
| 749 | + * Calls the provided function for each configured extension and return the value that was generated by the extensions. |
| 750 | + * If multiple extensions provide a value, throws {@link IllegalStateException}. |
| 751 | + * If no extensions provide a value (or if there are no extensions) returns {@code null}. |
| 752 | + */ |
| 753 | + @Nullable |
| 754 | + private <T> T findValueFromExtensions(String valueType, Function<SecurityExtension, T> method) { |
| 755 | + T foundValue = null; |
| 756 | + String fromExtension = null; |
| 757 | + for (SecurityExtension extension : securityExtensions) { |
| 758 | + final T extensionValue = method.apply(extension); |
| 759 | + if (extensionValue == null) { |
| 760 | + continue; |
| 761 | + } |
| 762 | + if (foundValue == null) { |
| 763 | + foundValue = extensionValue; |
| 764 | + fromExtension = extension.extensionName(); |
| 765 | + } else { |
| 766 | + throw new IllegalStateException( |
| 767 | + "Extensions [" |
| 768 | + + fromExtension |
| 769 | + + "] and [" |
| 770 | + + extension.extensionName() |
| 771 | + + "] " |
| 772 | + + " both attempted to provide a value for [" |
| 773 | + + valueType |
| 774 | + + "]" |
| 775 | + ); |
| 776 | + } |
| 777 | + } |
| 778 | + if (foundValue == null) { |
| 779 | + return null; |
| 780 | + } else { |
| 781 | + logger.debug("Using [{}] [{}] from extension [{}]", valueType, foundValue, fromExtension); |
| 782 | + return foundValue; |
| 783 | + } |
| 784 | + } |
| 785 | + |
771 | 786 | @Override
|
772 | 787 | public Settings additionalSettings() {
|
773 | 788 | return additionalSettings(settings, enabled, transportClientMode);
|
|
0 commit comments