Skip to content

Commit

Permalink
WIP: Register tools programmatically
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j committed Jun 19, 2024
1 parent 86fbc5b commit fae22cb
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package dev.langchain4j.agent.tool;

import lombok.AllArgsConstructor;

import java.util.function.Function;

/**
* TODO
*/
@AllArgsConstructor // TODO
public class ToolSomething {
// TODO name
// TODO location

private final String name;
private final String description;
private final Class<?> argumentClass; // TODO needed? can get from function?
private final Function<?, ?> function;

public String name() {
return name;
}

public String description() {
return description;
}

public Class<?> argumentClass() {
return argumentClass;
}

public Function<?, ?> function() {
return function;
}

// TODO ctor? builder?
public static <T> ToolSomething from(String name,
String description,
Class<T> argumentClass,
Function<T, ?> function) {
return new ToolSomething(name, description, argumentClass, function);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ public static ToolSpecification toolSpecificationFrom(Method method) {
return builder.build();
}

public static ToolSpecification toolSpecificationFrom(ToolSomething tool) {

return ToolSpecification.builder()
.name(tool.name())
.description(tool.description())
.parameters(tool)
.build();
}

/**
* Convert a {@link Parameter} to a {@link JsonSchemaProperty}.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package dev.langchain4j.agent.tool;

public class DefaultFunctionToolExecutor implements ToolExecutor {



@Override
public String execute(ToolExecutionRequest toolExecutionRequest, Object memoryId) {
return null;
}
}
48 changes: 45 additions & 3 deletions langchain4j/src/main/java/dev/langchain4j/service/AiServices.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dev.langchain4j.agent.tool.DefaultToolExecutor;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSomething;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
Expand Down Expand Up @@ -37,6 +38,7 @@
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.util.Arrays.asList;
import static java.util.stream.Collectors.toList;

/**
Expand Down Expand Up @@ -303,7 +305,7 @@ public AiServices<T> moderationModel(ModerationModel moderationModel) {
* @see Tool
*/
public AiServices<T> tools(Object... objectsWithTools) {
return tools(Arrays.asList(objectsWithTools));
return tools(asList(objectsWithTools));
}

/**
Expand All @@ -318,8 +320,13 @@ public AiServices<T> tools(Object... objectsWithTools) {
*/
public AiServices<T> tools(List<Object> objectsWithTools) { // TODO Collection?
// TODO validate uniqueness of tool names
context.toolSpecifications = new ArrayList<>();
context.toolExecutors = new HashMap<>();

if (context.toolSpecifications == null) {
context.toolSpecifications = new ArrayList<>();
}
if (context.toolExecutors == null) {
context.toolExecutors = new HashMap<>();
}

for (Object objectWithTool : objectsWithTools) {
if (objectWithTool instanceof Class) {
Expand All @@ -338,6 +345,41 @@ public AiServices<T> tools(List<Object> objectsWithTools) { // TODO Collection?
return this;
}

/**
* TODO
*
* @param tools
* @return
*/
public AiServices<T> tools(ToolSomething... tools) {
return tools((Collection<ToolSomething>) asList(tools));
}

/**
* TODO
*
* @param tools
* @return
*/
public AiServices<T> tools(Collection<ToolSomething> tools) {

// TODO validate uniqueness of tool names

if (context.toolSpecifications == null) {
context.toolSpecifications = new ArrayList<>();
}
if (context.toolExecutors == null) {
context.toolExecutors = new HashMap<>();
}

for (ToolSomething tool : tools) {

context.toolSpecifications.add(toolSpecification);
}

return this;
}

/**
* Deprecated. Use {@link #contentRetriever(ContentRetriever)}
* (e.g. {@link EmbeddingStoreContentRetriever}) instead.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package dev.langchain4j.service;

import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.agent.tool.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
Expand All @@ -14,14 +11,17 @@
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.output.structured.Description;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.junit.jupiter.MockitoExtension;

import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Stream;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.description;
Expand Down Expand Up @@ -109,6 +109,61 @@ static class TransactionService {
}
}

static class UserDetails {

@Description("")
String name;

@Description("")
String surname;

@Description("")
String bookingNumber;
}

static class BookingChecker implements Function<UserDetails, String> {

@Override
public String apply(UserDetails userDetails) {
if (userDetails.bookingNumber.equals("123-456")) {
return "Booking found. Booking period: 1 July 2024 - 10 July 2024";
} else {
return "Booking not found";
}
}
}



@Test
void test() {

OpenAiChatModel model = OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();

ToolSomething.from(
"get_booking_details",
"Get booking details",
UserDetails.class, // single param / multiple params. types? Map? How to describe params?
new BookingChecker()
);

// TODO how to register them in Spring boot app? Return a list of those as a bean?

AiServices.builder(Assistant.class)
.chatLanguageModel(model)
.tools()
.build();

}


@ParameterizedTest
@MethodSource("models")
void should_execute_a_tool_then_answer(ChatLanguageModel chatLanguageModel) {
Expand Down

0 comments on commit fae22cb

Please sign in to comment.