Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions dotCMS/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,34 @@
<!-- <docker.platforms>linux/arm64,linux/amd64</docker.platforms>-->
<tomcat.run.data.folder>${project.build.directory}/tomcat-run-data/${context.name}</tomcat.run.data.folder>
<cleanup.data.folder>false</cleanup.data.folder>
<langchain4j.version>1.3.0</langchain4j.version>
</properties>
<dependencies>
<!-- LangChain4j core + OpenAI (use OpenAI-compatible for gateways) -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-anthropic</artifactId>
<version>${langchain4j.version}</version>
</dependency>

<!-- https://mvnrepository.com/artifact/dev.langchain4j/langchain4j-embeddings-all-minilm-l6-v2 -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<version>0.36.2</version>
</dependency>

<!-- Force dependency resolution when using also-make option even though these do not provide jar dependencies -->
<dependency>
<groupId>com.dotcms</groupId>
Expand Down
85 changes: 85 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package com.dotcms.ai.api;

import com.dotcms.ai.config.AiModelConfig;

public class CompletionRequest {

private final AiModelConfig chatModelConfig;
private final String vendorModelPath;
private final float temperature;
private final String prompt;

private CompletionRequest(final Builder builder) {

this.chatModelConfig = builder.chatModelConfig;
this.vendorModelPath = builder.vendorModelPath;
this.temperature = builder.temperature;
this.prompt = builder.prompt;
}

public String getVendorModelPath() {
return vendorModelPath;
}

public float getTemperature() {
return temperature;
}

public AiModelConfig getChatModelConfig() {
return this.chatModelConfig;
}

public String getPrompt() {
return prompt;
}

@Override
public String toString() {
return "CompletionRequest{" +
"chatModelConfig=" + chatModelConfig +
", vendorModelPath='" + vendorModelPath + '\'' +
", temperature=" + temperature +
", prompt='" + prompt + '\'' +
'}';
}

public static Builder builder() {
return new Builder();
}

public String getSystemPrompt() {
return null; // todo: fill this thig
}

public static final class Builder {
private AiModelConfig chatModelConfig;
private String vendorModelPath;
private Float temperature;
private String prompt;


public Builder chatModelConfig(AiModelConfig modelConfig) {
this.chatModelConfig = modelConfig;
return this;
}

public Builder vendorModelPath(String vendorModelPath) {
this.vendorModelPath = vendorModelPath;
return this;
}

public Builder prompt(String prompt) {
this.prompt = prompt;
return this;
}

public Builder temperature(Float temperature) {
this.temperature = temperature;
return this;
}

public CompletionRequest build() {
return new CompletionRequest(this);
}
}
}
26 changes: 26 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionResponse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.dotcms.ai.api;

public class CompletionResponse {

private final String messageText;
private final Object aiMessage;
private final Object metatada;

public CompletionResponse(String messageText, Object aiMessage, Object metatada) {
this.messageText = messageText;
this.aiMessage = aiMessage;
this.metatada = metatada;
}

public String getMessageText() {
return messageText;
}

public Object getAiMessage() {
return aiMessage;
}

public Object getMetatada() {
return metatada;
}
}
7 changes: 7 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,11 @@ JSONObject prompt(String systemPrompt,
*/
void rawStream(CompletionsForm promptForm, OutputStream out);

/**
* this method takes a prompt in the request and returns a json AI response based on the parameters
* passed in.
* @param completionRequest
* @return
*/
CompletionResponse raw(CompletionRequest completionRequest);
}
42 changes: 40 additions & 2 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java
Original file line number Diff line number Diff line change
@@ -1,31 +1,41 @@
package com.dotcms.ai.api;

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.api.provider.VendorModelProviderFactory;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.app.AIModelType;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.client.AIProxyClient;
import com.dotcms.ai.client.JSONObjectAIRequest;
import com.dotcms.ai.config.AiModelConfig;
import com.dotcms.ai.db.EmbeddingsDTO;
import com.dotcms.ai.domain.AIResponse;
import com.dotcms.ai.client.JSONObjectAIRequest;
import com.dotcms.ai.domain.Model;
import com.dotcms.ai.exception.DotAIModelNotFoundException;
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotcms.ai.util.AIUtil;
import com.dotcms.ai.util.EncodingUtil;
import com.dotcms.analytics.Util;
import com.dotcms.api.web.HttpServletRequestThreadLocal;
import com.dotcms.cdi.CDIUtils;
import com.dotcms.mock.request.FakeHttpRequest;
import com.dotcms.mock.response.BaseResponse;
import com.dotcms.rendering.velocity.util.VelocityUtil;
import com.dotmarketing.business.APILocator;
import com.dotmarketing.business.web.WebAPILocator;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Config;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.StringUtils;
import com.dotmarketing.util.UtilMethods;
import com.dotmarketing.util.json.JSONArray;
import com.dotmarketing.util.json.JSONObject;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.response.ChatResponse;
import io.vavr.Lazy;
import io.vavr.Tuple2;
import io.vavr.control.Try;
Expand All @@ -38,6 +48,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

