From dc6bba61dda62d1143128ed55e6a2e35fae608be Mon Sep 17 00:00:00 2001 From: LangChain4j Date: Mon, 7 Oct 2024 09:58:58 +0200 Subject: [PATCH] Fix #1863: Allow using Map as a return type of AI Service method. (#1873) ## Issue Fixes #1863 ## Change Allow using `Map` as a return type of AI Service method. Added new method ` T fromJson(String json, Type type)` to `JsonCodec`, Quarkus extension will need to implement it. cc @geoand ## General checklist - [ ] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [X] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [X] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable) --- docs/docs/tutorials/5-ai-services.md | 3 +- .../langchain4j/internal/GsonJsonCodec.java | 35 +++++++++-------- .../java/dev/langchain4j/internal/Json.java | 29 +++++++++++++- .../service/output/ServiceOutputParser.java | 7 ++-- .../dev/langchain4j/service/AiServicesIT.java | 38 +++++++++++++++++-- 5 files changed, 88 insertions(+), 24 deletions(-) diff --git a/docs/docs/tutorials/5-ai-services.md b/docs/docs/tutorials/5-ai-services.md index 6b7325a5718..697a3b74a42 100644 --- a/docs/docs/tutorials/5-ai-services.md +++ b/docs/docs/tutorials/5-ai-services.md @@ -256,9 +256,10 @@ Currently, AI Services support the following return types: - `byte`/`short`/`int`/`BigInteger`/`long`/`float`/`double`/`BigDecimal` - `Date`/`LocalDate`/`LocalTime`/`LocalDateTime` - `List`/`Set`, if you want to get the answer in the form of a list of bullet points +- `Map` - `Result`, if you need to access `TokenUsage`, `FinishReason`, sources (`Content`s retrieved during RAG) and executed tools, aside from `T`, which can be of any type listed above. For example: `Result`, `Result` -Unless the return type is `String` or `AiMessage`, the AI Service will automatically append instructions +Unless the return type is `String`, `AiMessage`, or `Map`, the AI Service will automatically append instructions to the end of the `UserMessage` indicating the format in which the LLM should respond. Before the method returns, the AI Service will parse the output of the LLM into the desired type. diff --git a/langchain4j-core/src/main/java/dev/langchain4j/internal/GsonJsonCodec.java b/langchain4j-core/src/main/java/dev/langchain4j/internal/GsonJsonCodec.java index 95c20f5a9df..d3635535c9c 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/internal/GsonJsonCodec.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/internal/GsonJsonCodec.java @@ -1,16 +1,27 @@ package dev.langchain4j.internal; -import com.google.gson.*; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; +import com.google.gson.JsonSerializer; import com.google.gson.stream.JsonWriter; -import java.io.*; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStreamWriter; +import java.lang.reflect.Type; import java.nio.charset.StandardCharsets; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; -import java.util.Map; -import static java.time.format.DateTimeFormatter.*; +import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE; +import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME; +import static java.time.format.DateTimeFormatter.ISO_LOCAL_TIME; class GsonJsonCodec implements Json.JsonCodec { @@ -87,22 +98,16 @@ public String toJson(Object o) { return GSON.toJson(o); } - /** - * Reads a JSON string and returns an object of the specified type. - * - *

