Skip to content

Commit

Permalink
Local AI Integration (langchain4j#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuraleta authored Jul 25, 2023
1 parent 8d6f9f6 commit 6e240e4
Show file tree
Hide file tree
Showing 16 changed files with 609 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package dev.langchain4j.model.localai;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import lombok.Builder;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static java.time.Duration.ofSeconds;

public class LocalAiChatModel implements ChatLanguageModel {

private final OpenAiClient client;
private final String modelName;
private final Double temperature;
private final Double topP;
private final Integer maxTokens;
private final Integer maxRetries;

@Builder
public LocalAiChatModel(String baseUrl,
String modelName,
Double temperature,
Double topP,
Integer maxTokens,
Duration timeout,
Integer maxRetries,
Boolean logRequests,
Boolean logResponses) {

temperature = temperature == null ? 0.7 : temperature;
timeout = timeout == null ? ofSeconds(60) : timeout;
maxRetries = maxRetries == null ? 3 : maxRetries;

this.client = OpenAiClient.builder()
.apiKey("ignored")
.url(ensureNotBlank(baseUrl, "baseUrl"))
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.logRequests(logRequests)
.logResponses(logResponses)
.build();
this.modelName = ensureNotBlank(modelName, "modelName");
this.temperature = temperature;
this.topP = topP;
this.maxTokens = maxTokens;
this.maxRetries = maxRetries;
}

@Override
public AiMessage sendMessages(List<ChatMessage> messages) {
return sendMessages(messages, null);
}

@Override
public AiMessage sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {

ChatCompletionRequest request = ChatCompletionRequest.builder()
.model(modelName)
.messages(toOpenAiMessages(messages))
.functions(toFunctions(toolSpecifications))
.temperature(temperature)
.topP(topP)
.maxTokens(maxTokens)
.build();

ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request).execute(), maxRetries);

return aiMessageFrom(response);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package dev.langchain4j.model.localai;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.embedding.EmbeddingRequest;
import dev.ai4j.openai4j.embedding.EmbeddingResponse;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import lombok.Builder;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static java.time.Duration.ofSeconds;
import static java.util.stream.Collectors.toList;

public class LocalAiEmbeddingModel implements EmbeddingModel {

private final OpenAiClient client;
private final String modelName;
private final Integer maxRetries;

@Builder
public LocalAiEmbeddingModel(String baseUrl,
String modelName,
Duration timeout,
Integer maxRetries,
Boolean logRequests,
Boolean logResponses) {

timeout = timeout == null ? ofSeconds(60) : timeout;
maxRetries = maxRetries == null ? 3 : maxRetries;

this.client = OpenAiClient.builder()
.apiKey("ignored")
.url(ensureNotBlank(baseUrl, "baseUrl"))
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.logRequests(logRequests)
.logResponses(logResponses)
.build();
this.modelName = ensureNotBlank(modelName, "modelName");
this.maxRetries = maxRetries;
}

@Override
public List<Embedding> embedAll(List<TextSegment> textSegments) {

List<String> texts = textSegments.stream()
.map(TextSegment::text)
.collect(toList());

EmbeddingRequest request = EmbeddingRequest.builder()
.input(texts)
.model(modelName)
.build();

EmbeddingResponse response = withRetry(() -> client.embedding(request).execute(), maxRetries);

return response.data().stream()
.map(openAiEmbedding -> Embedding.from(openAiEmbedding.embedding()))
.collect(toList());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package dev.langchain4j.model.localai;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.completion.CompletionRequest;
import dev.ai4j.openai4j.completion.CompletionResponse;
import dev.langchain4j.model.language.LanguageModel;
import lombok.Builder;

import java.time.Duration;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static java.time.Duration.ofSeconds;

public class LocalAiLanguageModel implements LanguageModel {

private final OpenAiClient client;
private final String modelName;
private final Double temperature;
private final Double topP;
private final Integer maxTokens;
private final Integer maxRetries;

@Builder
public LocalAiLanguageModel(String baseUrl,
String modelName,
Double temperature,
Double topP,
Integer maxTokens,
Duration timeout,
Integer maxRetries,
Boolean logRequests,
Boolean logResponses) {

temperature = temperature == null ? 0.7 : temperature;
timeout = timeout == null ? ofSeconds(60) : timeout;
maxRetries = maxRetries == null ? 3 : maxRetries;

this.client = OpenAiClient.builder()
.apiKey("ignored")
.url(ensureNotBlank(baseUrl, "baseUrl"))
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.logRequests(logRequests)
.logResponses(logResponses)
.build();
this.modelName = ensureNotBlank(modelName, "modelName");
this.temperature = temperature;
this.topP = topP;
this.maxTokens = maxTokens;
this.maxRetries = maxRetries;
}

@Override
public String process(String text) {

CompletionRequest request = CompletionRequest.builder()
.model(modelName)
.prompt(text)
.temperature(temperature)
.topP(topP)
.maxTokens(maxTokens)
.build();

CompletionResponse response = withRetry(() -> client.completion(request).execute(), maxRetries);

return response.text();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package dev.langchain4j.model.localai;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.ai4j.openai4j.chat.FunctionCall;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import lombok.Builder;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toFunctions;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
import static java.time.Duration.ofSeconds;

public class LocalAiStreamingChatModel implements StreamingChatLanguageModel {

private final OpenAiClient client;
private final String modelName;
private final Double temperature;
private final Double topP;
private final Integer maxTokens;

@Builder
public LocalAiStreamingChatModel(String baseUrl,
String modelName,
Double temperature,
Double topP,
Integer maxTokens,
Duration timeout,
Boolean logRequests,
Boolean logResponses) {

temperature = temperature == null ? 0.7 : temperature;
timeout = timeout == null ? ofSeconds(60) : timeout;

this.client = OpenAiClient.builder()
.apiKey("ignored")
.url(ensureNotBlank(baseUrl, "baseUrl"))
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.logRequests(logRequests)
.logResponses(logResponses)
.build();
this.modelName = ensureNotBlank(modelName, "modelName");
this.temperature = temperature;
this.topP = topP;
this.maxTokens = maxTokens;
}

@Override
public void sendMessages(List<ChatMessage> messages, StreamingResponseHandler handler) {
sendMessages(messages, null, handler);
}

@Override
public void sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler handler) {
ChatCompletionRequest request = ChatCompletionRequest.builder()
.stream(true)
.model(modelName)
.messages(toOpenAiMessages(messages))
.functions(toFunctions(toolSpecifications))
.temperature(temperature)
.topP(topP)
.maxTokens(maxTokens)
.build();

client.chatCompletion(request)
.onPartialResponse(partialResponse -> handle(partialResponse, handler))
.onComplete(handler::onComplete)
.onError(handler::onError)
.execute();
}

private static void handle(ChatCompletionResponse partialResponse,
StreamingResponseHandler handler) {
Delta delta = partialResponse.choices().get(0).delta();
String content = delta.content();
FunctionCall functionCall = delta.functionCall();
if (content != null) {
handler.onNext(content);
} else if (functionCall != null) {
if (functionCall.name() != null) {
handler.onToolName(functionCall.name());
} else if (functionCall.arguments() != null) {
handler.onToolArguments(functionCall.arguments());
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package dev.langchain4j.model.localai;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.completion.CompletionRequest;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.language.StreamingLanguageModel;
import lombok.Builder;

import java.time.Duration;

import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static java.time.Duration.ofSeconds;

public class LocalAiStreamingLanguageModel implements StreamingLanguageModel {

private final OpenAiClient client;
private final String modelName;
private final Double temperature;
private final Double topP;
private final Integer maxTokens;

@Builder
public LocalAiStreamingLanguageModel(String baseUrl,
String modelName,
Double temperature,
Double topP,
Integer maxTokens,
Duration timeout,
Boolean logRequests,
Boolean logResponses) {

temperature = temperature == null ? 0.7 : temperature;
timeout = timeout == null ? ofSeconds(60) : timeout;

this.client = OpenAiClient.builder()
.apiKey("ignored")
.url(ensureNotBlank(baseUrl, "baseUrl"))
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.logRequests(logRequests)
.logResponses(logResponses)
.build();
this.modelName = ensureNotBlank(modelName, "modelName");
this.temperature = temperature;
this.topP = topP;
this.maxTokens = maxTokens;
}

@Override
public void process(String text, StreamingResponseHandler handler) {

CompletionRequest request = CompletionRequest.builder()
.model(modelName)
.prompt(text)
.temperature(temperature)
.topP(topP)
.maxTokens(maxTokens)
.build();

client.completion(request)
.onPartialResponse(partialResponse -> {
String partialResponseText = partialResponse.text();
if (partialResponseText != null) {
handler.onNext(partialResponseText);
}
})
.onComplete(handler::onComplete)
.onError(handler::onError)
.execute();
}
}
Loading

0 comments on commit 6e240e4

Please sign in to comment.