Skip to content

Commit

Permalink
Improve ScheduledTask task-name handling
Browse files Browse the repository at this point in the history
This PR introduces a String getTaskName() default method to
the ScheduledTask interface and adjusts call sites to use the
implementation derived task name where possible.

Previously, ScheduledTask names were passed around separately, which
lead to unhelpful debug messages.
We now give ScheduledTask implementations control over their task-name
which allows for more flexible naming.

Enlist call StoreSyncEvent.fire(...) to after transaction to ensure realm is present in database.
Ensure that Realm is already committed before updating sync via UserStorageSyncManager
Align Sync task name generation for cancellation to support SyncFederationTest
Only log a message if sync task was actually canceled.

Signed-off-by: Thomas Darimont <thomas.darimont@googlemail.com>
  • Loading branch information
thomasdarimont authored and pedroigor committed Feb 2, 2024
1 parent ff5a5fa commit 277af02
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.keycloak.common.util.Time;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.utils.SessionTimeoutHelper;
import org.keycloak.timer.ScheduledTask;
import org.keycloak.timer.TimerProvider;

/**
Expand All @@ -36,14 +37,22 @@ public abstract class AbstractLastSessionRefreshStoreFactory {
// Max count of lastSessionRefreshes. If count of lastSessionRefreshes reach this value, the message is sent to second DC
public static final int DEFAULT_MAX_COUNT = 100;



protected void setupPeriodicTimer(KeycloakSession kcSession, AbstractLastSessionRefreshStore store, long timerIntervalMs, String eventKey) {
TimerProvider timer = kcSession.getProvider(TimerProvider.class);
timer.scheduleTask((KeycloakSession keycloakSession) -> {
timer.scheduleTask(new PropagateLastSessionRefreshTask(store), timerIntervalMs, eventKey);
}

public static class PropagateLastSessionRefreshTask implements ScheduledTask {

private final AbstractLastSessionRefreshStore store;

store.checkSendingMessage(keycloakSession, Time.currentTime());
public PropagateLastSessionRefreshTask(AbstractLastSessionRefreshStore store) {
this.store = store;
}

}, timerIntervalMs, eventKey);
@Override
public void run(KeycloakSession session) {
store.checkSendingMessage(session, Time.currentTime());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ protected String detectDialect(Connection connection) {
protected void startGlobalStats(KeycloakSession session, int globalStatsIntervalSecs) {
logger.debugf("Started Hibernate statistics with the interval %s seconds", globalStatsIntervalSecs);
TimerProvider timer = session.getProvider(TimerProvider.class);
timer.scheduleTask(new HibernateStatsReporter(emf), globalStatsIntervalSecs * 1000, "ReportHibernateGlobalStats");
timer.scheduleTask(new HibernateStatsReporter(emf), globalStatsIntervalSecs * 1000);
}

void migration(MigrationStrategy strategy, boolean initializeEmpty, String schema, File databaseUpdateFile, Connection connection, KeycloakSession session) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@

import org.jboss.logging.Logger;
import org.keycloak.common.util.Time;
import org.keycloak.events.EventStoreProvider;
import org.keycloak.events.EventStoreProviderFactory;
import org.keycloak.models.KeycloakSession;
import org.keycloak.provider.InvalidationHandler;
import org.keycloak.storage.datastore.PeriodicEventInvalidation;
import org.keycloak.timer.ScheduledTask;

public class ClearExpiredAdminEvents implements ScheduledTask {

protected static final Logger logger = Logger.getLogger(ClearExpiredAdminEvents.class);
Expand All @@ -34,7 +32,7 @@ public void run(KeycloakSession session) {
long currentTimeMillis = Time.currentTimeMillis();
session.invalidate(PeriodicEventInvalidation.JPA_EVENT_STORE);
long took = Time.currentTimeMillis() - currentTimeMillis;
logger.debugf("ClearExpiredEvents finished in %d ms", took);
logger.debugf("%s finished in %d ms", getTaskName(), took);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ public void run(KeycloakSession session) {
session.sessions().removeAllExpired();

long took = Time.currentTimeMillis() - currentTimeMillis;
logger.debugf("ClearExpiredUserSessions finished in %d ms", took);
logger.debugf("%s finished in %d ms", getTaskName(), took);
}

@Override
public String getTaskName() {
return TASK_NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.keycloak.credential.CredentialInput;
import org.keycloak.credential.CredentialProvider;
import org.keycloak.credential.CredentialProviderFactory;
import org.keycloak.models.AbstractKeycloakTransaction;
import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientScopeModel;
import org.keycloak.models.CredentialValidationOutput;
Expand Down Expand Up @@ -789,7 +790,7 @@ public void preRemove(RealmModel realm, ComponentModel component) {
if (!component.getProviderType().equals(UserStorageProvider.class.getName())) return;
localStorage().preRemove(realm, component);
if (getFederatedStorage() != null) getFederatedStorage().preRemove(realm, component);
new UserStorageSyncManager().notifyToRefreshPeriodicSync(session, realm, new UserStorageProviderModel(component), true);
UserStorageSyncManager.notifyToRefreshPeriodicSync(session, realm, new UserStorageProviderModel(component), true);

}

Expand All @@ -813,8 +814,19 @@ public void close() {
public void onCreate(KeycloakSession session, RealmModel realm, ComponentModel model) {
ComponentFactory factory = ComponentUtil.getComponentFactory(session, model);
if (!(factory instanceof UserStorageProviderFactory)) return;
new UserStorageSyncManager().notifyToRefreshPeriodicSync(session, realm, new UserStorageProviderModel(model), false);

// enlistAfterCompletion(..) as we need to ensure that the realm is available in the system
session.getTransactionManager().enlistAfterCompletion(new AbstractKeycloakTransaction() {
@Override
protected void commitImpl() {
UserStorageSyncManager.notifyToRefreshPeriodicSync(session, realm, new UserStorageProviderModel(model), false);
}

@Override
protected void rollbackImpl() {
// NOOP
}
});
}

@Override
Expand All @@ -825,7 +837,7 @@ public void onUpdate(KeycloakSession session, RealmModel realm, ComponentModel o
UserStorageProviderModel newP= new UserStorageProviderModel(newModel);
if (old.getChangedSyncPeriod() != newP.getChangedSyncPeriod() || old.getFullSyncPeriod() != newP.getFullSyncPeriod()
|| old.isImportEnabled() != newP.isImportEnabled()) {
new UserStorageSyncManager().notifyToRefreshPeriodicSync(session, realm, new UserStorageProviderModel(newModel), false);
UserStorageSyncManager.notifyToRefreshPeriodicSync(session, realm, new UserStorageProviderModel(newModel), false);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.keycloak.storage.datastore;

import org.jboss.logging.Logger;
import org.keycloak.Config;
import org.keycloak.Config.Scope;
import org.keycloak.migration.MigrationModelManager;
Expand All @@ -35,11 +36,17 @@
import org.keycloak.storage.StoreMigrateRepresentationEvent;
import org.keycloak.storage.StoreSyncEvent;
import org.keycloak.storage.managers.UserStorageSyncManager;
import org.keycloak.timer.ScheduledTask;
import org.keycloak.timer.TimerProvider;
import java.util.Arrays;
import java.util.List;

public class DefaultDatastoreProviderFactory implements DatastoreProviderFactory, ProviderEventListener {

private static final String PROVIDER_ID = "legacy";

private static final Logger logger = Logger.getLogger(DefaultDatastoreProviderFactory.class);

private long clientStorageProviderTimeout;
private long roleStorageProviderTimeout;
private Runnable onClose;
Expand Down Expand Up @@ -94,18 +101,32 @@ public void onEvent(ProviderEvent event) {
}
}

public static void setupScheduledTasks(final KeycloakSessionFactory sessionFactory) {
public void setupScheduledTasks(final KeycloakSessionFactory sessionFactory) {
long interval = Config.scope("scheduled").getLong("interval", 900L) * 1000;

try (KeycloakSession session = sessionFactory.create()) {
TimerProvider timer = session.getProvider(TimerProvider.class);
if (timer != null) {
timer.schedule(new ClusterAwareScheduledTaskRunner(sessionFactory, new ClearExpiredEvents(), interval), interval, "ClearExpiredEvents");
timer.schedule(new ClusterAwareScheduledTaskRunner(sessionFactory, new ClearExpiredAdminEvents(), interval), interval, "ClearExpiredAdminEvents");
timer.schedule(new ClusterAwareScheduledTaskRunner(sessionFactory, new ClearExpiredClientInitialAccessTokens(), interval), interval, "ClearExpiredClientInitialAccessTokens");
timer.schedule(new ClusterAwareScheduledTaskRunner(sessionFactory, new ClearExpiredUserSessions(), interval), interval, ClearExpiredUserSessions.TASK_NAME);
UserStorageSyncManager.bootstrapPeriodic(sessionFactory, timer);
scheduleTasks(sessionFactory, timer, interval);
}
}
}

protected void scheduleTasks(KeycloakSessionFactory sessionFactory, TimerProvider timer, long interval) {
for (ScheduledTask task : getScheduledTasks()) {
scheduleTask(timer, sessionFactory, task, interval);
}

UserStorageSyncManager.bootstrapPeriodic(sessionFactory, timer);
}

protected List<ScheduledTask> getScheduledTasks() {
return Arrays.asList(new ClearExpiredEvents(), new ClearExpiredAdminEvents(), new ClearExpiredClientInitialAccessTokens(), new ClearExpiredUserSessions());
}

protected void scheduleTask(TimerProvider timer, KeycloakSessionFactory sessionFactory, ScheduledTask task, long interval) {
timer.schedule(new ClusterAwareScheduledTaskRunner(sessionFactory, task, interval), interval);
logger.debugf("Scheduled cluster task %s with interval %s ms", task.getTaskName(), interval);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.keycloak.storage.user.ImportSynchronization;
import org.keycloak.storage.user.SynchronizationResult;
import org.keycloak.timer.TimerProvider;
import org.keycloak.timer.TimerProvider.TimerTaskContext;

import java.util.Objects;
import java.util.concurrent.Callable;
Expand Down Expand Up @@ -66,7 +67,7 @@ public void run(KeycloakSession session) {
providers.forEachOrdered(provider -> {
UserStorageProviderFactory factory = (UserStorageProviderFactory) session.getKeycloakSessionFactory().getProviderFactory(UserStorageProvider.class, provider.getProviderId());
if (factory instanceof ImportSynchronization && provider.isImportEnabled()) {
refreshPeriodicSyncForProvider(sessionFactory, timer, provider, realm.getId());
refreshPeriodicSyncForProvider(sessionFactory, timer, provider, realm);
}
});
});
Expand Down Expand Up @@ -198,56 +199,81 @@ public static void notifyToRefreshPeriodicSync(KeycloakSession session, RealmMod


// Executed once it receives notification that some UserFederationProvider was created or updated
protected static void refreshPeriodicSyncForProvider(final KeycloakSessionFactory sessionFactory, TimerProvider timer, final UserStorageProviderModel provider, final String realmId) {
logger.debugf("Going to refresh periodic sync for provider '%s' . Full sync period: %d , changed users sync period: %d",
provider.getName(), provider.getFullSyncPeriod(), provider.getChangedSyncPeriod());
protected static void refreshPeriodicSyncForProvider(final KeycloakSessionFactory sessionFactory, TimerProvider timer, final UserStorageProviderModel provider, final RealmModel realm) {
logger.debugf("Going to refresh periodic sync settings for provider '%s' in realm '%s' with realmId '%s'. Full sync period: %d , changed users sync period: %d",
provider.getName(), realm.getName(), realm.getId(), provider.getFullSyncPeriod(), provider.getChangedSyncPeriod());

String fullSyncTaskName = createSyncTaskName(provider, UserStorageSyncTask.SyncMode.FULL);
if (provider.getFullSyncPeriod() > 0) {
// We want periodic full sync for this provider
timer.schedule(new Runnable() {

@Override
public void run() {
try {
boolean shouldPerformSync = shouldPerformNewPeriodicSync(provider.getLastSync(), provider.getChangedSyncPeriod());
if (shouldPerformSync) {
syncAllUsers(sessionFactory, realmId, provider);
} else {
logger.debugf("Ignored periodic full sync with storage provider %s due small time since last sync", provider.getName());
}
} catch (Throwable t) {
logger.error("Error occurred during full sync of users", t);
}
}

}, provider.getFullSyncPeriod() * 1000, provider.getId() + "-FULL");
// schedule periodic full sync for this provider
UserStorageSyncTask task = new UserStorageSyncTask(provider, realm, sessionFactory, UserStorageSyncTask.SyncMode.FULL);
timer.schedule(task, provider.getFullSyncPeriod() * 1000, fullSyncTaskName);
} else {
timer.cancelTask(provider.getId() + "-FULL");
// cancel potentially dangling task
timer.cancelTask(fullSyncTaskName);
}

String changedSyncTaskName = createSyncTaskName(provider, UserStorageSyncTask.SyncMode.CHANGED);
if (provider.getChangedSyncPeriod() > 0) {
// We want periodic sync of just changed users for this provider
timer.schedule(new Runnable() {
// schedule periodic changed user sync for this provider
UserStorageSyncTask task = new UserStorageSyncTask(provider, realm, sessionFactory, UserStorageSyncTask.SyncMode.CHANGED);
timer.schedule(task, provider.getChangedSyncPeriod() * 1000, changedSyncTaskName);
} else {
// cancel potentially dangling task
timer.cancelTask(changedSyncTaskName);
}
}

@Override
public void run() {
try {
boolean shouldPerformSync = shouldPerformNewPeriodicSync(provider.getLastSync(), provider.getChangedSyncPeriod());
if (shouldPerformSync) {
syncChangedUsers(sessionFactory, realmId, provider);
} else {
logger.debugf("Ignored periodic changed-users sync with storage provider %s due small time since last sync", provider.getName());
}
} catch (Throwable t) {
logger.error("Error occurred during sync of changed users", t);
}
}
public static class UserStorageSyncTask implements Runnable {

}, provider.getChangedSyncPeriod() * 1000, provider.getId() + "-CHANGED");
private final UserStorageProviderModel provider;

} else {
timer.cancelTask(provider.getId() + "-CHANGED");
private final RealmModel realm;

private final KeycloakSessionFactory sessionFactory;

private final SyncMode syncMode;

public static enum SyncMode {
FULL, CHANGED
}

public UserStorageSyncTask(UserStorageProviderModel provider, RealmModel realm, KeycloakSessionFactory sessionFactory, SyncMode syncMode) {
this.provider = provider;
this.realm = realm;
this.sessionFactory = sessionFactory;
this.syncMode = syncMode;
}

@Override
public void run() {

try {
boolean shouldPerformSync = shouldPerformNewPeriodicSync(provider.getLastSync(), provider.getChangedSyncPeriod());

if (!shouldPerformSync) {
logger.debugf("Ignored periodic %s users-sync with storage provider %s due small time since last sync in realm %s", //
syncMode, provider.getName(), realm.getName());
return;
}

switch (syncMode) {
case FULL:
syncAllUsers(sessionFactory, realm.getId(), provider);
break;
case CHANGED:
syncChangedUsers(sessionFactory, realm.getId(), provider);
break;
}
} catch (Throwable t) {
logger.errorf(t,"Error occurred during %s users-sync in realm %s", //
syncMode, realm.getName());
}
}
}

public static String createSyncTaskName(UserStorageProviderModel model, UserStorageSyncTask.SyncMode syncMode) {
return UserStorageSyncTask.class.getSimpleName() + "-" + model.getId() + "-" + syncMode;
}

// Skip syncing if there is short time since last sync time.
Expand All @@ -264,9 +290,17 @@ private static boolean shouldPerformNewPeriodicSync(int lastSyncTime, int period

// Executed once it receives notification that some UserFederationProvider was removed
protected static void removePeriodicSyncForProvider(TimerProvider timer, UserStorageProviderModel fedProvider) {
logger.debugf("Removing periodic sync for provider %s", fedProvider.getName());
timer.cancelTask(fedProvider.getId() + "-FULL");
timer.cancelTask(fedProvider.getId() + "-CHANGED");
cancelPeriodicSyncForProviderIfPresent(timer, fedProvider, UserStorageSyncTask.SyncMode.FULL);
cancelPeriodicSyncForProviderIfPresent(timer, fedProvider, UserStorageSyncTask.SyncMode.CHANGED);
}

protected static void cancelPeriodicSyncForProviderIfPresent(TimerProvider timer, UserStorageProviderModel providerModel, UserStorageSyncTask.SyncMode syncMode) {
String taskName = createSyncTaskName(providerModel, syncMode);
TimerTaskContext existingTask = timer.cancelTask(taskName);
if (existingTask != null) {
logger.debugf("Cancelled periodic sync task with task-name '%s' for provider with id '%s' and name '%s'",
taskName, providerModel.getId(), providerModel.getName());
}
}

// Update interval of last sync for given UserFederationProviderModel. Do it in separate transaction
Expand Down Expand Up @@ -310,7 +344,8 @@ public void run(KeycloakSession session) {
if (fedEvent.isRemoved()) {
removePeriodicSyncForProvider(timer, fedEvent.getStorageProvider());
} else {
refreshPeriodicSyncForProvider(sessionFactory, timer, fedEvent.getStorageProvider(), fedEvent.getRealmId());
RealmModel realm = session.realms().getRealm(fedEvent.getRealmId());
refreshPeriodicSyncForProvider(sessionFactory, timer, fedEvent.getStorageProvider(), realm);
}
}

Expand Down
Loading

0 comments on commit 277af02

Please sign in to comment.