Skip to content

Commit

Permalink
Check for duplicate method names in a list of tool specifications (la…
Browse files Browse the repository at this point in the history
…ngchain4j#1375)

## Issue

When a class has several annotated methods with `@Tools` that have the
same name (but different parameters), the model does not know which one
to invoke. So the idea if to check for method name duplicates and throw
an exception.

## Change

So now, the `toolSpecificationsFrom` method invokes a
`validateSpecifications`. For now, this method only checks that there is
no duplicate method names (but it could be used for other checks in the
future if needed). The `validateSpecifications` method returns void if
the list of tool specifications is valid. If not, it throws an
`IllegalArgumentException`.

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
  • Loading branch information
agoncal authored Jul 1, 2024
1 parent 7cda3ec commit 544526b
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
package dev.langchain4j.agent.tool;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.ARRAY;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.BOOLEAN;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.NUMBER;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.description;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.enums;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.from;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.items;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.objectItems;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import dev.langchain4j.model.output.structured.Description;
import static java.lang.String.format;
import static java.util.Arrays.stream;
import static java.util.stream.Collectors.toList;

import java.lang.reflect.*;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.*;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static java.util.Arrays.stream;
import static java.util.stream.Collectors.toList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/**
* Utility methods for {@link ToolSpecification}s.
Expand All @@ -27,10 +48,12 @@ private ToolSpecifications() {
* @return the {@link ToolSpecification}s.
*/
public static List<ToolSpecification> toolSpecificationsFrom(Class<?> classWithTools) {
return stream(classWithTools.getDeclaredMethods())
List<ToolSpecification> toolSpecifications = stream(classWithTools.getDeclaredMethods())
.filter(method -> method.isAnnotationPresent(Tool.class))
.map(ToolSpecifications::toolSpecificationFrom)
.collect(toList());
validateSpecifications(toolSpecifications);
return toolSpecifications;
}

/**
Expand All @@ -44,6 +67,23 @@ public static List<ToolSpecification> toolSpecificationsFrom(Object objectWithTo
return toolSpecificationsFrom(objectWithTools.getClass());
}

/**
* Validates all the {@link ToolSpecification}s. The validation checks for duplicate method names.
* Throws {@link IllegalArgumentException} if validation fails
*
* @param toolSpecifications list of ToolSpecification to be validated.
*/
public static void validateSpecifications(List<ToolSpecification> toolSpecifications) throws IllegalArgumentException {

// Checks for duplicates methods
Set<String> names = new HashSet<>();
for (ToolSpecification toolSpecification : toolSpecifications) {
if (!names.add(toolSpecification.name())) {
throw new IllegalArgumentException(format("Tool names must be unique. The tool '%s' appears several times", toolSpecification.name()));
}
}
}

/**
* Returns the {@link ToolSpecification} for the given method annotated with @{@link Tool}.
*
Expand Down Expand Up @@ -106,7 +146,7 @@ static Iterable<JsonSchemaProperty> toJsonSchemaProperties(Parameter parameter)
return removeNulls(OBJECT, schema(type), description);
}

static JsonSchemaProperty schema(Class<?> structured){
static JsonSchemaProperty schema(Class<?> structured) {
return schema(structured, new HashSet<>());
}

Expand All @@ -116,21 +156,21 @@ private static JsonSchemaProperty schema(Class<?> structured, Set<Class<?>> visi
}

visited.add(structured);
Map<String,Object> properties = new HashMap<>();
Map<String, Object> properties = new HashMap<>();
for (Field field : structured.getDeclaredFields()) {
String name = field.getName();
if ( name.equals("this$0") || java.lang.reflect.Modifier.isStatic(field.getModifiers())) {
if (name.equals("this$0") || java.lang.reflect.Modifier.isStatic(field.getModifiers())) {
// Skip inner class reference.
continue;
}
Iterable<JsonSchemaProperty> schemaProperties = toJsonSchemaProperties(field, visited);
Map<Object,Object> objectMap = new HashMap<>();
for(JsonSchemaProperty jsonSchemaProperty : schemaProperties) {
Map<Object, Object> objectMap = new HashMap<>();
for (JsonSchemaProperty jsonSchemaProperty : schemaProperties) {
objectMap.put(jsonSchemaProperty.key(), jsonSchemaProperty.value());
}
properties.put(name, objectMap);
}
return from( "properties", properties );
return from("properties", properties);
}

private static Iterable<JsonSchemaProperty> toJsonSchemaProperties(Field field, Set<Class<?>> visited) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package dev.langchain4j.agent.tool;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.items;

import dev.langchain4j.model.output.structured.Description;
import static java.util.Arrays.asList;
import lombok.Data;
import org.assertj.core.api.WithAssertions;
import org.junit.jupiter.api.Test;
Expand All @@ -11,10 +11,12 @@
import java.lang.reflect.Parameter;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.*;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.items;
import static java.util.Arrays.asList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

class ToolSpecificationsTest implements WithAssertions {

Expand Down Expand Up @@ -95,6 +97,34 @@ public int unused(int i) {
}
}

@SuppressWarnings("unused")
public static class InvalidToolsWithDuplicateMethodNames {

@Tool
public int duplicateMethod(String typeString) {
return 42;
}

@Tool
public int duplicateMethod(int typeInt) {
return 42;
}
}

@SuppressWarnings("unused")
public static class InvalidToolsWithDuplicateNames {

@Tool(name = "duplicate_name")
public int oneMethod(String typeString) {
return 42;
}

@Tool(name = "duplicate_name")
public int aDifferentMethod(int typeInt) {
return 42;
}
}

private static Method getF() throws NoSuchMethodException {
return Wrapper.class.getMethod("f",
String.class,//0
Expand Down Expand Up @@ -173,6 +203,24 @@ public void test_toolSpecificationsFrom() {
.containsExactlyInAnyOrder("f", "func_name");
}

@Test
public void test_toolSpecificationsFrom_with_duplicate_method_names() {
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> ToolSpecifications.toolSpecificationsFrom(new InvalidToolsWithDuplicateMethodNames()))
.withMessage("Tool names must be unique. The tool 'duplicateMethod' appears several times")
.withNoCause();

}

@Test
public void test_toolSpecificationsFrom_with_duplicate_names() {
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> ToolSpecifications.toolSpecificationsFrom(new InvalidToolsWithDuplicateNames()))
.withMessage("Tool names must be unique. The tool 'duplicate_name' appears several times")
.withNoCause();

}

@Test
public void test_toolName_memoryId() throws NoSuchMethodException {
Method method = Wrapper.class.getMethod("g", String.class);
Expand Down

0 comments on commit 544526b

Please sign in to comment.