Skip to content

Commit

Permalink
Fix langchain4j#1863: Allow using Map<K,V> as a return type of AI Ser…
Browse files Browse the repository at this point in the history
…vice method. (langchain4j#1873)

## Issue
Fixes langchain4j#1863

## Change
Allow using `Map<K,V>` as a return type of AI Service method.
Added new method `<T> T fromJson(String json, Type type)` to
`JsonCodec`, Quarkus extension will need to implement it. cc @geoand

## General checklist
- [ ] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [X] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
dliubarskyi authored Oct 7, 2024
1 parent 5bbed84 commit dc6bba6
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 24 deletions.
3 changes: 2 additions & 1 deletion docs/docs/tutorials/5-ai-services.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,10 @@ Currently, AI Services support the following return types:
- `byte`/`short`/`int`/`BigInteger`/`long`/`float`/`double`/`BigDecimal`
- `Date`/`LocalDate`/`LocalTime`/`LocalDateTime`
- `List<String>`/`Set<String>`, if you want to get the answer in the form of a list of bullet points
- `Map<K, V>`
- `Result<T>`, if you need to access `TokenUsage`, `FinishReason`, sources (`Content`s retrieved during RAG) and executed tools, aside from `T`, which can be of any type listed above. For example: `Result<String>`, `Result<MyCustomPojo>`

Unless the return type is `String` or `AiMessage`, the AI Service will automatically append instructions
Unless the return type is `String`, `AiMessage`, or `Map<K, V>`, the AI Service will automatically append instructions
to the end of the `UserMessage` indicating the format in which the LLM should respond.
Before the method returns, the AI Service will parse the output of the LLM into the desired type.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
package dev.langchain4j.internal;

import com.google.gson.*;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import com.google.gson.JsonSerializer;
import com.google.gson.stream.JsonWriter;

import java.io.*;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.Map;

import static java.time.format.DateTimeFormatter.*;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_TIME;

class GsonJsonCodec implements Json.JsonCodec {

Expand Down Expand Up @@ -87,22 +98,16 @@ public String toJson(Object o) {
return GSON.toJson(o);
}

/**
* Reads a JSON string and returns an object of the specified type.
*
* <p>As a special case, if the type is {@link Map}, the returned object is
* parsed as a {@code Map<String, String>}.
*
* @param json JSON string to be parsed.
* @param type Type of the object to be returned.
* @return the object represented by the JSON string.
* @param <T> Type of the object to be returned.
*/
@Override
public <T> T fromJson(String json, Class<T> type) {
return GSON.fromJson(json, type);
}

@Override
public <T> T fromJson(String json, Type type) {
return GSON.fromJson(json, type);
}

@Override
public InputStream toInputStream(Object o, Class<?> type) throws IOException {
try (
Expand Down
29 changes: 27 additions & 2 deletions langchain4j-core/src/main/java/dev/langchain4j/internal/Json.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Type;

import static dev.langchain4j.spi.ServiceHelper.loadFactories;

Expand All @@ -30,6 +31,16 @@ public interface JsonCodec {
*/
String toJson(Object o);

/**
* Convert the given JSON string to an object of the given class.
*
* @param json the JSON string.
* @param type the class of the object.
* @param <T> the type of the object.
* @return the object.
*/
<T> T fromJson(String json, Class<T> type);

/**
* Convert the given JSON string to an object of the given type.
*
Expand All @@ -38,7 +49,7 @@ public interface JsonCodec {
* @param <T> the type of the object.
* @return the object.
*/
<T> T fromJson(String json, Class<T> type);
<T> T fromJson(String json, Type type);

/**
* Convert the given object to an {@link InputStream}.
Expand Down Expand Up @@ -72,6 +83,20 @@ public static String toJson(Object o) {
return CODEC.toJson(o);
}

/**
* Convert the given JSON string to an object of the given class.
*
* @param json the JSON string.
* @param type the class of the object.
* @param <T> the type of the object.
* @return the object.
* @deprecated use Jackson's ObjectMapper
*/
@Deprecated
public static <T> T fromJson(String json, Class<T> type) {
return CODEC.fromJson(json, type);
}

/**
* Convert the given JSON string to an object of the given type.
*
Expand All @@ -82,7 +107,7 @@ public static String toJson(Object o) {
* @deprecated use Jackson's ObjectMapper
*/
@Deprecated
public static <T> T fromJson(String json, Class<T> type) {
public static <T> T fromJson(String json, Type type) {
return CODEC.fromJson(json, type);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ public Object parse(Response<AiMessage> response, Type returnType) {
}

try {
return Json.fromJson(text, rawReturnClass);
return Json.fromJson(text, returnType);
} catch (Exception e) {
String jsonBlock = extractJsonBlock(text);
return Json.fromJson(jsonBlock, rawReturnClass);
return Json.fromJson(jsonBlock, returnType);
}
}

Expand All @@ -97,7 +97,8 @@ public String outputFormatInstructions(Type returnType) {
if (rawClass == String.class
|| rawClass == AiMessage.class
|| rawClass == TokenStream.class
|| rawClass == Response.class) {
|| rawClass == Response.class
|| rawClass == Map.class) {
return "";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,28 @@
import java.time.LocalTime;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.data.message.SystemMessage.systemMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
import static dev.langchain4j.service.AiServicesIT.Ingredient.*;
import static dev.langchain4j.service.AiServicesIT.IssueCategory.*;
import static dev.langchain4j.service.AiServicesIT.Ingredient.OIL;
import static dev.langchain4j.service.AiServicesIT.Ingredient.PEPPER;
import static dev.langchain4j.service.AiServicesIT.Ingredient.SALT;
import static dev.langchain4j.service.AiServicesIT.IssueCategory.COMFORT_ISSUE;
import static dev.langchain4j.service.AiServicesIT.IssueCategory.MAINTENANCE_ISSUE;
import static dev.langchain4j.service.AiServicesIT.IssueCategory.OVERALL_EXPERIENCE_ISSUE;
import static dev.langchain4j.service.AiServicesIT.IssueCategory.SERVICE_ISSUE;
import static dev.langchain4j.service.AiServicesIT.Sentiment.POSITIVE;
import static java.time.Month.JULY;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;
import static org.assertj.core.data.MapEntry.entry;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

@ExtendWith(MockitoExtension.class)
public class AiServicesIT {
Expand Down Expand Up @@ -326,6 +335,29 @@ void test_extract_list_of_enums_with_descriptions() {
verify(chatLanguageModel).supportedCapabilities();
}

interface MapExtractor {

@UserMessage("Return a JSON map with the age of each person in the following text: {{it}}")
Map<String, Integer> extractAges(String text);
}

@Test
void should_extract_map() {

MapExtractor mapExtractor = AiServices.create(MapExtractor.class, chatLanguageModel);

String text = "Klaus is 42 and Francine is 47";

Map<String, Integer> ages = mapExtractor.extractAges(text);

assertThat(ages).containsExactly(entry("Klaus", 42), entry("Francine", 47));

verify(chatLanguageModel).generate(singletonList(userMessage(
"Return a JSON map with the age of each person in the following text: " + text
)));
verify(chatLanguageModel).supportedCapabilities();
}

@ToString
static class Address {
private Integer streetNumber;
Expand Down

0 comments on commit dc6bba6

Please sign in to comment.