forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ollama Tools Support - only in OllamaChatModel (langchain4j#1558)
## Issue Closes langchain4j#1525 Closes langchain4j#318 ## Change Add new DTOs to support Tool/Function calling. Add new integration test and model configuration that supports Tool Calls - Mistral Modify some class to support Tool calling, for example `OllamaChatModel`, `Role` As stated here https://ollama.com/blog/tool-support, tools support is not possible in streaming scenario - for now I got the best results with the Mistral model, but without a SystemMessage, it always wants to call a tool, even if it doesn't make any sense. With Llama 3.1, it properly returns a ToolCallRequest but ignores the result. ## General checklist <!-- Please double-check the following points and mark them like this: [X] --> - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
- Loading branch information
Showing
13 changed files
with
296 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Function.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package dev.langchain4j.model.ollama; | ||
|
||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.databind.PropertyNamingStrategies; | ||
import com.fasterxml.jackson.databind.annotation.JsonNaming; | ||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Data; | ||
import lombok.NoArgsConstructor; | ||
|
||
import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; | ||
|
||
@Data | ||
@NoArgsConstructor | ||
@AllArgsConstructor | ||
@Builder | ||
@JsonIgnoreProperties(ignoreUnknown = true) | ||
@JsonInclude(NON_NULL) | ||
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) | ||
class Function { | ||
private String name; | ||
private String description; | ||
private Parameters parameters; | ||
} |
26 changes: 26 additions & 0 deletions
26
langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/FunctionCall.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package dev.langchain4j.model.ollama; | ||
|
||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.databind.PropertyNamingStrategies; | ||
import com.fasterxml.jackson.databind.annotation.JsonNaming; | ||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Data; | ||
import lombok.NoArgsConstructor; | ||
|
||
import java.util.Map; | ||
|
||
import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; | ||
|
||
@Data | ||
@NoArgsConstructor | ||
@AllArgsConstructor | ||
@Builder | ||
@JsonIgnoreProperties(ignoreUnknown = true) | ||
@JsonInclude(NON_NULL) | ||
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) | ||
class FunctionCall { | ||
private String name; | ||
private Map<String, Object> arguments; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
28 changes: 28 additions & 0 deletions
28
langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Parameters.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package dev.langchain4j.model.ollama; | ||
|
||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.databind.PropertyNamingStrategies; | ||
import com.fasterxml.jackson.databind.annotation.JsonNaming; | ||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Data; | ||
import lombok.NoArgsConstructor; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; | ||
|
||
@Data | ||
@NoArgsConstructor | ||
@AllArgsConstructor | ||
@Builder | ||
@JsonIgnoreProperties(ignoreUnknown = true) | ||
@JsonInclude(NON_NULL) | ||
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) | ||
class Parameters { | ||
private String type = "object"; | ||
private Map<String, Map<String,Object>> properties; | ||
private List<String> required; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,8 @@ enum Role { | |
|
||
SYSTEM, | ||
USER, | ||
ASSISTANT; | ||
ASSISTANT, | ||
TOOL; | ||
|
||
@JsonValue | ||
public String serialize() { | ||
|
24 changes: 24 additions & 0 deletions
24
langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Tool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
package dev.langchain4j.model.ollama; | ||
|
||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.databind.PropertyNamingStrategies; | ||
import com.fasterxml.jackson.databind.annotation.JsonNaming; | ||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Data; | ||
import lombok.NoArgsConstructor; | ||
|
||
import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; | ||
|
||
@Data | ||
@NoArgsConstructor | ||
@AllArgsConstructor | ||
@Builder | ||
@JsonIgnoreProperties(ignoreUnknown = true) | ||
@JsonInclude(NON_NULL) | ||
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) | ||
class Tool { | ||
private final String type = "function"; | ||
private Function function; | ||
} |
23 changes: 23 additions & 0 deletions
23
langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ToolCall.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package dev.langchain4j.model.ollama; | ||
|
||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.databind.PropertyNamingStrategies; | ||
import com.fasterxml.jackson.databind.annotation.JsonNaming; | ||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Data; | ||
import lombok.NoArgsConstructor; | ||
|
||
import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; | ||
|
||
@Data | ||
@NoArgsConstructor | ||
@AllArgsConstructor | ||
@Builder | ||
@JsonIgnoreProperties(ignoreUnknown = true) | ||
@JsonInclude(NON_NULL) | ||
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) | ||
class ToolCall { | ||
private FunctionCall function; | ||
} |
18 changes: 18 additions & 0 deletions
18
...est/java/dev/langchain4j/model/ollama/AbstractOllamaToolsLanguageModelInfrastructure.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
package dev.langchain4j.model.ollama; | ||
|
||
class AbstractOllamaToolsLanguageModelInfrastructure { | ||
|
||
private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.TOOL_MODEL); | ||
|
||
static LangChain4jOllamaContainer ollama; | ||
|
||
static { | ||
ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE)) | ||
.withModel(OllamaImage.TOOL_MODEL); | ||
ollama.start(); | ||
ollama.commitToImage(LOCAL_OLLAMA_IMAGE); | ||
} | ||
|
||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
93 changes: 93 additions & 0 deletions
93
langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaToolChatModelIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package dev.langchain4j.model.ollama; | ||
|
||
import dev.langchain4j.agent.tool.ToolExecutionRequest; | ||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.ToolExecutionResultMessage; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.util.List; | ||
|
||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*; | ||
import static dev.langchain4j.data.message.SystemMessage.systemMessage; | ||
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; | ||
import static dev.langchain4j.data.message.UserMessage.userMessage; | ||
import static dev.langchain4j.model.ollama.OllamaImage.TOOL_MODEL; | ||
import static java.util.Arrays.asList; | ||
import static java.util.Collections.singletonList; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
class OllamaToolChatModelIT extends AbstractOllamaToolsLanguageModelInfrastructure { | ||
|
||
|
||
ToolSpecification weatherToolSpecification = ToolSpecification.builder() | ||
.name("get_current_weather") | ||
.description("Get the current weather for a location") | ||
.addParameter("format", STRING, enums("celsius", "fahrenheit"), description("The format to return the weather in, e.g. 'celsius' or 'fahrenheit'")) | ||
.addParameter("location", STRING, description("The location to get the weather for, e.g. San Francisco, CA")) | ||
.build(); | ||
|
||
ChatLanguageModel ollamaChatModel = OllamaChatModel.builder() | ||
.baseUrl(ollama.getEndpoint()) | ||
.modelName(TOOL_MODEL) | ||
.logRequests(true) | ||
.logResponses(true) | ||
.build(); | ||
|
||
@Test | ||
void should_execute_a_tool_then_answer() { | ||
|
||
// given | ||
UserMessage userMessage = userMessage("What is the weather today in Paris?"); | ||
List<ToolSpecification> toolSpecifications = singletonList(weatherToolSpecification); | ||
|
||
// when | ||
Response<AiMessage> response = ollamaChatModel.generate(singletonList(userMessage), toolSpecifications); | ||
|
||
// then | ||
AiMessage aiMessage = response.content(); | ||
assertThat(aiMessage.text()).isNull(); | ||
assertThat(aiMessage.toolExecutionRequests()).hasSize(1); | ||
|
||
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); | ||
assertThat(toolExecutionRequest.name()).isEqualTo("get_current_weather"); | ||
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"format\": \"celsius\", \"location\": \"Paris\"}"); | ||
|
||
// given | ||
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "{\"format\": \"celsius\", \"location\": \"Paris\", \"temperature\": \"32\"}"); | ||
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage); | ||
|
||
// when | ||
Response<AiMessage> secondResponse = ollamaChatModel.generate(messages); | ||
|
||
// then | ||
AiMessage secondAiMessage = secondResponse.content(); | ||
assertThat(secondAiMessage.text()).contains("32"); | ||
assertThat(secondAiMessage.toolExecutionRequests()).isNull(); | ||
} | ||
|
||
|
||
@Test | ||
void should_not_execute_a_tool_and_tell_a_joke() { | ||
|
||
// given | ||
List<ToolSpecification> toolSpecifications = singletonList(weatherToolSpecification); | ||
|
||
// when | ||
List<ChatMessage> chatMessages = asList( | ||
systemMessage("Use tools only if needed"), | ||
userMessage("Tell a joke") | ||
); | ||
Response<AiMessage> response = ollamaChatModel.generate(chatMessages, toolSpecifications); | ||
|
||
// then | ||
AiMessage aiMessage = response.content(); | ||
assertThat(aiMessage.text()).isNotNull(); | ||
assertThat(aiMessage.toolExecutionRequests()).isNull(); | ||
} | ||
|
||
} |