Expand All @@ -53,14 +64,21 @@ public class CompletionsAPIImpl implements CompletionsAPI {
Lazy.of(() -> Config.getIntProperty(DEFAULT_AI_MAX_NUMBER_OF_TOKENS, 16384));

private final AppConfig config;
private final VendorModelProviderFactory modelProviderFactory;

public CompletionsAPIImpl(final AppConfig config) {
this(config, CDIUtils.getBeanThrows(VendorModelProviderFactory.class));
}

public CompletionsAPIImpl(final AppConfig config,
final VendorModelProviderFactory modelProviderFactory) {
final Lazy<AppConfig> defaultConfig = Lazy.of(() -> ConfigService.INSTANCE.config(
Try.of(() -> WebAPILocator
.getHostWebAPI()
.getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest()))
.getOrElse(APILocator.systemHost())));
this.config = Optional.ofNullable(config).orElse(defaultConfig.get());
this.modelProviderFactory = modelProviderFactory;
}

@Override
Expand Down Expand Up @@ -144,6 +162,26 @@ public JSONObject raw(final CompletionsForm promptForm) {
return raw(jsonObject, UtilMethods.extractUserIdOrNull(promptForm.user));
}

@Override
public CompletionResponse raw(final CompletionRequest completionRequest) {

Logger.debug(this, ()-> "Doing raw request: " + completionRequest);
final AiModelConfig modelConfig = completionRequest.getChatModelConfig();
final String vendorName = AIUtil.getVendorFromPath(completionRequest.getVendorModelPath());
final Float temperature = completionRequest.getTemperature();
final ChatModel chatModel = this.modelProviderFactory.get(vendorName,
Objects.nonNull(temperature)? AiModelConfig.withTemperature(modelConfig, temperature).build(): // if the temperature is set, lets override the config
completionRequest.getChatModelConfig());
final String userPrompt = completionRequest.getPrompt();
final String systemPrompt = completionRequest.getSystemPrompt();
final UserMessage userMessage = new UserMessage(userPrompt);
final List<ChatMessage> messages = StringUtils.isSet(systemPrompt)?
List.of(new SystemMessage(systemPrompt), userMessage):List.of(userMessage);
final ChatResponse chatResponse = chatModel.chat(messages);
return new CompletionResponse(chatResponse.aiMessage().text(),
chatResponse.aiMessage(), chatResponse.metadata());
}

@Override
public void rawStream(final CompletionsForm promptForm, final OutputStream output) {
final JSONObject json = buildRequestJson(promptForm);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.dotcms.ai.api.provider;

import com.dotcms.ai.config.AiModelConfig;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;

/**
* Provides a Chat Model based on a configuration
* @author jsanca
*/
public interface ChatModelProvider {

/**
* Create a new ChatModel based on the config
* @param config AiModelConfig
* @return
*/
ChatModel create(AiModelConfig config);

/**
* Create new Streaming ChatModel
* @param config AiModelConfig
* @return StreamingChatModel
*/
StreamingChatModel createStreaming(AiModelConfig config);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.dotcms.ai.api.provider;

import com.dotcms.ai.config.AiModelConfig;
import dev.langchain4j.model.embedding.EmbeddingModel;

/**
* Provides an Embedding Model based on a configuration
* @author jsanca
*/
public interface EmbeddingModelProvider {

/**
* Creates embedding based on a configuration
* @param config
* @return EmbeddingModel
*/
EmbeddingModel createEmbedding(AiModelConfig config);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.dotcms.ai.api.provider;

/**
* Just groups model providers
* @author jsanca
*/
public interface VendorModelProvider extends ChatModelProvider, EmbeddingModelProvider {

String getVendorName();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package com.dotcms.ai.api.provider;

import com.dotcms.ai.api.provider.anthropic.AnthropicVendorModelProviderImpl;
import com.dotcms.ai.api.provider.openai.OpenAiVendorModelProviderImpl;
import com.dotcms.ai.config.AiModelConfig;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;

import javax.enterprise.context.ApplicationScoped;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
* Factory to create the AI Model Providers
* @author jsanca
*/
@ApplicationScoped
public class VendorModelProviderFactory {

private final Map<String, VendorModelProvider> providers = new ConcurrentHashMap<>();

public VendorModelProviderFactory() {
this(List.of(new OpenAiVendorModelProviderImpl(), new AnthropicVendorModelProviderImpl()));
}

public VendorModelProviderFactory(final List<VendorModelProvider> providerList) {
for (final VendorModelProvider provider : providerList) {
addProvider(provider);
}
}

// todo: add osgi support probably will need a proxy that gets injected this class and exposes statically to the activator
public void addProvider(final VendorModelProvider provider) {
providers.put(provider.getVendorName().toLowerCase(), provider);
}

public ChatModel get(final String providerName, final AiModelConfig config) {

final VendorModelProvider provider = providers.get(providerName.toLowerCase());
if (provider == null) {
// todo: if eventually have a default one, by config on, use instead
throw new IllegalArgumentException("Unknown model provider: " + providerName);
}
return provider.create(config);
}

public StreamingChatModel getStreaming(final String providerName, final AiModelConfig config) {

final VendorModelProvider provider = providers.get(providerName.toLowerCase());
if (provider == null) {
// todo: if eventually have a default one, by config on, use instead
throw new IllegalArgumentException("Unknown model provider: " + providerName);
}
return provider.createStreaming(config);
}

public EmbeddingModel getEmbedding(final String providerName, final AiModelConfig config) {

final VendorModelProvider provider = providers.get(providerName.toLowerCase());
if (provider == null) {
// todo: if eventually have a default one, by config on, use instead
throw new IllegalArgumentException("Unknown model provider: " + providerName);
}
return provider.createEmbedding(config);
}
}
Loading
Loading