Skip to content

Commit

Permalink
Tools: support schema for custom POJO parameters (langchain4j#708)
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j committed Jun 20, 2024
1 parent 6762d33 commit 66b8200
Showing 1 changed file with 31 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.mistralai.MistralAiChatModel;
import dev.langchain4j.model.mistralai.MistralAiChatModelName;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.output.structured.Description;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.junit.jupiter.MockitoExtension;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;

Expand All @@ -35,6 +33,7 @@
import static dev.langchain4j.model.mistralai.MistralAiChatModelName.MISTRAL_LARGE_LATEST;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO_0613;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.service.AiServicesWithToolsIT.Operator.EQUALS;
import static dev.langchain4j.service.AiServicesWithToolsIT.TemperatureUnit.Kelvin;
import static dev.langchain4j.service.AiServicesWithToolsIT.TransactionService.EXPECTED_SPECIFICATION;
import static java.util.Arrays.asList;
Expand All @@ -54,6 +53,13 @@ static Stream<ChatLanguageModel> models() {
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build(),
MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.modelName(MISTRAL_LARGE_LATEST)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build()
// TODO other models supporting tools
);
Expand Down Expand Up @@ -568,43 +574,51 @@ void should_use_tool_with_enum_parameter(ChatLanguageModel chatLanguageModel) {

static class QueryService {



@Tool("Execute the query and return the result")
String executeQuery(@P("query to execute") Query query) {
if(query == null){
return "query cannot be null ";
}
System.out.printf("query to execute", Json.toJson(query) );
assertThat(query).isNotNull();
System.out.println("query to execute: " + Json.toJson(query));

assertThat(query.select).containsExactly("name");
assertThat(query.where).containsExactly(new Condition("country", EQUALS, "India"));
assertThat(query.limit).isEqualTo(3);

return "[ {name = Amar}, {name= Akbar} ,{ name = Antony} ]";
return "Amar, Akbar, Antony";
}
}

@Data
static class Query{
static class Query {

@Description("List of fields to fetch records")
List<String> select;

@Description("List of conditions to filter on. Pass null if no condition")
List<Condition> where;

@Description("limit on number of records")
Integer limit;

@Description("offset for fetching records")
Integer offset;

}

@Data
static class Condition{
@AllArgsConstructor
static class Condition {

@Description("Field to filter on")
String field;

@Description("Operator to apply")
Operator operator;

@Description("Value to compare with")
Object value;
}

enum Operator{
enum Operator {

EQUALS,
NOT_EQUALS,
IS_NULL,
Expand All @@ -613,7 +627,7 @@ enum Operator{

@ParameterizedTest
@MethodSource("models")
void should_use_tool_with_pojo(ChatLanguageModel chatLanguageModel) {
void should_use_tool_with_pojo(ChatLanguageModel chatLanguageModel) {

// given
QueryService queryService = spy(new QueryService());
Expand All @@ -628,9 +642,8 @@ void should_use_tool_with_pojo(ChatLanguageModel chatLanguageModel) {
.tools(queryService)
.build();

Response<AiMessage> response = assistant.chat("List 5 users where country is India");

assertThat(response.content().text()).contains("Amar");
Response<AiMessage> response = assistant.chat("List names of 3 users where country is India");

assertThat(response.content().text()).contains("Amar", "Akbar", "Antony");
}
}

0 comments on commit 66b8200

Please sign in to comment.