Skip to content

Commit

Permalink
Fix a enum serialization issue leading to stackoverflow when creating…
Browse files Browse the repository at this point in the history
… schemas from classes (langchain4j#1450)

- **Fixes langchain4j#1447 Make vector embedding calculation in batch mode to speed
up the calculation of all embeddings for each label.**
- **Remove unneeded import**
- **[Gemini] Fix enum schema handling**
  • Loading branch information
glaforge authored Jul 12, 2024
1 parent 7b9366c commit 03aaa76
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ public static Schema fromClass(Class<?> theClass) {
} else if (Collection.class.isAssignableFrom(theClass)) {
// Because of type erasure, we can't easily know the type of the items in the collection
return Schema.newBuilder().setType(Type.ARRAY).build();
} else if (theClass.isEnum()) {
List<String> enumConstantNames = Arrays.stream(theClass.getEnumConstants())
.map(Object::toString)
.collect(Collectors.toList());
return Schema.newBuilder()
.setType(Type.STRING)
.addAllEnum(enumConstantNames)
.build();
} else {
// This is some kind of object, let's go through its fields
Schema.Builder schemaBuilder = Schema.newBuilder().setType(Type.OBJECT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@

import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Type;
import lombok.Data;
import lombok.Getter;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.Collections;

import static dev.langchain4j.model.vertexai.SchemaHelper.fromClass;
import static org.assertj.core.api.Assertions.assertThat;

public class SchemaHelperTest {
Expand All @@ -17,10 +23,10 @@ class Person {
public int age;
public boolean isStudent;
public String[] friends;
};
}

// when
Schema schema = SchemaHelper.fromClass(Person.class);
Schema schema = fromClass(Person.class);
System.out.println("schema = " + schema);

// then
Expand All @@ -32,7 +38,7 @@ class Person {
assertThat(schema.getPropertiesMap().get("friends").getType()).isEqualTo(Type.ARRAY);
assertThat(schema.getPropertiesMap().get("friends").getItems().getType()).isEqualTo(Type.STRING);
}

@Test
void should_convert_json_schema_string_into_schema() {

Expand Down Expand Up @@ -78,6 +84,38 @@ void should_convert_json_schema_string_into_schema() {
assertThat(schema.getPropertiesMap().get("artist-adult").getType()).isEqualTo(Type.BOOLEAN);
assertThat(schema.getPropertiesMap().get("artist-pets").getType()).isEqualTo(Type.ARRAY);
assertThat(schema.getPropertiesMap().get("artist-pets").getItems().getType()).isEqualTo(Type.STRING);
}

@Getter
enum Sentiment {
POSITIVE(1.0), NEUTRAL(0.0), NEGATIVE(-1.0);
private final double value;
Sentiment(double val) {
this.value = val;
}
}

@Data
static class SentimentClassification {
private Sentiment sentiment;
}

@Test
void should_support_enum_schema_without_stackoverflow() {

// given
Schema schemaFromEnum = fromClass(SentimentClassification.class);

Schema expectedSchema = Schema.newBuilder()
.setType(Type.OBJECT)
.putProperties("sentiment", Schema.newBuilder()
.setType(Type.STRING)
.addAllEnum(Arrays.asList("POSITIVE", "NEUTRAL", "NEGATIVE"))
.build())
.addAllRequired(Collections.singletonList("sentiment"))
.build();

// then
assertThat(schemaFromEnum).isEqualTo(expectedSchema);
}
}

0 comments on commit 03aaa76

Please sign in to comment.