forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Local AI Integration (langchain4j#49)
- Loading branch information
Showing
16 changed files
with
609 additions
and
17 deletions.
There are no files selected for viewing
82 changes: 82 additions & 0 deletions
82
langchain4j/src/main/java/dev/langchain4j/model/localai/LocalAiChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package dev.langchain4j.model.localai; | ||
|
||
import dev.ai4j.openai4j.OpenAiClient; | ||
import dev.ai4j.openai4j.chat.ChatCompletionRequest; | ||
import dev.ai4j.openai4j.chat.ChatCompletionResponse; | ||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import lombok.Builder; | ||
|
||
import java.time.Duration; | ||
import java.util.List; | ||
|
||
import static dev.langchain4j.internal.RetryUtils.withRetry; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*; | ||
import static java.time.Duration.ofSeconds; | ||
|
||
public class LocalAiChatModel implements ChatLanguageModel { | ||
|
||
private final OpenAiClient client; | ||
private final String modelName; | ||
private final Double temperature; | ||
private final Double topP; | ||
private final Integer maxTokens; | ||
private final Integer maxRetries; | ||
|
||
@Builder | ||
public LocalAiChatModel(String baseUrl, | ||
String modelName, | ||
Double temperature, | ||
Double topP, | ||
Integer maxTokens, | ||
Duration timeout, | ||
Integer maxRetries, | ||
Boolean logRequests, | ||
Boolean logResponses) { | ||
|
||
temperature = temperature == null ? 0.7 : temperature; | ||
timeout = timeout == null ? ofSeconds(60) : timeout; | ||
maxRetries = maxRetries == null ? 3 : maxRetries; | ||
|
||
this.client = OpenAiClient.builder() | ||
.apiKey("ignored") | ||
.url(ensureNotBlank(baseUrl, "baseUrl")) | ||
.callTimeout(timeout) | ||
.connectTimeout(timeout) | ||
.readTimeout(timeout) | ||
.writeTimeout(timeout) | ||
.logRequests(logRequests) | ||
.logResponses(logResponses) | ||
.build(); | ||
this.modelName = ensureNotBlank(modelName, "modelName"); | ||
this.temperature = temperature; | ||
this.topP = topP; | ||
this.maxTokens = maxTokens; | ||
this.maxRetries = maxRetries; | ||
} | ||
|
||
@Override | ||
public AiMessage sendMessages(List<ChatMessage> messages) { | ||
return sendMessages(messages, null); | ||
} | ||
|
||
@Override | ||
public AiMessage sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) { | ||
|
||
ChatCompletionRequest request = ChatCompletionRequest.builder() | ||
.model(modelName) | ||
.messages(toOpenAiMessages(messages)) | ||
.functions(toFunctions(toolSpecifications)) | ||
.temperature(temperature) | ||
.topP(topP) | ||
.maxTokens(maxTokens) | ||
.build(); | ||
|
||
ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request).execute(), maxRetries); | ||
|
||
return aiMessageFrom(response); | ||
} | ||
} |
68 changes: 68 additions & 0 deletions
68
langchain4j/src/main/java/dev/langchain4j/model/localai/LocalAiEmbeddingModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package dev.langchain4j.model.localai; | ||
|
||
import dev.ai4j.openai4j.OpenAiClient; | ||
import dev.ai4j.openai4j.embedding.EmbeddingRequest; | ||
import dev.ai4j.openai4j.embedding.EmbeddingResponse; | ||
import dev.langchain4j.data.embedding.Embedding; | ||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.model.embedding.EmbeddingModel; | ||
import lombok.Builder; | ||
|
||
import java.time.Duration; | ||
import java.util.List; | ||
|
||
import static dev.langchain4j.internal.RetryUtils.withRetry; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
import static java.time.Duration.ofSeconds; | ||
import static java.util.stream.Collectors.toList; | ||
|
||
public class LocalAiEmbeddingModel implements EmbeddingModel { | ||
|
||
private final OpenAiClient client; | ||
private final String modelName; | ||
private final Integer maxRetries; | ||
|
||
@Builder | ||
public LocalAiEmbeddingModel(String baseUrl, | ||
String modelName, | ||
Duration timeout, | ||
Integer maxRetries, | ||
Boolean logRequests, | ||
Boolean logResponses) { | ||
|
||
timeout = timeout == null ? ofSeconds(60) : timeout; | ||
maxRetries = maxRetries == null ? 3 : maxRetries; | ||
|
||
this.client = OpenAiClient.builder() | ||
.apiKey("ignored") | ||
.url(ensureNotBlank(baseUrl, "baseUrl")) | ||
.callTimeout(timeout) | ||
.connectTimeout(timeout) | ||
.readTimeout(timeout) | ||
.writeTimeout(timeout) | ||
.logRequests(logRequests) | ||
.logResponses(logResponses) | ||
.build(); | ||
this.modelName = ensureNotBlank(modelName, "modelName"); | ||
this.maxRetries = maxRetries; | ||
} | ||
|
||
@Override | ||
public List<Embedding> embedAll(List<TextSegment> textSegments) { | ||
|
||
List<String> texts = textSegments.stream() | ||
.map(TextSegment::text) | ||
.collect(toList()); | ||
|
||
EmbeddingRequest request = EmbeddingRequest.builder() | ||
.input(texts) | ||
.model(modelName) | ||
.build(); | ||
|
||
EmbeddingResponse response = withRetry(() -> client.embedding(request).execute(), maxRetries); | ||
|
||
return response.data().stream() | ||
.map(openAiEmbedding -> Embedding.from(openAiEmbedding.embedding())) | ||
.collect(toList()); | ||
} | ||
} |
71 changes: 71 additions & 0 deletions
71
langchain4j/src/main/java/dev/langchain4j/model/localai/LocalAiLanguageModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package dev.langchain4j.model.localai; | ||
|
||
import dev.ai4j.openai4j.OpenAiClient; | ||
import dev.ai4j.openai4j.completion.CompletionRequest; | ||
import dev.ai4j.openai4j.completion.CompletionResponse; | ||
import dev.langchain4j.model.language.LanguageModel; | ||
import lombok.Builder; | ||
|
||
import java.time.Duration; | ||
|
||
import static dev.langchain4j.internal.RetryUtils.withRetry; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
import static java.time.Duration.ofSeconds; | ||
|
||
public class LocalAiLanguageModel implements LanguageModel { | ||
|
||
private final OpenAiClient client; | ||
private final String modelName; | ||
private final Double temperature; | ||
private final Double topP; | ||
private final Integer maxTokens; | ||
private final Integer maxRetries; | ||
|
||
@Builder | ||
public LocalAiLanguageModel(String baseUrl, | ||
String modelName, | ||
Double temperature, | ||
Double topP, | ||
Integer maxTokens, | ||
Duration timeout, | ||
Integer maxRetries, | ||
Boolean logRequests, | ||
Boolean logResponses) { | ||
|
||
temperature = temperature == null ? 0.7 : temperature; | ||
timeout = timeout == null ? ofSeconds(60) : timeout; | ||
maxRetries = maxRetries == null ? 3 : maxRetries; | ||
|
||
this.client = OpenAiClient.builder() | ||
.apiKey("ignored") | ||
.url(ensureNotBlank(baseUrl, "baseUrl")) | ||
.callTimeout(timeout) | ||
.connectTimeout(timeout) | ||
.readTimeout(timeout) | ||
.writeTimeout(timeout) | ||
.logRequests(logRequests) | ||
.logResponses(logResponses) | ||
.build(); | ||
this.modelName = ensureNotBlank(modelName, "modelName"); | ||
this.temperature = temperature; | ||
this.topP = topP; | ||
this.maxTokens = maxTokens; | ||
this.maxRetries = maxRetries; | ||
} | ||
|
||
@Override | ||
public String process(String text) { | ||
|
||
CompletionRequest request = CompletionRequest.builder() | ||
.model(modelName) | ||
.prompt(text) | ||
.temperature(temperature) | ||
.topP(topP) | ||
.maxTokens(maxTokens) | ||
.build(); | ||
|
||
CompletionResponse response = withRetry(() -> client.completion(request).execute(), maxRetries); | ||
|
||
return response.text(); | ||
} | ||
} |
98 changes: 98 additions & 0 deletions
98
langchain4j/src/main/java/dev/langchain4j/model/localai/LocalAiStreamingChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package dev.langchain4j.model.localai; | ||
|
||
import dev.ai4j.openai4j.OpenAiClient; | ||
import dev.ai4j.openai4j.chat.ChatCompletionRequest; | ||
import dev.ai4j.openai4j.chat.ChatCompletionResponse; | ||
import dev.ai4j.openai4j.chat.Delta; | ||
import dev.ai4j.openai4j.chat.FunctionCall; | ||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import lombok.Builder; | ||
|
||
import java.time.Duration; | ||
import java.util.List; | ||
|
||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toFunctions; | ||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages; | ||
import static java.time.Duration.ofSeconds; | ||
|
||
public class LocalAiStreamingChatModel implements StreamingChatLanguageModel { | ||
|
||
private final OpenAiClient client; | ||
private final String modelName; | ||
private final Double temperature; | ||
private final Double topP; | ||
private final Integer maxTokens; | ||
|
||
@Builder | ||
public LocalAiStreamingChatModel(String baseUrl, | ||
String modelName, | ||
Double temperature, | ||
Double topP, | ||
Integer maxTokens, | ||
Duration timeout, | ||
Boolean logRequests, | ||
Boolean logResponses) { | ||
|
||
temperature = temperature == null ? 0.7 : temperature; | ||
timeout = timeout == null ? ofSeconds(60) : timeout; | ||
|
||
this.client = OpenAiClient.builder() | ||
.apiKey("ignored") | ||
.url(ensureNotBlank(baseUrl, "baseUrl")) | ||
.callTimeout(timeout) | ||
.connectTimeout(timeout) | ||
.readTimeout(timeout) | ||
.writeTimeout(timeout) | ||
.logRequests(logRequests) | ||
.logResponses(logResponses) | ||
.build(); | ||
this.modelName = ensureNotBlank(modelName, "modelName"); | ||
this.temperature = temperature; | ||
this.topP = topP; | ||
this.maxTokens = maxTokens; | ||
} | ||
|
||
@Override | ||
public void sendMessages(List<ChatMessage> messages, StreamingResponseHandler handler) { | ||
sendMessages(messages, null, handler); | ||
} | ||
|
||
@Override | ||
public void sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler handler) { | ||
ChatCompletionRequest request = ChatCompletionRequest.builder() | ||
.stream(true) | ||
.model(modelName) | ||
.messages(toOpenAiMessages(messages)) | ||
.functions(toFunctions(toolSpecifications)) | ||
.temperature(temperature) | ||
.topP(topP) | ||
.maxTokens(maxTokens) | ||
.build(); | ||
|
||
client.chatCompletion(request) | ||
.onPartialResponse(partialResponse -> handle(partialResponse, handler)) | ||
.onComplete(handler::onComplete) | ||
.onError(handler::onError) | ||
.execute(); | ||
} | ||
|
||
private static void handle(ChatCompletionResponse partialResponse, | ||
StreamingResponseHandler handler) { | ||
Delta delta = partialResponse.choices().get(0).delta(); | ||
String content = delta.content(); | ||
FunctionCall functionCall = delta.functionCall(); | ||
if (content != null) { | ||
handler.onNext(content); | ||
} else if (functionCall != null) { | ||
if (functionCall.name() != null) { | ||
handler.onToolName(functionCall.name()); | ||
} else if (functionCall.arguments() != null) { | ||
handler.onToolArguments(functionCall.arguments()); | ||
} | ||
} | ||
} | ||
} |
73 changes: 73 additions & 0 deletions
73
langchain4j/src/main/java/dev/langchain4j/model/localai/LocalAiStreamingLanguageModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
package dev.langchain4j.model.localai; | ||
|
||
import dev.ai4j.openai4j.OpenAiClient; | ||
import dev.ai4j.openai4j.completion.CompletionRequest; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.language.StreamingLanguageModel; | ||
import lombok.Builder; | ||
|
||
import java.time.Duration; | ||
|
||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
import static java.time.Duration.ofSeconds; | ||
|
||
public class LocalAiStreamingLanguageModel implements StreamingLanguageModel { | ||
|
||
private final OpenAiClient client; | ||
private final String modelName; | ||
private final Double temperature; | ||
private final Double topP; | ||
private final Integer maxTokens; | ||
|
||
@Builder | ||
public LocalAiStreamingLanguageModel(String baseUrl, | ||
String modelName, | ||
Double temperature, | ||
Double topP, | ||
Integer maxTokens, | ||
Duration timeout, | ||
Boolean logRequests, | ||
Boolean logResponses) { | ||
|
||
temperature = temperature == null ? 0.7 : temperature; | ||
timeout = timeout == null ? ofSeconds(60) : timeout; | ||
|
||
this.client = OpenAiClient.builder() | ||
.apiKey("ignored") | ||
.url(ensureNotBlank(baseUrl, "baseUrl")) | ||
.callTimeout(timeout) | ||
.connectTimeout(timeout) | ||
.readTimeout(timeout) | ||
.writeTimeout(timeout) | ||
.logRequests(logRequests) | ||
.logResponses(logResponses) | ||
.build(); | ||
this.modelName = ensureNotBlank(modelName, "modelName"); | ||
this.temperature = temperature; | ||
this.topP = topP; | ||
this.maxTokens = maxTokens; | ||
} | ||
|
||
@Override | ||
public void process(String text, StreamingResponseHandler handler) { | ||
|
||
CompletionRequest request = CompletionRequest.builder() | ||
.model(modelName) | ||
.prompt(text) | ||
.temperature(temperature) | ||
.topP(topP) | ||
.maxTokens(maxTokens) | ||
.build(); | ||
|
||
client.completion(request) | ||
.onPartialResponse(partialResponse -> { | ||
String partialResponseText = partialResponse.text(); | ||
if (partialResponseText != null) { | ||
handler.onNext(partialResponseText); | ||
} | ||
}) | ||
.onComplete(handler::onComplete) | ||
.onError(handler::onError) | ||
.execute(); | ||
} | ||
} |
Oops, something went wrong.