Skip to content

Commit

Permalink
#29281: Applying feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
victoralfaro-dotcms committed Aug 14, 2024
1 parent 69eed78 commit 3b8be74
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 60 deletions.
5 changes: 4 additions & 1 deletion dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,14 @@ public long minIntervalBetweenCalls() {
@Override
public String toString() {
return "AIModel{" +
"name='" + names + '\'' +
"type=" + type +
", names=" + names +
", tokensPerMinute=" + tokensPerMinute +
", apiPerMinute=" + apiPerMinute +
", maxTokens=" + maxTokens +
", isCompletion=" + isCompletion +
", current=" + current +
", decommissioned=" + decommissioned +
'}';
}

Expand Down
29 changes: 18 additions & 11 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -49,7 +50,7 @@ public class AIModels {

private final ConcurrentMap<String, List<Tuple2<AIModelType, AIModel>>> internalModels = new ConcurrentHashMap<>();
private final ConcurrentMap<Tuple2<String, String>, AIModel> modelsByName = new ConcurrentHashMap<>();
private final Cache<String, List<String>> supportedModelsCache =
private final Cache<String, Set<String>> supportedModelsCache =
Caffeine.newBuilder()
.expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL))
.maximumSize(AI_MODELS_CACHE_SIZE)
Expand Down Expand Up @@ -107,7 +108,11 @@ public void loadModels(final String host, final List<AIModel> loading) {
* @return an Optional containing the found AIModel, or an empty Optional if not found
*/
public Optional<AIModel> findModel(final String host, final String modelName) {
return Optional.ofNullable(modelsByName.get(Tuple.of(host, modelName.toLowerCase())));
final String lowered = modelName.toLowerCase();
final Set<String> supported = getOrPullSupportedModels();
return supported.contains(lowered)
? Optional.ofNullable(modelsByName.get(Tuple.of(host, lowered)))
: Optional.empty();
}

/**
Expand Down Expand Up @@ -146,10 +151,10 @@ public void resetModels(final String host) {
* Retrieves the list of supported models, either from the cache or by fetching them
* from an external source if the cache is empty or expired.
*
* @return a list of supported model names
* @return a set of supported model names
*/
public List<String> getOrPullSupportedModels() {
final List<String> cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY);
public Set<String> getOrPullSupportedModels() {
final Set<String> cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY);
if (CollectionUtils.isNotEmpty(cached)) {
return cached;
}
Expand All @@ -160,17 +165,18 @@ public List<String> getOrPullSupportedModels() {
throw new DotRuntimeException("App dotAI config without API urls or API key");
}

final CircuitBreakerUrl.Response<OpenAIModels> response = Try
.of(() -> fetchOpenAIModels(appConfig))
.getOrElseThrow(() -> new DotRuntimeException("Error fetching OpenAI supported models"));
final CircuitBreakerUrl.Response<OpenAIModels> response = fetchOpenAIModels(appConfig);
if (Objects.nonNull(response.getResponse().getError())) {
throw new DotRuntimeException("Found error in AI response: " + response.getResponse().getError().getMessage());
}

final List<String> supported = response
final Set<String> supported = response
.getResponse()
.getData()
.stream()
.map(OpenAIModel::getId)
.map(String::toLowerCase)
.collect(Collectors.toList());
.collect(Collectors.toSet());
supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported);

return supported;
Expand Down Expand Up @@ -204,7 +210,7 @@ private static CircuitBreakerUrl.Response<OpenAIModels> fetchOpenAIModels(final
.setTimeout(AI_MODELS_FETCH_TIMEOUT)
.setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS)
.setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + appConfig.getApiKey()))
.setThrowWhenNot2xx(false)
.setThrowWhenNot2xx(true)
.build()
.doResponse(OpenAIModels.class);

Expand All @@ -215,6 +221,7 @@ private static CircuitBreakerUrl.Response<OpenAIModels> fetchOpenAIModels(final
"Error fetching OpenAI supported models from [%s] (status code: [%d])",
OPEN_AI_MODELS_URL,
response.getStatusCode()));
throw new DotRuntimeException("Error fetching OpenAI supported models");
}

