Skip to content

Commit

Permalink
reduce duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
strehle committed Nov 7, 2024
1 parent ab1f363 commit a0967e7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ public static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefa
Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false)));
}

private Consumer<AssertionToken> createDefaultAssertionElementsDecrypter() {
public static Consumer<AssertionToken> createDefaultAssertionElementsDecrypter() {
return assertionToken -> {
Assertion assertion = assertionToken.getAssertion();
RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration();
Expand All @@ -465,7 +465,7 @@ private Consumer<AssertionToken> createDefaultAssertionElementsDecrypter() {
};
}

private boolean hasName(Assertion assertion) {
public static boolean hasName(Assertion assertion) {
if (assertion == null) {
return false;
}
Expand All @@ -478,7 +478,7 @@ private boolean hasName(Assertion assertion) {
return assertion.getSubject().getNameID().getValue() != null;
}

private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
public static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
MultiValueMap<String, Object> attributeMap = new LinkedMultiValueMap<>();
for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
for (Attribute attribute : attributeStatement.getAttributes()) {
Expand All @@ -495,15 +495,15 @@ private static Map<String, List<Object>> getAssertionAttributes(Assertion assert
return new LinkedHashMap<>(attributeMap); // gh-11785
}

private static List<String> getSessionIndexes(Assertion assertion) {
public static List<String> getSessionIndexes(Assertion assertion) {
List<String> sessionIndexes = new ArrayList<>();
for (AuthnStatement statement : assertion.getAuthnStatements()) {
sessionIndexes.add(statement.getSessionIndex());
}
return sessionIndexes;
}

private static Object getXmlObjectValue(XMLObject xmlObject) {
public static Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject instanceof XSAny xsAny) {
return xsAny.getTextContent();
}
Expand All @@ -526,7 +526,7 @@ private static Object getXmlObjectValue(XMLObject xmlObject) {
return xmlObject;
}

private static Saml2AuthenticationException createAuthenticationException(String code, String message,
public static Saml2AuthenticationException createAuthenticationException(String code, String message,
Exception cause) {
return new Saml2AuthenticationException(new Saml2Error(code, message), cause);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,10 @@
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.beans.IdentityZoneManager;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.schema.XSBoolean;
import org.opensaml.core.xml.schema.XSBooleanValue;
import org.opensaml.core.xml.schema.XSDateTime;
import org.opensaml.core.xml.schema.XSInteger;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.core.xml.schema.XSURI;
import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AuthnStatement;
import org.opensaml.saml.saml2.core.impl.AssertionUnmarshaller;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
Expand Down Expand Up @@ -84,10 +73,8 @@
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -122,7 +109,7 @@ public final class Saml2BearerGrantAuthenticationConverter implements Authentica

private final Converter<OpenSaml4AuthenticationProvider.AssertionToken, Saml2ResponseValidatorResult> assertionSignatureValidator = OpenSaml4AuthenticationProvider.createDefaultAssertionSignatureValidator();

private final Consumer<OpenSaml4AuthenticationProvider.AssertionToken> assertionElementsDecrypter = createDefaultAssertionElementsDecrypter();
private final Consumer<OpenSaml4AuthenticationProvider.AssertionToken> assertionElementsDecrypter = OpenSaml4AuthenticationProvider.createDefaultAssertionElementsDecrypter();

private final Converter<OpenSaml4AuthenticationProvider.AssertionToken, Saml2ResponseValidatorResult> assertionValidator = createDefaultAssertionValidator();

Expand Down Expand Up @@ -177,8 +164,8 @@ public static Converter<OpenSaml4AuthenticationProvider.AssertionToken, Saml2Res
Assertion assertion = assertionToken.getAssertion();
Saml2AuthenticationToken token = assertionToken.getToken();
String username = assertion.getSubject().getNameID().getValue();
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
List<String> sessionIndexes = getSessionIndexes(assertion);
Map<String, List<Object>> attributes = OpenSaml4AuthenticationProvider.getAssertionAttributes(assertion);
List<String> sessionIndexes = OpenSaml4AuthenticationProvider.getSessionIndexes(assertion);
DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal(username, attributes,
sessionIndexes);
String registrationId = token.getRelyingPartyRegistration().getRegistrationId();
Expand Down Expand Up @@ -314,7 +301,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
} catch (Saml2AuthenticationException ex) {
throw ex;
} catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, ex.getMessage(), ex);
throw OpenSaml4AuthenticationProvider.createAuthenticationException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, ex.getMessage(), ex);
}
}

Expand All @@ -325,7 +312,7 @@ private static Assertion parseAssertion(String assertion) throws Saml2Exception,
Element element = document.getDocumentElement();
return (Assertion) assertionUnmarshaller.unmarshall(element);
} catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.INVALID_ASSERTION, ex.getMessage(), ex);
throw OpenSaml4AuthenticationProvider.createAuthenticationException(Saml2ErrorCodes.INVALID_ASSERTION, ex.getMessage(), ex);
}
}

