diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java index ab20a4ed99e..26571f7f110 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ChatRequest.java @@ -27,4 +27,5 @@ class ChatRequest { private Options options; private String format; private Boolean stream; + private List tools; } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Function.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Function.java new file mode 100644 index 00000000000..29867188503 --- /dev/null +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Function.java @@ -0,0 +1,25 @@ +package dev.langchain4j.model.ollama; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +class Function { + private String name; + private String description; + private Parameters parameters; +} diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/FunctionCall.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/FunctionCall.java new file mode 100644 index 00000000000..adecf243eaa --- /dev/null +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/FunctionCall.java @@ -0,0 +1,26 @@ +package dev.langchain4j.model.ollama; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.Map; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +class FunctionCall { + private String name; + private Map arguments; +} diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java index fd7cb358018..3e123f2c445 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Message.java @@ -25,4 +25,5 @@ class Message { private Role role; private String content; private List images; + private List toolCalls; } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java index 6bba995057f..0366fd36348 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java @@ -1,5 +1,6 @@ package dev.langchain4j.model.ollama; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -16,7 +17,7 @@ import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; -import static dev.langchain4j.model.ollama.OllamaMessagesUtils.toOllamaMessages; +import static dev.langchain4j.model.ollama.OllamaMessagesUtils.*; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.time.Duration.ofSeconds; @@ -92,6 +93,29 @@ public Response generate(List messages) { ); } + @Override + public Response generate(List messages, List toolSpecifications) { + ensureNotEmpty(messages, "messages"); + + ChatRequest request = ChatRequest.builder() + .model(modelName) + .messages(toOllamaMessages(messages)) + .options(options) + .format(format) + .stream(false) + .tools(toOllamaTools(toolSpecifications)) + .build(); + + ChatResponse response = withRetry(() -> client.chat(request), maxRetries); + + return Response.from( + response.getMessage().getToolCalls() != null ? + AiMessage.from(toToolExecutionRequest(response.getMessage().getToolCalls())) : + AiMessage.from(response.getMessage().getContent()), + new TokenUsage(response.getPromptEvalCount(), response.getEvalCount()) + ); + } + public static OllamaChatModelBuilder builder() { for (OllamaChatModelBuilderFactory factory : loadFactories(OllamaChatModelBuilderFactory.class)) { return factory.get(); diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaMessagesUtils.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaMessagesUtils.java index a729cc8fd3e..b5516fd2090 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaMessagesUtils.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaMessagesUtils.java @@ -1,6 +1,9 @@ package dev.langchain4j.model.ollama; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.*; +import dev.langchain4j.internal.Json; import java.util.List; import java.util.Map; @@ -26,6 +29,30 @@ static List toOllamaMessages(List messages) { ).collect(Collectors.toList()); } + static List toOllamaTools(List toolSpecifications) { + return toolSpecifications.stream().map(toolSpecification -> + Tool.builder() + .function(Function.builder() + .name(toolSpecification.name()) + .description(toolSpecification.description()) + .parameters(Parameters.builder() + .properties(toolSpecification.parameters().properties()) + .required(toolSpecification.parameters().required()) + .build()) + .build()) + .build()) + .collect(Collectors.toList()); + } + + static List toToolExecutionRequest(List toolCalls) { + return toolCalls.stream().map(toolCall -> + ToolExecutionRequest.builder() + .name(toolCall.getFunction().getName()) + .arguments(Json.toJson(toolCall.getFunction().getArguments())) + .build()) + .collect(Collectors.toList()); + } + private static Message messagesWithImageSupport(UserMessage userMessage) { Map> groupedContents = userMessage.contents().stream() .collect(Collectors.groupingBy(Content::type)); @@ -62,6 +89,8 @@ private static Role toOllamaRole(ChatMessageType chatMessageType) { return Role.USER; case AI: return Role.ASSISTANT; + case TOOL_EXECUTION_RESULT: + return Role.TOOL; default: throw new IllegalArgumentException("Unknown ChatMessageType: " + chatMessageType); } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Parameters.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Parameters.java new file mode 100644 index 00000000000..4b769e31823 --- /dev/null +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Parameters.java @@ -0,0 +1,28 @@ +package dev.langchain4j.model.ollama; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.Map; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +class Parameters { + private String type = "object"; + private Map> properties; + private List required; +} diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java index e32e6aadbf6..ffe6855a95e 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Role.java @@ -8,7 +8,8 @@ enum Role { SYSTEM, USER, - ASSISTANT; + ASSISTANT, + TOOL; @JsonValue public String serialize() { diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Tool.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Tool.java new file mode 100644 index 00000000000..8990c7aa336 --- /dev/null +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/Tool.java @@ -0,0 +1,24 @@ +package dev.langchain4j.model.ollama; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +class Tool { + private final String type = "function"; + private Function function; +} diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ToolCall.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ToolCall.java new file mode 100644 index 00000000000..cfed740268d --- /dev/null +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/ToolCall.java @@ -0,0 +1,23 @@ +package dev.langchain4j.model.ollama; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +class ToolCall { + private FunctionCall function; +} diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaToolsLanguageModelInfrastructure.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaToolsLanguageModelInfrastructure.java new file mode 100644 index 00000000000..e87d36fccdb --- /dev/null +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/AbstractOllamaToolsLanguageModelInfrastructure.java @@ -0,0 +1,18 @@ +package dev.langchain4j.model.ollama; + +class AbstractOllamaToolsLanguageModelInfrastructure { + + private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.TOOL_MODEL); + + static LangChain4jOllamaContainer ollama; + + static { + ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE)) + .withModel(OllamaImage.TOOL_MODEL); + ollama.start(); + ollama.commitToImage(LOCAL_OLLAMA_IMAGE); + } + + + +} diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaImage.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaImage.java index c7628267a02..a6f7746432e 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaImage.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaImage.java @@ -14,6 +14,7 @@ public class OllamaImage { static final String BAKLLAVA_MODEL = "bakllava"; static final String TINY_DOLPHIN_MODEL = "tinydolphin"; + static final String TOOL_MODEL = "mistral"; static final String ALL_MINILM_MODEL = "all-minilm"; diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaToolChatModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaToolChatModelIT.java new file mode 100644 index 00000000000..5f4c944bbfa --- /dev/null +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaToolChatModelIT.java @@ -0,0 +1,93 @@ +package dev.langchain4j.model.ollama; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static dev.langchain4j.agent.tool.JsonSchemaProperty.*; +import static dev.langchain4j.data.message.SystemMessage.systemMessage; +import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static dev.langchain4j.model.ollama.OllamaImage.TOOL_MODEL; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; + +class OllamaToolChatModelIT extends AbstractOllamaToolsLanguageModelInfrastructure { + + + ToolSpecification weatherToolSpecification = ToolSpecification.builder() + .name("get_current_weather") + .description("Get the current weather for a location") + .addParameter("format", STRING, enums("celsius", "fahrenheit"), description("The format to return the weather in, e.g. 'celsius' or 'fahrenheit'")) + .addParameter("location", STRING, description("The location to get the weather for, e.g. San Francisco, CA")) + .build(); + + ChatLanguageModel ollamaChatModel = OllamaChatModel.builder() + .baseUrl(ollama.getEndpoint()) + .modelName(TOOL_MODEL) + .logRequests(true) + .logResponses(true) + .build(); + + @Test + void should_execute_a_tool_then_answer() { + + // given + UserMessage userMessage = userMessage("What is the weather today in Paris?"); + List toolSpecifications = singletonList(weatherToolSpecification); + + // when + Response response = ollamaChatModel.generate(singletonList(userMessage), toolSpecifications); + + // then + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + assertThat(aiMessage.toolExecutionRequests()).hasSize(1); + + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.name()).isEqualTo("get_current_weather"); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"format\": \"celsius\", \"location\": \"Paris\"}"); + + // given + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "{\"format\": \"celsius\", \"location\": \"Paris\", \"temperature\": \"32\"}"); + List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); + + // when + Response secondResponse = ollamaChatModel.generate(messages); + + // then + AiMessage secondAiMessage = secondResponse.content(); + assertThat(secondAiMessage.text()).contains("32"); + assertThat(secondAiMessage.toolExecutionRequests()).isNull(); + } + + + @Test + void should_not_execute_a_tool_and_tell_a_joke() { + + // given + List toolSpecifications = singletonList(weatherToolSpecification); + + // when + List chatMessages = asList( + systemMessage("Use tools only if needed"), + userMessage("Tell a joke") + ); + Response response = ollamaChatModel.generate(chatMessages, toolSpecifications); + + // then + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNotNull(); + assertThat(aiMessage.toolExecutionRequests()).isNull(); + } + +} \ No newline at end of file