As a special case, if the type is {@link Map}, the returned object is - * parsed as a {@code Map}. - * - * @param json JSON string to be parsed. - * @param type Type of the object to be returned. - * @return the object represented by the JSON string. - * @param Type of the object to be returned. - */ @Override public T fromJson(String json, Class type) { return GSON.fromJson(json, type); } + @Override + public T fromJson(String json, Type type) { + return GSON.fromJson(json, type); + } + @Override public InputStream toInputStream(Object o, Class type) throws IOException { try ( diff --git a/langchain4j-core/src/main/java/dev/langchain4j/internal/Json.java b/langchain4j-core/src/main/java/dev/langchain4j/internal/Json.java index 977cb09b898..ac839473dae 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/internal/Json.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/internal/Json.java @@ -4,6 +4,7 @@ import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.Type; import static dev.langchain4j.spi.ServiceHelper.loadFactories; @@ -30,6 +31,16 @@ public interface JsonCodec { */ String toJson(Object o); + /** + * Convert the given JSON string to an object of the given class. + * + * @param json the JSON string. + * @param type the class of the object. + * @param the type of the object. + * @return the object. + */ + T fromJson(String json, Class type); + /** * Convert the given JSON string to an object of the given type. * @@ -38,7 +49,7 @@ public interface JsonCodec { * @param the type of the object. * @return the object. */ - T fromJson(String json, Class type); + T fromJson(String json, Type type); /** * Convert the given object to an {@link InputStream}. @@ -72,6 +83,20 @@ public static String toJson(Object o) { return CODEC.toJson(o); } + /** + * Convert the given JSON string to an object of the given class. + * + * @param json the JSON string. + * @param type the class of the object. + * @param the type of the object. + * @return the object. + * @deprecated use Jackson's ObjectMapper + */ + @Deprecated + public static T fromJson(String json, Class type) { + return CODEC.fromJson(json, type); + } + /** * Convert the given JSON string to an object of the given type. * @@ -82,7 +107,7 @@ public static String toJson(Object o) { * @deprecated use Jackson's ObjectMapper */ @Deprecated - public static T fromJson(String json, Class type) { + public static T fromJson(String json, Type type) { return CODEC.fromJson(json, type); } diff --git a/langchain4j/src/main/java/dev/langchain4j/service/output/ServiceOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/service/output/ServiceOutputParser.java index f91ec2b2639..7c1658f7eb1 100644 --- a/langchain4j/src/main/java/dev/langchain4j/service/output/ServiceOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/service/output/ServiceOutputParser.java @@ -74,10 +74,10 @@ public Object parse(Response response, Type returnType) { } try { - return Json.fromJson(text, rawReturnClass); + return Json.fromJson(text, returnType); } catch (Exception e) { String jsonBlock = extractJsonBlock(text); - return Json.fromJson(jsonBlock, rawReturnClass); + return Json.fromJson(jsonBlock, returnType); } } @@ -97,7 +97,8 @@ public String outputFormatInstructions(Type returnType) { if (rawClass == String.class || rawClass == AiMessage.class || rawClass == TokenStream.class - || rawClass == Response.class) { + || rawClass == Response.class + || rawClass == Map.class) { return ""; } diff --git a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java index 78afb0d8451..2b4a5845fd1 100644 --- a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java +++ b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java @@ -23,19 +23,28 @@ import java.time.LocalTime; import java.util.Arrays; import java.util.List; +import java.util.Map; import static dev.langchain4j.data.message.SystemMessage.systemMessage; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI; -import static dev.langchain4j.service.AiServicesIT.Ingredient.*; -import static dev.langchain4j.service.AiServicesIT.IssueCategory.*; +import static dev.langchain4j.service.AiServicesIT.Ingredient.OIL; +import static dev.langchain4j.service.AiServicesIT.Ingredient.PEPPER; +import static dev.langchain4j.service.AiServicesIT.Ingredient.SALT; +import static dev.langchain4j.service.AiServicesIT.IssueCategory.COMFORT_ISSUE; +import static dev.langchain4j.service.AiServicesIT.IssueCategory.MAINTENANCE_ISSUE; +import static dev.langchain4j.service.AiServicesIT.IssueCategory.OVERALL_EXPERIENCE_ISSUE; +import static dev.langchain4j.service.AiServicesIT.IssueCategory.SERVICE_ISSUE; import static dev.langchain4j.service.AiServicesIT.Sentiment.POSITIVE; import static java.time.Month.JULY; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.data.MapEntry.entry; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; @ExtendWith(MockitoExtension.class) public class AiServicesIT { @@ -326,6 +335,29 @@ void test_extract_list_of_enums_with_descriptions() { verify(chatLanguageModel).supportedCapabilities(); } + interface MapExtractor { + + @UserMessage("Return a JSON map with the age of each person in the following text: {{it}}") + Map extractAges(String text); + } + + @Test + void should_extract_map() { + + MapExtractor mapExtractor = AiServices.create(MapExtractor.class, chatLanguageModel); + + String text = "Klaus is 42 and Francine is 47"; + + Map ages = mapExtractor.extractAges(text); + + assertThat(ages).containsExactly(entry("Klaus", 42), entry("Francine", 47)); + + verify(chatLanguageModel).generate(singletonList(userMessage( + "Return a JSON map with the age of each person in the following text: " + text + ))); + verify(chatLanguageModel).supportedCapabilities(); + } + @ToString static class Address { private Integer streetNumber;