Expand All @@ -340,7 +327,7 @@ private void process(Saml2AuthenticationToken token, Assertion assertion) {
}
result = result.concat(this.assertionValidator.convert(assertionToken));

if (!hasName(assertion)) {
if (!OpenSaml4AuthenticationProvider.hasName(assertion)) {
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
"Assertion [" + assertion.getID() + "] is missing a subject");
result = result.concat(error);
Expand All @@ -354,87 +341,10 @@ private void process(Saml2AuthenticationToken token, Assertion assertion) {
log.debug("Found {} validation errors in SAML assertion [{}}]", errors.size(), assertion.getID());
}
Saml2Error first = errors.iterator().next();
throw createAuthenticationException(first.getErrorCode(), first.getDescription(), null);
throw OpenSaml4AuthenticationProvider.createAuthenticationException(first.getErrorCode(), first.getDescription(), null);
} else {
log.debug("Successfully processed SAML Assertion [{}]", assertion.getID());
}
}

private Consumer<OpenSaml4AuthenticationProvider.AssertionToken> createDefaultAssertionElementsDecrypter() {
return assertionToken -> {
Assertion assertion = assertionToken.getAssertion();
RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration();
try {
OpenSamlDecryptionUtils.decryptAssertionElements(assertion, registration);
} catch (Exception ex) {
throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
}
};
}

private boolean hasName(Assertion assertion) {
if (assertion == null) {
return false;
}
if (assertion.getSubject() == null) {
return false;
}
if (assertion.getSubject().getNameID() == null) {
return false;
}
return assertion.getSubject().getNameID().getValue() != null;
}

private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
MultiValueMap<String, Object> attributeMap = new LinkedMultiValueMap<>();
for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
for (Attribute attribute : attributeStatement.getAttributes()) {
List<Object> attributeValues = new ArrayList<>();
for (XMLObject xmlObject : attribute.getAttributeValues()) {
Object attributeValue = getXmlObjectValue(xmlObject);
if (attributeValue != null) {
attributeValues.add(attributeValue);
}
}
attributeMap.addAll(attribute.getName(), attributeValues);
}
}
return new LinkedHashMap<>(attributeMap); // gh-11785
}

private static List<String> getSessionIndexes(Assertion assertion) {
List<String> sessionIndexes = new ArrayList<>();
for (AuthnStatement statement : assertion.getAuthnStatements()) {
sessionIndexes.add(statement.getSessionIndex());
}
return sessionIndexes;
}

private static Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject instanceof XSAny xsAny) {
return xsAny.getTextContent();
}
if (xmlObject instanceof XSString xsstring) {
return xsstring.getValue();
}
if (xmlObject instanceof XSInteger xsInteger) {
return xsInteger.getValue();
}
if (xmlObject instanceof XSURI xsUri) {
return xsUri.getURI();
}
if (xmlObject instanceof XSBoolean xsBoolean) {
XSBooleanValue xsBooleanValue = xsBoolean.getValue();
return (xsBooleanValue != null) ? xsBooleanValue.getValue() : null;
}
if (xmlObject instanceof XSDateTime xsDateTime) {
return xsDateTime.getValue();
}
return xmlObject;
}

private static Saml2AuthenticationException createAuthenticationException(String code, String message,
Exception cause) {
return new Saml2AuthenticationException(new Saml2Error(code, message), cause);
}
}

0 comments on commit a0967e7

Please sign in to comment.