Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Added logic that attempts to convert Double to the type of tool (method)
parameter
  • Loading branch information
langchain4j authored Jul 15, 2023
1 parent 755c9d0 commit 907c1eb
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
public class JsonSchemaProperty {

public static final JsonSchemaProperty STRING = type("string");
public static final JsonSchemaProperty INTEGER = type("integer");
public static final JsonSchemaProperty NUMBER = type("number");
public static final JsonSchemaProperty OBJECT = type("object");
public static final JsonSchemaProperty ARRAY = type("array");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Map;

public class ToolExecutor {
Expand Down Expand Up @@ -60,10 +62,58 @@ private Object[] prepareArguments(Map<String, Object> argumentsMap) {
for (int i = 0; i < parameters.length; i++) {
String parameterName = parameters[i].getName();
if (argumentsMap.containsKey(parameterName)) {
arguments[i] = argumentsMap.get(parameterName);
Object argument = argumentsMap.get(parameterName);
Class<?> parameterType = parameters[i].getType();

// Gson always parses numbers into the Double type. If the parameter type is not Double, a conversion attempt is made.
if (argument instanceof Double && !(parameterType == Double.class || parameterType == double.class)) {
Double doubleValue = (Double) argument;

if (parameterType == Float.class || parameterType == float.class) {
if (doubleValue < -Float.MAX_VALUE || doubleValue > Float.MAX_VALUE) {
throw new IllegalArgumentException("Double value " + doubleValue + " is out of range for the float type");
}
argument = doubleValue.floatValue();
} else if (parameterType == BigDecimal.class) {
argument = BigDecimal.valueOf(doubleValue);
}

// Allow conversion to integer types only if double value has no fractional part
if (hasNoFractionalPart(doubleValue)) {
if (parameterType == Integer.class || parameterType == int.class) {
if (doubleValue < Integer.MIN_VALUE || doubleValue > Integer.MAX_VALUE) {
throw new IllegalArgumentException("Double value " + doubleValue + " is out of range for the integer type");
}
argument = doubleValue.intValue();
} else if (parameterType == Long.class || parameterType == long.class) {
if (doubleValue < Long.MIN_VALUE || doubleValue > Long.MAX_VALUE) {
throw new IllegalArgumentException("Double value " + doubleValue + " is out of range for the long type");
}
argument = doubleValue.longValue();
} else if (parameterType == Short.class || parameterType == short.class) {
if (doubleValue < Short.MIN_VALUE || doubleValue > Short.MAX_VALUE) {
throw new IllegalArgumentException("Double value " + doubleValue + " is out of range for the short type");
}
argument = doubleValue.shortValue();
} else if (parameterType == Byte.class || parameterType == byte.class) {
if (doubleValue < Byte.MIN_VALUE || doubleValue > Byte.MAX_VALUE) {
throw new IllegalArgumentException("Double value " + doubleValue + " is out of range for the byte type");
}
argument = doubleValue.byteValue();
} else if (parameterType == BigInteger.class) {
argument = BigDecimal.valueOf(doubleValue).toBigInteger();
}
}
}

arguments[i] = argument;
}
}

return arguments;
}

private static boolean hasNoFractionalPart(Double doubleValue) {
return doubleValue.equals(Math.floor(doubleValue));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.List;
import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -50,21 +52,19 @@ private static Iterable<JsonSchemaProperty> toJsonSchemaProperties(Parameter par
return removeNulls(BOOLEAN, description);
}

if (type == byte.class || type == Byte.class
|| type == short.class || type == Short.class
|| type == int.class || type == Integer.class
|| type == long.class || type == Long.class
|| type == BigInteger.class) {
return removeNulls(INTEGER, description);
}

// TODO put constraints on min and max?
if (type == byte.class
|| type == Byte.class
|| type == short.class
|| type == Short.class
|| type == int.class
|| type == Integer.class
|| type == long.class
|| type == Long.class
|| type == float.class
|| type == Float.class
|| type == double.class
|| type == Double.class // TODO bigdecimal, etc
) {
return removeNulls(NUMBER, description); // TODO test all types!
if (type == float.class || type == Float.class
|| type == double.class || type == Double.class
|| type == BigDecimal.class) {
return removeNulls(NUMBER, description);
}

if (type.isArray()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
package dev.langchain4j.agent.tool;

import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import java.math.BigDecimal;
import java.math.BigInteger;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class ToolExecutorTest {

TestTool testTool = new TestTool();

private static class TestTool {

@Tool
double doubles(double arg0, Double arg1) {
return arg0 + arg1;
}

@Tool
float floats(float arg0, Float arg1) {
return arg0 + arg1;
}

@Tool
BigDecimal bigDecimals(BigDecimal arg0, BigDecimal arg1) {
return arg0.add(arg1);
}

@Tool
long longs(long arg0, Long arg1) {
return arg0 + arg1;
}

@Tool
int ints(int arg0, Integer arg1) {
return arg0 + arg1;
}

@Tool
short shorts(short arg0, Short arg1) {
return (short) (arg0 + arg1);
}

@Tool
byte bytes(byte arg0, Byte arg1) {
return (byte) (arg0 + arg1);
}

@Tool
BigInteger bigIntegers(BigInteger arg0, BigInteger arg1) {
return arg0.add(arg1);
}
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2}",
"{\"arg0\": 2.0, \"arg1\": 2.0}",
"{\"arg0\": 1.9, \"arg1\": 2.1}",
})
void should_execute_tool_with_parameters_of_type_double(String arguments) throws NoSuchMethodException {
executeAndAssert(arguments, "doubles", double.class, Double.class, "4.0");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2}",
"{\"arg0\": 2.0, \"arg1\": 2.0}",
"{\"arg0\": 1.9, \"arg1\": 2.1}",
})
void should_execute_tool_with_parameters_of_type_float(String arguments) throws NoSuchMethodException {
executeAndAssert(arguments, "floats", float.class, Float.class, "4.0");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": " + Float.MAX_VALUE + "}",
"{\"arg0\": 2, \"arg1\": " + -Double.MAX_VALUE + "}"
})
void should_fail_when_argument_does_not_fit_into_float_type(String arguments) throws NoSuchMethodException {
executeAndExpectFailure(arguments, "floats", float.class, Float.class, "is out of range for the float type");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2}",
"{\"arg0\": 2.0, \"arg1\": 2.0}",
"{\"arg0\": 1.9, \"arg1\": 2.1}",
})
void should_execute_tool_with_parameters_of_type_BigDecimal(String arguments) throws NoSuchMethodException {
executeAndAssert(arguments, "bigDecimals", BigDecimal.class, BigDecimal.class, "4.0");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2}",
"{\"arg0\": 2.0, \"arg1\": 2.0}"
})
void should_execute_tool_with_parameters_of_type_long(String arguments) throws NoSuchMethodException {
executeAndAssert(arguments, "longs", long.class, Long.class, "4");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2.1}",
"{\"arg0\": 2.1, \"arg1\": 2}"
})
void should_fail_when_argument_is_fractional_number_for_parameter_of_type_long(String arguments) throws NoSuchMethodException {
executeAndExpectFailure(arguments, "longs", long.class, Long.class, "argument type mismatch");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": " + Double.MAX_VALUE + "}",
"{\"arg0\": 2, \"arg1\": " + -Double.MAX_VALUE + "}"
})
void should_fail_when_argument_does_not_fit_into_long_type(String arguments) throws NoSuchMethodException {
executeAndExpectFailure(arguments, "longs", long.class, Long.class, "is out of range for the long type");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2}",
"{\"arg0\": 2.0, \"arg1\": 2.0}"
})
void should_execute_tool_with_parameters_of_type_int(String arguments) throws NoSuchMethodException {
executeAndAssert(arguments, "ints", int.class, Integer.class, "4");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2.1}",
"{\"arg0\": 2.1, \"arg1\": 2}"
})
void should_fail_when_argument_is_fractional_number_for_parameter_of_type_int(String arguments) throws NoSuchMethodException {
executeAndExpectFailure(arguments, "ints", int.class, Integer.class, "argument type mismatch");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": " + Double.MAX_VALUE + "}",
"{\"arg0\": 2, \"arg1\": " + -Double.MAX_VALUE + "}"
})
void should_fail_when_argument_does_not_fit_into_int_type(String arguments) throws NoSuchMethodException {
executeAndExpectFailure(arguments, "ints", int.class, Integer.class, "is out of range for the integer type");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2}",
"{\"arg0\": 2.0, \"arg1\": 2.0}"
})
void should_execute_tool_with_parameters_of_type_short(String arguments) throws NoSuchMethodException {
executeAndAssert(arguments, "shorts", short.class, Short.class, "4");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2.1}",
"{\"arg0\": 2.1, \"arg1\": 2}"
})
void should_fail_when_argument_is_fractional_number_for_parameter_of_type_short(String arguments) throws NoSuchMethodException {
executeAndExpectFailure(arguments, "shorts", short.class, Short.class, "argument type mismatch");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": " + Double.MAX_VALUE + "}",
"{\"arg0\": 2, \"arg1\": " + -Double.MAX_VALUE + "}"
})
void should_fail_when_argument_does_not_fit_into_short_type(String arguments) throws NoSuchMethodException {
executeAndExpectFailure(arguments, "shorts", short.class, Short.class, "is out of range for the short type");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2}",
"{\"arg0\": 2.0, \"arg1\": 2.0}"
})
void should_execute_tool_with_parameters_of_type_byte(String arguments) throws NoSuchMethodException {
executeAndAssert(arguments, "bytes", byte.class, Byte.class, "4");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2.1}",
"{\"arg0\": 2.1, \"arg1\": 2}"
})
void should_fail_when_argument_is_fractional_number_for_parameter_of_type_byte(String arguments) throws NoSuchMethodException {
executeAndExpectFailure(arguments, "bytes", byte.class, Byte.class, "argument type mismatch");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": " + Double.MAX_VALUE + "}",
"{\"arg0\": 2, \"arg1\": " + -Double.MAX_VALUE + "}"
})
void should_fail_when_argument_does_not_fit_into_byte_type(String arguments) throws NoSuchMethodException {
executeAndExpectFailure(arguments, "bytes", byte.class, Byte.class, "is out of range for the byte type");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"arg0\": 2, \"arg1\": 2}",
"{\"arg0\": 2.0, \"arg1\": 2.0}"
})
void should_execute_tool_with_parameters_of_type_BigInteger(String arguments) throws NoSuchMethodException {
executeAndAssert(arguments, "bigIntegers", BigInteger.class, BigInteger.class, "4");
}

private void executeAndAssert(String arguments, String methodName, Class<?> arg0Type, Class<?> arg1Type, String expectedResult) throws NoSuchMethodException {
ToolExecutionRequest request = ToolExecutionRequest.builder()
.arguments(arguments)
.build();

ToolExecutor toolExecutor = new ToolExecutor(testTool, TestTool.class.getDeclaredMethod(methodName, arg0Type, arg1Type));

String result = toolExecutor.execute(request.argumentsAsMap());

assertThat(result).isEqualTo(expectedResult);
}

private void executeAndExpectFailure(String arguments, String methodName, Class<?> arg0Type, Class<?> arg1Type, String expectedError) throws NoSuchMethodException {
ToolExecutionRequest request = ToolExecutionRequest.builder()
.arguments(arguments)
.build();

ToolExecutor toolExecutor = new ToolExecutor(testTool, TestTool.class.getDeclaredMethod(methodName, arg0Type, arg1Type));

assertThatThrownBy(() -> toolExecutor.execute(request.argumentsAsMap()))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessageContaining(expectedError);
}
}

0 comments on commit 907c1eb

Please sign in to comment.