diff --git a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/SchemaHelper.java b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/SchemaHelper.java index 720b6426076..644f638e828 100644 --- a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/SchemaHelper.java +++ b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/SchemaHelper.java @@ -76,6 +76,14 @@ public static Schema fromClass(Class theClass) { } else if (Collection.class.isAssignableFrom(theClass)) { // Because of type erasure, we can't easily know the type of the items in the collection return Schema.newBuilder().setType(Type.ARRAY).build(); + } else if (theClass.isEnum()) { + List enumConstantNames = Arrays.stream(theClass.getEnumConstants()) + .map(Object::toString) + .collect(Collectors.toList()); + return Schema.newBuilder() + .setType(Type.STRING) + .addAllEnum(enumConstantNames) + .build(); } else { // This is some kind of object, let's go through its fields Schema.Builder schemaBuilder = Schema.newBuilder().setType(Type.OBJECT); diff --git a/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/SchemaHelperTest.java b/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/SchemaHelperTest.java index 14362d1461c..2ae1a958aed 100644 --- a/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/SchemaHelperTest.java +++ b/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/SchemaHelperTest.java @@ -2,8 +2,14 @@ import com.google.cloud.vertexai.api.Schema; import com.google.cloud.vertexai.api.Type; +import lombok.Data; +import lombok.Getter; import org.junit.jupiter.api.Test; +import java.util.Arrays; +import java.util.Collections; + +import static dev.langchain4j.model.vertexai.SchemaHelper.fromClass; import static org.assertj.core.api.Assertions.assertThat; public class SchemaHelperTest { @@ -17,10 +23,10 @@ class Person { public int age; public boolean isStudent; public String[] friends; - }; + } // when - Schema schema = SchemaHelper.fromClass(Person.class); + Schema schema = fromClass(Person.class); System.out.println("schema = " + schema); // then @@ -32,7 +38,7 @@ class Person { assertThat(schema.getPropertiesMap().get("friends").getType()).isEqualTo(Type.ARRAY); assertThat(schema.getPropertiesMap().get("friends").getItems().getType()).isEqualTo(Type.STRING); } - + @Test void should_convert_json_schema_string_into_schema() { @@ -78,6 +84,38 @@ void should_convert_json_schema_string_into_schema() { assertThat(schema.getPropertiesMap().get("artist-adult").getType()).isEqualTo(Type.BOOLEAN); assertThat(schema.getPropertiesMap().get("artist-pets").getType()).isEqualTo(Type.ARRAY); assertThat(schema.getPropertiesMap().get("artist-pets").getItems().getType()).isEqualTo(Type.STRING); + } + + @Getter + enum Sentiment { + POSITIVE(1.0), NEUTRAL(0.0), NEGATIVE(-1.0); + private final double value; + Sentiment(double val) { + this.value = val; + } + } + + @Data + static class SentimentClassification { + private Sentiment sentiment; + } + @Test + void should_support_enum_schema_without_stackoverflow() { + + // given + Schema schemaFromEnum = fromClass(SentimentClassification.class); + + Schema expectedSchema = Schema.newBuilder() + .setType(Type.OBJECT) + .putProperties("sentiment", Schema.newBuilder() + .setType(Type.STRING) + .addAllEnum(Arrays.asList("POSITIVE", "NEUTRAL", "NEGATIVE")) + .build()) + .addAllRequired(Collections.singletonList("sentiment")) + .build(); + + // then + assertThat(schemaFromEnum).isEqualTo(expectedSchema); } }