return response;
Expand Down
85 changes: 59 additions & 26 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.UtilMethods;
import com.liferay.util.StringPool;
import io.vavr.control.Try;
import org.apache.commons.lang3.StringUtils;

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
Expand All @@ -25,9 +29,11 @@ public class AppConfig implements Serializable {
private static final String AI_API_URL_KEY = "AI_API_URL";
private static final String AI_IMAGE_API_URL_KEY = "AI_IMAGE_API_URL";
private static final String AI_EMBEDDINGS_API_URL_KEY = "AI_EMBEDDINGS_API_URL";

private static final String SYSTEM_HOST = "System Host";
public static final Pattern SPLITTER = Pattern.compile("\\s?,\\s?");

private static final AtomicReference<AppConfig> SYSTEM_HOST_CONFIG = new AtomicReference<>();

private final String host;
private final String apiKey;
private final transient AIModel model;
Expand All @@ -45,6 +51,9 @@ public class AppConfig implements Serializable {

public AppConfig(final String host, final Map<String, Secret> secrets) {
this.host = host;
if (SYSTEM_HOST.equalsIgnoreCase(host)) {
setSystemHostConfig(this);
}

final AIAppUtil aiAppUtil = AIAppUtil.get();
apiKey = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY, AI_API_KEY_KEY);
Expand Down Expand Up @@ -73,18 +82,36 @@ public AppConfig(final String host, final Map<String, Secret> secrets) {

configValues = secrets.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

Logger.debug(getClass(), () -> "apiKey: " + apiKey);
Logger.debug(getClass(), () -> "apiUrl: " + apiUrl);
Logger.debug(getClass(), () -> "apiImageUrl: " + apiImageUrl);
Logger.debug(getClass(), () -> "embeddingsUrl: " + apiEmbeddingsUrl);
Logger.debug(getClass(), () -> "rolePrompt: " + rolePrompt);
Logger.debug(getClass(), () -> "textPrompt: " + textPrompt);
Logger.debug(getClass(), () -> "model: " + model);
Logger.debug(getClass(), () -> "imagePrompt: " + imagePrompt);
Logger.debug(getClass(), () -> "imageModel: " + imageModel);
Logger.debug(getClass(), () -> "imageSize: " + imageSize);
Logger.debug(getClass(), () -> "embeddingsModel: " + embeddingsModel);
Logger.debug(getClass(), () -> "listerIndexer: " + listenerIndexer);
Logger.debug(this, this::toString);
}

/**
* Retrieves the system host configuration.
*
* @return the system host configuration
*/
public static AppConfig getSystemHostConfig() {
if (Objects.isNull(SYSTEM_HOST_CONFIG.get())) {
setSystemHostConfig(ConfigService.INSTANCE.config());
}
return SYSTEM_HOST_CONFIG.get();
}

/**
* Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING}
* property instead of the usual Log4j configuration.
*
* @param clazz The {@link Class} to log the message for.
* @param message The {@link Supplier} with the message to log.
*/
public static void debugLogger(final Class<?> clazz, final Supplier<String> message) {
if (getSystemHostConfig().getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(clazz, message.get());
}
}

public static void setSystemHostConfig(final AppConfig systemHostConfig) {
AppConfig.SYSTEM_HOST_CONFIG.set(systemHostConfig);
}

/**
Expand Down Expand Up @@ -287,19 +314,6 @@ public AIModel resolveModelOrThrow(final String modelName) {
return aiModel;
}

/**
* Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING}
* property instead of the usual Log4j configuration.
*
* @param clazz The {@link Class} to log the message for.
* @param message The {@link Supplier} with the message to log.
*/
public static void debugLogger(final Class<?> clazz, final Supplier<String> message) {
if (ConfigService.INSTANCE.config().getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(clazz, message.get());
}
}

/**
* Checks if the configuration is enabled.
*
Expand All @@ -309,4 +323,23 @@ public boolean isEnabled() {
return Stream.of(apiUrl, apiImageUrl, apiEmbeddingsUrl, apiKey).allMatch(StringUtils::isNotBlank);
}

@Override
public String toString() {
return "AppConfig{\n" +
" host='" + host + "',\n" +
" apiKey='" + Optional.ofNullable(apiKey).map(key -> "*****").orElse(StringPool.BLANK) + "',\n" +
" model=" + model + "',\n" +
" imageModel=" + imageModel + "',\n" +
" embeddingsModel=" + embeddingsModel + "',\n" +
" apiUrl='" + apiUrl + "',\n" +
" apiImageUrl='" + apiImageUrl + "',\n" +
" apiEmbeddingsUrl='" + apiEmbeddingsUrl + "',\n" +
" rolePrompt='" + rolePrompt + "',\n" +
" textPrompt='" + textPrompt + "',\n" +
" imagePrompt='" + imagePrompt + "',\n" +
" imageSize='" + imageSize + "',\n" +
" listenerIndexer='" + listenerIndexer + "'\n" +
'}';
}

}
10 changes: 5 additions & 5 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.dotcms.ai.app;

import com.liferay.util.StringPool;

public enum AppKeys {

API_KEY("apiKey", null),
Expand All @@ -22,12 +24,12 @@ public enum AppKeys {
IMAGE_MODEL_TOKENS_PER_MINUTE("imageModelTokensPerMinute", "0"),
IMAGE_MODEL_API_PER_MINUTE("imageModelApiPerMinute", "50"),
IMAGE_MODEL_MAX_TOKENS("imageModelMaxTokens", "0"),
IMAGE_MODEL_COMPLETION("imageModelCompletion", AppKeys.FALSE),
IMAGE_MODEL_COMPLETION("imageModelCompletion", StringPool.FALSE),
EMBEDDINGS_MODEL_NAMES("embeddingsModelNames", null),
EMBEDDINGS_MODEL_TOKENS_PER_MINUTE("embeddingsModelTokensPerMinute", "1000000"),
EMBEDDINGS_MODEL_API_PER_MINUTE("embeddingsModelApiPerMinute", "3000"),
EMBEDDINGS_MODEL_MAX_TOKENS("embeddingsModelMaxTokens", "8191"),
EMBEDDINGS_MODEL_COMPLETION("embeddingsModelCompletion", AppKeys.FALSE),
EMBEDDINGS_MODEL_COMPLETION("embeddingsModelCompletion", StringPool.FALSE),
EMBEDDINGS_SPLIT_AT_TOKENS("com.dotcms.ai.embeddings.split.at.tokens", "512"),
EMBEDDINGS_MINIMUM_TEXT_LENGTH_TO_INDEX("com.dotcms.ai.embeddings.minimum.text.length", "64"),
EMBEDDINGS_MINIMUM_FILE_SIZE_TO_INDEX("com.dotcms.ai.embeddings.minimum.file.size", "1024"),
Expand All @@ -39,7 +41,7 @@ public enum AppKeys {
EMBEDDINGS_CACHE_TTL_SECONDS("com.dotcms.ai.embeddings.cache.ttl.seconds", "600"),
EMBEDDINGS_CACHE_SIZE("com.dotcms.ai.embeddings.cache.size", "1000"),
EMBEDDINGS_DB_DELETE_OLD_ON_UPDATE("com.dotcms.ai.embeddings.delete.old.on.update", "true"),
DEBUG_LOGGING("com.dotcms.ai.debug.logging", AppKeys.FALSE),
DEBUG_LOGGING("com.dotcms.ai.debug.logging", StringPool.FALSE),
COMPLETION_TEMPERATURE("com.dotcms.ai.completion.default.temperature", "1"),
COMPLETION_ROLE_PROMPT(
"com.dotcms.ai.completion.role.prompt",
Expand All @@ -52,8 +54,6 @@ public enum AppKeys {
AI_MODELS_CACHE_TTL("com.dotcms.ai.models.supported.ttl", "28800"),
AI_MODELS_CACHE_SIZE("com.dotcms.ai.models.supported.size", "64");

private static final String FALSE = "false";

public static final String APP_KEY = "dotAI";

public final String key;
Expand Down
17 changes: 11 additions & 6 deletions dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,23 @@ public static void doRequest(final String urlIn,
final AppConfig appConfig,
final JSONObject json,
final OutputStream out) {
AppConfig.debugLogger(
OpenAIRequest.class,
() -> String.format(
"Posting to [%s] with method [%s]%s with app config:%s%s the payload: %s",
urlIn,
method,
System.lineSeparator(),
appConfig.toString(),
System.lineSeparator(),
json.toString(2)));

if (!appConfig.isEnabled()) {
AppConfig.debugLogger(OpenAIRequest.class, () -> "dotAI is not enabled and will not send request.");
AppConfig.debugLogger(OpenAIRequest.class, () -> "App dotAI is not enabled and will not send request.");
throw new DotRuntimeException("App dotAI config without API urls or API key");
}

final AIModel model = appConfig.resolveModelOrThrow(json.optString(AiKeys.MODEL));

if (appConfig.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.debug(OpenAIRequest.class, "posting: " + json);
}

final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L)
+ model.minIntervalBetweenCalls()
- System.currentTimeMillis();
Expand Down
2 changes: 2 additions & 0 deletions dotCMS/src/main/java/com/liferay/util/StringPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,6 @@ public class StringPool {

public static final String TRUE = Boolean.TRUE.toString();

public static final String FALSE = Boolean.FALSE.toString();

}
4 changes: 4 additions & 0 deletions dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,8 @@ static Map<String, Secret> aiAppSecrets(final WireMockServer wireMockServer, fin
return aiAppSecrets(wireMockServer, host, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL);
}

static void removeSecrets(final Host host) throws DotDataException, DotSecurityException {
APILocator.getAppsAPI().removeSecretsForSite(host, APILocator.systemUser());
}

}
Loading

0 comments on commit 3b8be74

Please sign in to comment.