Skip to content

Commit

Permalink
1446 : Extraction of json block before parse (langchain4j#1516)
Browse files Browse the repository at this point in the history
Extracts json block before parse and added support for injecting custom
ServiceOutputParser to AiService.

## Issue
Closes langchain4j#1446

## Change
Added method in DefaultServiceOutputParser to find start and end index
of json blocks and substrings result before parse.

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
patpe authored Aug 26, 2024
1 parent 9842ab6 commit 8cd1e73
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio

response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());

Object parsedResponse;
parsedResponse = serviceOutputParser.parse(response, returnType);
Object parsedResponse = serviceOutputParser.parse(response, returnType);
if (typeHasRawClass(returnType, Result.class)) {
return Result.builder()
.content(parsedResponse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
Expand All @@ -18,6 +20,18 @@

public class ServiceOutputParser {

/**
* JSON Pattern:<br />
*
* <i>\\{(?:[^{}]|\\{(?:[^{}]|\\{(?:[^{}]|\\{[^{}]*\\})*\\})*\\})*\\}</i>: Matches JSON objects, accounting for nested objects.
*/
private static final String JSON_PATTERN_REGEX = "\\{(?:[^{}]|\\{(?:[^{}]|\\{(?:[^{}]|\\{[^{}]*\\})*\\})*\\})*\\}";

/**
* Pattern.DOTALL: This flag makes the dot . match all characters, including newline characters.
*/
private static final Pattern JSON_BLOCK_PATTERN = Pattern.compile(JSON_PATTERN_REGEX, Pattern.DOTALL);

private final OutputParserFactory outputParserFactory;

public ServiceOutputParser() {
Expand Down Expand Up @@ -60,7 +74,8 @@ public Object parse(Response<AiMessage> response, Type returnType) {
return optionalOutputParser.get().parse(text);
}

return Json.fromJson(text, rawReturnClass);
String extractedJsonBlock = extractJsonBlock(text);
return Json.fromJson(extractedJsonBlock, rawReturnClass);
}

public String outputFormatInstructions(Type returnType) {
Expand Down Expand Up @@ -106,7 +121,7 @@ public String outputFormatInstructions(Type returnType) {
return "\nYou must answer strictly in the following JSON format: " + jsonStructure((rawClass), new HashSet<>());
}

public static String jsonStructure(Class<?> structured, Set<Class<?>> visited) {
private static String jsonStructure(Class<?> structured, Set<Class<?>> visited) {
StringBuilder jsonSchema = new StringBuilder();

jsonSchema.append("{\n");
Expand Down Expand Up @@ -195,4 +210,14 @@ private static String simpleTypeName(Type type) {
return type.getTypeName();
}
}

private String extractJsonBlock(String text) {
Matcher matcher = JSON_BLOCK_PATTERN.matcher(text);

if (matcher.find()) {
return matcher.group();
}

return text;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package dev.langchain4j.service.output;

import com.google.gson.JsonSyntaxException;
import com.google.gson.reflect.TypeToken;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.structured.Description;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.stubbing.Answer;

import java.io.Serializable;
Expand All @@ -18,6 +21,7 @@
import java.util.concurrent.atomic.AtomicReference;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -91,6 +95,79 @@ private void testWhetherProperOutputParserWasCalled(AiMessage aiMessage, Type ra
assertInstanceOf(expectedOutputParserType, capturedOutputParser);
}

/********************************************************************************************
* Json output parse tests
********************************************************************************************/

@ParameterizedTest
@ValueSource(strings = {
"{\"key\":\"value\"}",
"```\n{\"key\":\"value\"}\n```",
"```json\n{\"key\":\"value\"}\n```",
"Sure, here is your JSON:\n```\n{\"key\":\"value\"}\n```\nLet me know if you need more help."
})
void makeSureJsonBlockIsExtractedBeforeParse(String json) {
// Given
AiMessage aiMessage = AiMessage.aiMessage(json);
Response<AiMessage> responseStub = Response.from(aiMessage);
sut = new ServiceOutputParser();

// When
Object result = sut.parse(responseStub, KeyProperty.class);

// Then
assertInstanceOf(KeyProperty.class, result);

KeyProperty keyProperty = (KeyProperty) result;
assertThat(keyProperty.key).isEqualTo("value");
}

@ParameterizedTest
@ValueSource(strings = {
"{\"keyProperty\" : {\"key\" : \"value\"}}",
"```\n{\"keyProperty\" :\n {\"key\" : \"value\"}\n}\n```",
"```json\n{\"keyProperty\" :\n {\"key\" : \"value\"}\n}\n```",
"Sure, here is your JSON:\n```\n{\"keyProperty\" :\n {\"key\" : \"value\"}\n}\n```\nLet me know if you need more help."
})
void makeSureNestedJsonBlockIsExtractedBeforeParse(String json) {
// Given
AiMessage aiMessage = AiMessage.aiMessage(json);
Response<AiMessage> responseStub = Response.from(aiMessage);
sut = new ServiceOutputParser();

// When
Object result = sut.parse(responseStub, KeyPropertyWrapper.class);

// Then
assertInstanceOf(KeyPropertyWrapper.class, result);

KeyPropertyWrapper keyProperty = (KeyPropertyWrapper) result;
assertThat(keyProperty.keyProperty.key).isEqualTo("value");
}

@ParameterizedTest
@ValueSource(strings = {
"\"key\":\"value\"}",
"{\"key\":\"value\""
})
void illegalJsonBlockNotExtractedAndFailsParse(String json) {
// Given
AiMessage aiMessage = AiMessage.aiMessage(json);
Response<AiMessage> responseStub = Response.from(aiMessage);
sut = new ServiceOutputParser();

// When / Then
assertThatExceptionOfType(JsonSyntaxException.class).isThrownBy(() -> sut.parse(responseStub, KeyProperty.class));
}

static class KeyPropertyWrapper {
KeyProperty keyProperty;
}

static class KeyProperty {
String key;
}

/********************************************************************************************
* Output format instructions tests
********************************************************************************************/
Expand Down

0 comments on commit 8cd1e73

Please sign in to comment.