Skip to content

Commit

Permalink
Ollama Tools Support - only in OllamaChatModel (langchain4j#1558)
Browse files Browse the repository at this point in the history
## Issue
Closes langchain4j#1525
Closes langchain4j#318

## Change
Add new DTOs to support Tool/Function calling.
Add new integration test and model configuration that supports Tool
Calls - Mistral
Modify some class to support Tool calling, for example
`OllamaChatModel`, `Role`

As stated here https://ollama.com/blog/tool-support, tools support is
not possible in streaming scenario - for now

I got the best results with the Mistral model, but without a
SystemMessage, it always wants to call a tool, even if it doesn't make
any sense. With Llama 3.1, it properly returns a ToolCallRequest but
ignores the result.

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
bidek authored Aug 16, 2024
1 parent 5ec127d commit 8618d30
Show file tree
Hide file tree
Showing 13 changed files with 296 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ class ChatRequest {
private Options options;
private String format;
private Boolean stream;
private List<Tool> tools;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package dev.langchain4j.model.ollama;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;

@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
class Function {
private String name;
private String description;
private Parameters parameters;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package dev.langchain4j.model.ollama;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.Map;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;

@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
class FunctionCall {
private String name;
private Map<String, Object> arguments;
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ class Message {
private Role role;
private String content;
private List<String> images;
private List<ToolCall> toolCalls;
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.langchain4j.model.ollama;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand All @@ -16,7 +17,7 @@
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.model.ollama.OllamaMessagesUtils.toOllamaMessages;
import static dev.langchain4j.model.ollama.OllamaMessagesUtils.*;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.time.Duration.ofSeconds;

Expand Down Expand Up @@ -92,6 +93,29 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
ensureNotEmpty(messages, "messages");

ChatRequest request = ChatRequest.builder()
.model(modelName)
.messages(toOllamaMessages(messages))
.options(options)
.format(format)
.stream(false)
.tools(toOllamaTools(toolSpecifications))
.build();

ChatResponse response = withRetry(() -> client.chat(request), maxRetries);

return Response.from(
response.getMessage().getToolCalls() != null ?
AiMessage.from(toToolExecutionRequest(response.getMessage().getToolCalls())) :
AiMessage.from(response.getMessage().getContent()),
new TokenUsage(response.getPromptEvalCount(), response.getEvalCount())
);
}

public static OllamaChatModelBuilder builder() {
for (OllamaChatModelBuilderFactory factory : loadFactories(OllamaChatModelBuilderFactory.class)) {
return factory.get();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package dev.langchain4j.model.ollama;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.internal.Json;

import java.util.List;
import java.util.Map;
Expand All @@ -26,6 +29,30 @@ static List<Message> toOllamaMessages(List<ChatMessage> messages) {
).collect(Collectors.toList());
}

static List<Tool> toOllamaTools(List<ToolSpecification> toolSpecifications) {
return toolSpecifications.stream().map(toolSpecification ->
Tool.builder()
.function(Function.builder()
.name(toolSpecification.name())
.description(toolSpecification.description())
.parameters(Parameters.builder()
.properties(toolSpecification.parameters().properties())
.required(toolSpecification.parameters().required())
.build())
.build())
.build())
.collect(Collectors.toList());
}

static List<ToolExecutionRequest> toToolExecutionRequest(List<ToolCall> toolCalls) {
return toolCalls.stream().map(toolCall ->
ToolExecutionRequest.builder()
.name(toolCall.getFunction().getName())
.arguments(Json.toJson(toolCall.getFunction().getArguments()))
.build())
.collect(Collectors.toList());
}

private static Message messagesWithImageSupport(UserMessage userMessage) {
Map<ContentType, List<Content>> groupedContents = userMessage.contents().stream()
.collect(Collectors.groupingBy(Content::type));
Expand Down Expand Up @@ -62,6 +89,8 @@ private static Role toOllamaRole(ChatMessageType chatMessageType) {
return Role.USER;
case AI:
return Role.ASSISTANT;
case TOOL_EXECUTION_RESULT:
return Role.TOOL;
default:
throw new IllegalArgumentException("Unknown ChatMessageType: " + chatMessageType);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dev.langchain4j.model.ollama;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;
import java.util.Map;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;

@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
class Parameters {
private String type = "object";
private Map<String, Map<String,Object>> properties;
private List<String> required;
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ enum Role {

SYSTEM,
USER,
ASSISTANT;
ASSISTANT,
TOOL;

@JsonValue
public String serialize() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package dev.langchain4j.model.ollama;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;

@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
class Tool {
private final String type = "function";
private Function function;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package dev.langchain4j.model.ollama;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;

@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
class ToolCall {
private FunctionCall function;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package dev.langchain4j.model.ollama;

class AbstractOllamaToolsLanguageModelInfrastructure {

private static final String LOCAL_OLLAMA_IMAGE = String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, OllamaImage.TOOL_MODEL);

static LangChain4jOllamaContainer ollama;

static {
ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OllamaImage.OLLAMA_IMAGE, LOCAL_OLLAMA_IMAGE))
.withModel(OllamaImage.TOOL_MODEL);
ollama.start();
ollama.commitToImage(LOCAL_OLLAMA_IMAGE);
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public class OllamaImage {
static final String BAKLLAVA_MODEL = "bakllava";

static final String TINY_DOLPHIN_MODEL = "tinydolphin";
static final String TOOL_MODEL = "mistral";

static final String ALL_MINILM_MODEL = "all-minilm";

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package dev.langchain4j.model.ollama;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Test;

import java.util.List;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
import static dev.langchain4j.data.message.SystemMessage.systemMessage;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.ollama.OllamaImage.TOOL_MODEL;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;

class OllamaToolChatModelIT extends AbstractOllamaToolsLanguageModelInfrastructure {


ToolSpecification weatherToolSpecification = ToolSpecification.builder()
.name("get_current_weather")
.description("Get the current weather for a location")
.addParameter("format", STRING, enums("celsius", "fahrenheit"), description("The format to return the weather in, e.g. 'celsius' or 'fahrenheit'"))
.addParameter("location", STRING, description("The location to get the weather for, e.g. San Francisco, CA"))
.build();

ChatLanguageModel ollamaChatModel = OllamaChatModel.builder()
.baseUrl(ollama.getEndpoint())
.modelName(TOOL_MODEL)
.logRequests(true)
.logResponses(true)
.build();

@Test
void should_execute_a_tool_then_answer() {

// given
UserMessage userMessage = userMessage("What is the weather today in Paris?");
List<ToolSpecification> toolSpecifications = singletonList(weatherToolSpecification);

// when
Response<AiMessage> response = ollamaChatModel.generate(singletonList(userMessage), toolSpecifications);

// then
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);

ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("get_current_weather");
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"format\": \"celsius\", \"location\": \"Paris\"}");

// given
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "{\"format\": \"celsius\", \"location\": \"Paris\", \"temperature\": \"32\"}");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);

// when
Response<AiMessage> secondResponse = ollamaChatModel.generate(messages);

// then
AiMessage secondAiMessage = secondResponse.content();
assertThat(secondAiMessage.text()).contains("32");
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
}


@Test
void should_not_execute_a_tool_and_tell_a_joke() {

// given
List<ToolSpecification> toolSpecifications = singletonList(weatherToolSpecification);

// when
List<ChatMessage> chatMessages = asList(
systemMessage("Use tools only if needed"),
userMessage("Tell a joke")
);
Response<AiMessage> response = ollamaChatModel.generate(chatMessages, toolSpecifications);

// then
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNotNull();
assertThat(aiMessage.toolExecutionRequests()).isNull();
}

}

0 comments on commit 8618d30

Please sign in to comment.