Skip to content

Commit

Permalink
#30361: allowing stream chat to dotAI
Browse files Browse the repository at this point in the history
  • Loading branch information
victoralfaro-dotcms committed Oct 17, 2024
1 parent be08096 commit 71e2d84
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 39 deletions.
14 changes: 9 additions & 5 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
*/
public interface AIClientStrategy {

AIClientStrategy NOOP = (client, handler, request, output) -> AIResponse.builder().build();
AIClientStrategy NOOP = (client, handler, request, output) -> {
AIResponse.builder().build();
return null;
};

/**
* Applies the strategy to the given AI client request and handles the response.
Expand All @@ -32,10 +35,11 @@ public interface AIClientStrategy {
* @param handler the response evaluator to handle the response
* @param request the AI request to be processed
* @param output the output stream to which the response will be written
* @return result output stream
*/
void applyStrategy(AIClient client,
AIResponseEvaluator handler,
AIRequest<? extends Serializable> request,
OutputStream output);
OutputStream applyStrategy(AIClient client,
AIResponseEvaluator handler,
AIRequest<? extends Serializable> request,
OutputStream output);

}
12 changes: 8 additions & 4 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.dotcms.ai.client;

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Optional;

/**
* Default implementation of the {@link AIClientStrategy} interface.
Expand All @@ -22,11 +24,13 @@
public class AIDefaultStrategy implements AIClientStrategy {

@Override
public void applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream output) {
public OutputStream applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream incoming) {
final OutputStream output = Optional.ofNullable(incoming).orElseGet(ByteArrayOutputStream::new);
client.sendRequest(request, output);
return output;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,24 @@ public class AIModelFallbackStrategy implements AIClientStrategy {
* @param request the AI request to be processed
* @param output the output stream to which the response will be written
* @throws DotAIAllModelsExhaustedException if all models are exhausted and no successful response is obtained
* @return result output stream
*/
@Override
public void applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream output) {
public OutputStream applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream output) {
final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request);
final Tuple2<AIModel, Model> modelTuple = resolveModel(jsonRequest);

final AIResponseData firstAttempt = sendAttempt(client, handler, jsonRequest, output, modelTuple);
if (firstAttempt.isSuccess()) {
return;
return output;
}

runFallbacks(client, handler, jsonRequest, output, modelTuple);

return output;
}

