Skip to content

Commit

Permalink
KEYCLOAK-18146 Search for clients by client attribute when doing saml…
Browse files Browse the repository at this point in the history
… artifact resolution
  • Loading branch information
mhajas authored and hmlnarik committed May 27, 2021
1 parent 2cb59e2 commit 4dcb695
Show file tree
Hide file tree
Showing 26 changed files with 364 additions and 52 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2020 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.connections.jpa.updater.liquibase.custom;

import liquibase.exception.CustomChangeException;
import liquibase.statement.core.InsertStatement;
import liquibase.structure.core.Table;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.HashMap;
import java.util.Map;

import static org.keycloak.protocol.saml.util.ArtifactBindingUtils.computeArtifactBindingIdentifierString;

public class JpaUpdate14_0_0_MigrateSamlArtifactAttribute extends CustomKeycloakTask {

private static final String SAML_ARTIFACT_BINDING_IDENTIFIER = "saml.artifact.binding.identifier";

private final Map<String, String> clientIds = new HashMap<>();

@Override
protected void generateStatementsImpl() throws CustomChangeException {
extractClientsData("SELECT C.ID, C.CLIENT_ID FROM " + getTableName("CLIENT") + " C " +
"LEFT JOIN " + getTableName("CLIENT_ATTRIBUTES") + " CA " +
"ON C.ID = CA.CLIENT_ID AND CA.NAME='" + SAML_ARTIFACT_BINDING_IDENTIFIER + "' " +
"WHERE C.PROTOCOL='saml' AND CA.NAME IS NULL");

for (Map.Entry<String, String> clientPair : clientIds.entrySet()) {
String id = clientPair.getKey();

String clientId = clientPair.getValue();
String samlIdentifier = computeArtifactBindingIdentifierString(clientId);

statements.add(
new InsertStatement(null, null, database.correctObjectName("CLIENT_ATTRIBUTES", Table.class))
.addColumnValue("CLIENT_ID", id)
.addColumnValue("NAME", SAML_ARTIFACT_BINDING_IDENTIFIER)
.addColumnValue("VALUE", samlIdentifier)
);
}
}

private void extractClientsData(String sql) throws CustomChangeException {
try (PreparedStatement statement = jdbcConnection.prepareStatement(sql);
ResultSet rs = statement.executeQuery()) {

while (rs.next()) {
String id = rs.getString(1);
String clientId = rs.getString(2);

if (id == null || id.trim().isEmpty()
|| clientId == null || clientId.trim().isEmpty()) {
continue;
}

clientIds.put(id, clientId);
}

} catch (Exception e) {
throw new CustomChangeException(getTaskId() + ": Exception when extracting data from previous version", e);
}
}

@Override
protected String getTaskId() {
return "Migrate Saml attributes (14.0.0)";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.keycloak.models.ClientProviderFactory;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.protocol.saml.SamlConfigAttributes;

import javax.persistence.EntityManager;
import java.util.Arrays;
Expand All @@ -38,7 +39,8 @@ public class JpaClientProviderFactory implements ClientProviderFactory {
private Set<String> clientSearchableAttributes = null;

private static final List<String> REQUIRED_SEARCHABLE_ATTRIBUTES = Arrays.asList(
"saml_idp_initiated_sso_url_name"
"saml_idp_initiated_sso_url_name",
SamlConfigAttributes.SAML_ARTIFACT_BINDING_IDENTIFIER
);

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,8 @@
</createIndex>
</changeSet>

<changeSet author="keycloak" id="KEYCLOAK-18146-add-saml-art-binding-identifier">
<customChange class="org.keycloak.connections.jpa.updater.liquibase.custom.JpaUpdate14_0_0_MigrateSamlArtifactAttribute"/>
</changeSet>

</databaseChangeLog>
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@
import org.keycloak.util.JsonSerialization;
import org.keycloak.validation.ValidationUtil;

import static org.keycloak.protocol.saml.util.ArtifactBindingUtils.computeArtifactBindingIdentifierString;

public class RepresentationToModel {

private static Logger logger = Logger.getLogger(RepresentationToModel.class);
Expand Down Expand Up @@ -1407,6 +1409,11 @@ private static ClientModel createClient(KeycloakSession session, RealmModel real
}
}

if ("saml".equals(resourceRep.getProtocol())
&& (resourceRep.getAttributes() == null
|| !resourceRep.getAttributes().containsKey("saml.artifact.binding.identifier"))) {
client.setAttribute("saml.artifact.binding.identifier", computeArtifactBindingIdentifierString(resourceRep.getClientId()));
}

if (resourceRep.getAuthenticationFlowBindingOverrides() != null) {
for (Map.Entry<String, String> entry : resourceRep.getAuthenticationFlowBindingOverrides().entrySet()) {
Expand Down Expand Up @@ -1559,6 +1566,12 @@ public static void updateClient(ClientRepresentation rep, ClientModel resource)
}
}

if ("saml".equals(rep.getProtocol())
&& (rep.getAttributes() == null
|| !rep.getAttributes().containsKey("saml.artifact.binding.identifier"))) {
resource.setAttribute("saml.artifact.binding.identifier", computeArtifactBindingIdentifierString(rep.getClientId()));
}

if (rep.getAuthenticationFlowBindingOverrides() != null) {
for (Map.Entry<String, String> entry : rep.getAuthenticationFlowBindingOverrides().entrySet()) {
if (entry.getValue() == null || entry.getValue().trim().equals("")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.provider.Provider;

import java.util.stream.Stream;


/**
* Provides a way to create and resolve artifacts for SAML Artifact binding
Expand All @@ -15,12 +14,12 @@ public interface ArtifactResolver extends Provider {
/**
* Returns client model that issued artifact
*
* @param session KeycloakSession for searching for client corresponding client
* @param artifact the artifact
* @param clients stream of clients, the stream will be searched for a client that issued the artifact
* @return the client model that issued the artifact
* @throws ArtifactResolverProcessingException When an error occurs during client search
*/
ClientModel selectSourceClient(String artifact, Stream<ClientModel> clients) throws ArtifactResolverProcessingException;
ClientModel selectSourceClient(KeycloakSession session, String artifact) throws ArtifactResolverProcessingException;

/**
* Creates and stores an artifact
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package org.keycloak.protocol.saml.util;

import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;

public class ArtifactBindingUtils {
public static String artifactToResolverProviderId(String artifact) {
return byteArrayToResolverProviderId(Base64.getDecoder().decode(artifact));
}

public static String byteArrayToResolverProviderId(byte[] ar) {
return String.format("%02X%02X", ar[0], ar[1]);
}

/**
* Computes identifier from the given String, for example, from entityId
*
* @param identifierFrom String that will be turned into an identifier
* @return Base64 of SHA-1 hash of the identifierFrom
*/
public static String computeArtifactBindingIdentifierString(String identifierFrom) {
return Base64.getEncoder().encodeToString(computeArtifactBindingIdentifier(identifierFrom));
}

/**
* Turns byte representation of the identifier into readable String
*
* @param identifier byte representation of the identifier
* @return Base64 of the identifier
*/
public static String getArtifactBindingIdentifierString(byte[] identifier) {
return Base64.getEncoder().encodeToString(identifier);
}

/**
* Computes 20 bytes long byte identifier of the given string, for example, from entityId
*
* @param identifierFrom String that will be turned into an identifier
* @return SHA-1 hash of the given identifierFrom
*/
public static byte[] computeArtifactBindingIdentifier(String identifierFrom) {
try {
MessageDigest sha1Digester = MessageDigest.getInstance("SHA-1");
return sha1Digester.digest(identifierFrom.getBytes(StandardCharsets.UTF_8));
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("JVM does not support required cryptography algorithms: SHA-1/SHA1PRNG.", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
package org.keycloak.protocol.saml;

import com.google.common.base.Charsets;
import com.google.common.base.Strings;
import org.jboss.logging.Logger;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.protocol.saml.util.ArtifactBindingUtils;
import org.keycloak.saml.common.constants.GeneralConstants;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Base64;
import java.util.stream.Stream;
import java.util.Collections;

import static org.keycloak.protocol.saml.DefaultSamlArtifactResolverFactory.TYPE_CODE;
import static org.keycloak.protocol.saml.SamlConfigAttributes.SAML_ARTIFACT_BINDING_IDENTIFIER;

/**
* ArtifactResolver for artifact-04 format.
Expand All @@ -43,18 +43,13 @@ public String resolveArtifact(AuthenticatedClientSessionModel clientSessionModel
}

@Override
public ClientModel selectSourceClient(String artifact, Stream<ClientModel> clients) throws ArtifactResolverProcessingException {
try {
byte[] source = extractSourceFromArtifact(artifact);
public ClientModel selectSourceClient(KeycloakSession session, String artifact) throws ArtifactResolverProcessingException {
byte[] source = extractSourceFromArtifact(artifact);
String identifier = ArtifactBindingUtils.getArtifactBindingIdentifierString(source);

MessageDigest sha1Digester = MessageDigest.getInstance("SHA-1");
return clients.filter(clientModel -> Arrays.equals(source,
sha1Digester.digest(clientModel.getClientId().getBytes(Charsets.UTF_8))))
.findFirst()
.orElseThrow(() -> new ArtifactResolverProcessingException("No client matching the artifact source found"));
} catch (NoSuchAlgorithmException e) {
throw new ArtifactResolverProcessingException(e);
}
return session.clients().searchClientsByAttributes(session.getContext().getRealm(),
Collections.singletonMap(SAML_ARTIFACT_BINDING_IDENTIFIER, identifier), 0, 1)
.findFirst().orElseThrow(() -> new ArtifactResolverProcessingException("No client matching the artifact source found"));
}

@Override
Expand Down Expand Up @@ -109,8 +104,7 @@ public String createArtifact(String entityId) throws ArtifactResolverProcessingE
SecureRandom handleGenerator = SecureRandom.getInstance("SHA1PRNG");
byte[] trimmedIndex = new byte[2];

MessageDigest sha1Digester = MessageDigest.getInstance("SHA-1");
byte[] source = sha1Digester.digest(entityId.getBytes(Charsets.UTF_8));
byte[] source = ArtifactBindingUtils.computeArtifactBindingIdentifier(entityId);

byte[] assertionHandle = new byte[20];
handleGenerator.nextBytes(assertionHandle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.jboss.logging.Logger;
import org.keycloak.models.ClientConfigResolver;
import org.keycloak.models.ClientModel;
import org.keycloak.protocol.saml.util.ArtifactBindingUtils;
import org.keycloak.saml.SignatureAlgorithm;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.util.XmlKeyInfoKeyNameTransformer;
Expand Down Expand Up @@ -258,4 +259,12 @@ public int getAssertionLifespan() {
return -1;
}
}

public void setArtifactBindingIdentifierFrom(String identifierFrom) {
client.setAttribute(SamlConfigAttributes.SAML_ARTIFACT_BINDING_IDENTIFIER, ArtifactBindingUtils.computeArtifactBindingIdentifierString(identifierFrom));
}

public String getArtifactBindingIdentifier() {
return client.getAttribute(SamlConfigAttributes.SAML_ARTIFACT_BINDING_IDENTIFIER);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ public interface SamlConfigAttributes {
String SAML_ENCRYPTION_CERTIFICATE_ATTRIBUTE = "saml.encryption." + CertificateInfoHelper.X509CERTIFICATE;
String SAML_ENCRYPTION_PRIVATE_KEY_ATTRIBUTE = "saml.encryption." + CertificateInfoHelper.PRIVATE_KEY;
String SAML_ASSERTION_LIFESPAN = "saml.assertion.lifespan";
String SAML_ARTIFACT_BINDING_IDENTIFIER = "saml.artifact.binding.identifier";
}
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ public void setupClientDefaults(ClientRepresentation clientRep, ClientModel newC
if (clientRep.isFrontchannelLogout() == null) {
newClient.setFrontchannelLogout(true);
}

client.setArtifactBindingIdentifierFrom(clientRep.getClientId());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
import org.keycloak.protocol.saml.preprocessor.SamlAuthenticationPreprocessor;
import org.keycloak.protocol.saml.profile.ecp.SamlEcpProfileService;
import org.keycloak.protocol.saml.profile.util.Soap;
import org.keycloak.protocol.util.ArtifactBindingUtils;
import org.keycloak.protocol.saml.util.ArtifactBindingUtils;
import org.keycloak.rotation.HardcodedKeyLocator;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.BaseSAML2BindingBuilder;
Expand Down Expand Up @@ -341,7 +341,7 @@ protected void handleArtifact(AsyncResponse asyncResponse, String artifact, Stri
//Find client
ClientModel client;
try {
client = getArtifactResolver(artifact).selectSourceClient(artifact, realm.getClientsStream());
client = getArtifactResolver(artifact).selectSourceClient(session, artifact);

Response error = checkClientValidity(client);
if (error != null) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
import org.keycloak.protocol.LoginProtocolFactory;
import org.keycloak.protocol.oidc.OIDCLoginProtocol;
import org.keycloak.protocol.oidc.mappers.UserSessionNoteMapper;
import org.keycloak.protocol.saml.SamlClient;
import org.keycloak.protocol.saml.SamlConfigAttributes;
import org.keycloak.protocol.saml.SamlProtocol;
import org.keycloak.representations.adapters.config.BaseRealmConfig;
import org.keycloak.representations.adapters.config.PolicyEnforcerConfig;
import org.keycloak.representations.idm.ClientRepresentation;
Expand Down Expand Up @@ -191,14 +194,22 @@ public void enableServiceAccount(ClientModel client) {
}
}

public void clientIdChanged(ClientModel client, String newClientId) {
public void clientIdChanged(ClientModel client, ClientRepresentation newClientRepresentation) {
String newClientId = newClientRepresentation.getClientId();
logger.debugf("Updating clientId from '%s' to '%s'", client.getClientId(), newClientId);

UserModel serviceAccountUser = realmManager.getSession().users().getServiceAccount(client);
if (serviceAccountUser != null) {
String username = ServiceAccountConstants.SERVICE_ACCOUNT_USER_PREFIX + newClientId;
serviceAccountUser.setUsername(username);
}

if (SamlProtocol.LOGIN_PROTOCOL.equals(client.getProtocol())) {
SamlClient samlClient = new SamlClient(client);
samlClient.setArtifactBindingIdentifierFrom(newClientId);

newClientRepresentation.getAttributes().put(SamlConfigAttributes.SAML_ARTIFACT_BINDING_IDENTIFIER, samlClient.getArtifactBindingIdentifier());
}
}

@JsonPropertyOrder({"realm", "realm-public-key", "bearer-only", "auth-server-url", "ssl-required",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ private void updateClientFromRep(ClientRepresentation rep, ClientModel client, K
}

if (rep.getClientId() != null && !rep.getClientId().equals(client.getClientId())) {
new ClientManager(new RealmManager(session)).clientIdChanged(client, rep.getClientId());
new ClientManager(new RealmManager(session)).clientIdChanged(client, rep);
}

if (rep.isFullScopeAllowed() != null && rep.isFullScopeAllowed() != client.isFullScopeAllowed()) {
Expand Down
Loading

0 comments on commit 4dcb695

Please sign in to comment.