private static Tuple2<AIModel, Model> resolveModel(final JSONObjectAIRequest request) {
Expand All @@ -96,11 +99,7 @@ private static Tuple2<AIModel, Model> resolveModel(final JSONObjectAIRequest req
}

private static boolean isSameAsFirst(final Model firstAttempt, final Model model) {
if (firstAttempt.equals(model)) {
return true;
}

return false;
return firstAttempt.equals(model);
}

private static boolean isOperational(final Model model) {
Expand All @@ -114,31 +113,45 @@ private static boolean isOperational(final Model model) {
return true;
}

private static AIResponseData doSend(final AIClient client, final AIRequest<? extends Serializable> request) {
final ByteArrayOutputStream output = new ByteArrayOutputStream();
private static boolean isStream(final JSONObjectAIRequest request) {
return request.getPayload().optBoolean(AiKeys.STREAM, false);
}

private static AIResponseData doSend(final AIClient client,
final JSONObjectAIRequest request,
final OutputStream incoming) {
final OutputStream output = Optional.ofNullable(incoming).orElseGet(ByteArrayOutputStream::new);
client.sendRequest(request, output);

final AIResponseData responseData = new AIResponseData();
responseData.setResponse(output.toString());
IOUtils.closeQuietly(output);
if (!isStream(request)) {
IOUtils.closeQuietly(output);
}

return responseData;
}

private static void redirectOutput(final OutputStream output, final String response) {
private static void redirectOutput(final JSONObjectAIRequest request,
final OutputStream output,
final String response) {
if (isStream(request)) {
return;
}

try (final InputStream input = new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))) {
IOUtils.copy(input, output);
} catch (IOException e) {
throw new DotRuntimeException(e);
}
}

private static void notifyFailure(final AIModel aiModel, final AIRequest<? extends Serializable> request) {
private static void notifyFailure(final AIModel aiModel, final JSONObjectAIRequest request) {
AIAppValidator.get().validateModelsUsage(aiModel, request.getUserId());
}

private static void handleFailure(final Tuple2<AIModel, Model> modelTuple,
final AIRequest<? extends Serializable> request,
final JSONObjectAIRequest request,
final AIResponseData responseData) {
final AIModel aiModel = modelTuple._1;
final Model model = modelTuple._2;
Expand Down Expand Up @@ -177,7 +190,7 @@ private static AIResponseData sendAttempt(final AIClient client,
final Tuple2<AIModel, Model> modelTuple) {

final AIResponseData responseData = Try
.of(() -> doSend(client, request))
.of(() -> doSend(client, request, output))
.getOrElseGet(exception -> fromException(evaluator, exception));

if (!responseData.isSuccess()) {
Expand All @@ -200,7 +213,7 @@ private static AIResponseData sendAttempt(final AIClient client,
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName()));
redirectOutput(output, responseData.getResponse());
redirectOutput(request, output, responseData.getResponse());
} else {
logFailure(modelTuple, responseData);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,11 @@ public static AIProxiedClient of(final AIClient client, final AIProxyStrategy st
* @return the AI response
*/
public <T extends Serializable> AIResponse sendToAI(final AIRequest<T> request, final OutputStream output) {
final OutputStream finalOutput = Optional.ofNullable(output).orElseGet(ByteArrayOutputStream::new);

strategy.applyStrategy(client, responseEvaluator, request, finalOutput);
final OutputStream resultOutput = strategy.applyStrategy(client, responseEvaluator, request, output);

return Optional.ofNullable(output)
.map(out -> AIResponse.EMPTY)
.orElseGet(() -> AIResponse.builder().withResponse(finalOutput.toString()).build());
.orElseGet(() -> AIResponse.builder().withResponse(resultOutput.toString()).build());
}

}
24 changes: 21 additions & 3 deletions dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.vavr.Tuple2;
import io.vavr.control.Try;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpStatus;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase;
import org.apache.http.client.methods.HttpUriRequest;
Expand All @@ -29,7 +30,9 @@
import org.apache.http.impl.client.HttpClients;

import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Optional;
Expand Down Expand Up @@ -129,17 +132,19 @@ public <T extends Serializable> void sendRequest(final AIRequest<T> request, fin

lastRestCall.put(aiModel, System.currentTimeMillis());

try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
try (final CloseableHttpClient httpClient = HttpClients.createDefault()) {
final StringEntity jsonEntity = new StringEntity(payload.toString(), ContentType.APPLICATION_JSON);
final HttpUriRequest httpRequest = AIClient.resolveMethod(jsonRequest.getMethod(), jsonRequest.getUrl());
httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON);
httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + appConfig.getApiKey());

if (!payload.getAsMap().isEmpty()) {
Try.run(() -> HttpEntityEnclosingRequestBase.class.cast(httpRequest).setEntity(jsonEntity));
Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity));
}

try (CloseableHttpResponse response = httpClient.execute(httpRequest)) {
try (final CloseableHttpResponse response = httpClient.execute(httpRequest)) {
onStreamCheckFotStatusCode(modelName, payload, response);

final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent());
final byte[] buffer = new byte[1024];
int len;
Expand All @@ -161,4 +166,17 @@ public <T extends Serializable> void sendRequest(final AIRequest<T> request, fin
}
}

private static void onStreamCheckFotStatusCode(final String modelName,
final JSONObject payload,
final CloseableHttpResponse response) {
if (payload.optBoolean(AiKeys.STREAM, false)) {
final int statusCode = response.getStatusLine().getStatusCode();
if (Response.Status.Family.familyOf(statusCode) == Response.Status.Family.CLIENT_ERROR) {
throw new DotAIModelNotFoundException(String.format(
"Model used [%s] in request in stream mode is not found",
modelName));
}
}
}

}
1 change: 0 additions & 1 deletion dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ public Builder withResponse(final String response) {
return this;
}


public AIResponse build() {
return new AIResponse(this);
}
Expand Down
11 changes: 7 additions & 4 deletions dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,15 @@ public class CompletionsResource {
public final Response summarizeFromContent(@Context final HttpServletRequest request,
@Context final HttpServletResponse response,
final CompletionsForm formIn) {
final CompletionsForm resolvedForm = resolveForm(request, response, formIn);
return getResponse(
request,
response,
formIn,
() -> APILocator.getDotAIAPI().getCompletionsAPI().summarize(formIn),
() -> APILocator.getDotAIAPI().getCompletionsAPI().summarize(resolvedForm),
output -> APILocator.getDotAIAPI()
.getCompletionsAPI()
.summarizeStream(formIn, new LineReadingOutputStream(output)));
.summarizeStream(resolvedForm, new LineReadingOutputStream(output)));
}

/**
Expand All @@ -81,14 +82,15 @@ public final Response summarizeFromContent(@Context final HttpServletRequest req
public final Response rawPrompt(@Context final HttpServletRequest request,
@Context final HttpServletResponse response,
final CompletionsForm formIn) {
final CompletionsForm resolvedForm = resolveForm(request, response, formIn);
return getResponse(
request,
response,
formIn,
() -> APILocator.getDotAIAPI().getCompletionsAPI().raw(formIn),
() -> APILocator.getDotAIAPI().getCompletionsAPI().raw(resolvedForm),
output -> APILocator.getDotAIAPI()
.getCompletionsAPI()
.rawStream(formIn, new LineReadingOutputStream(output)));
.rawStream(resolvedForm, new LineReadingOutputStream(output)));
}

/**
Expand Down Expand Up @@ -180,6 +182,7 @@ private static Response getResponse(final HttpServletRequest request,

final JSONObject jsonResponse = noStream.get();
jsonResponse.put(AiKeys.TOTAL_TIME, System.currentTimeMillis() - startTime + "ms");

return Response.ok(jsonResponse.toString(), MediaType.APPLICATION_JSON).build();
}

Expand Down

0 comments on commit 71e2d84

Please sign in to comment.