Skip to content

Commit 21082b8

Browse files
author
Milder Hernandez Cagua
committed
Update VectorStoreRecordCollection
1 parent 61219b5 commit 21082b8

File tree

5 files changed

+77
-81
lines changed

5 files changed

+77
-81
lines changed

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,15 @@ public AzureAISearchVectorStoreRecordCollection(
9090
: options.getRecordDefinition();
9191

9292
// Validate supported types
93-
VectorStoreRecordDefinition.validateSupportedKeyTypes(this.options.getRecordClass(),
94-
this.recordDefinition, supportedKeyTypes);
95-
VectorStoreRecordDefinition.validateSupportedDataTypes(this.options.getRecordClass(),
96-
this.recordDefinition, supportedDataTypes);
97-
VectorStoreRecordDefinition.validateSupportedVectorTypes(this.options.getRecordClass(),
98-
this.recordDefinition, supportedVectorTypes);
93+
VectorStoreRecordDefinition.validateSupportedTypes(
94+
Collections.singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())),
95+
supportedKeyTypes);
96+
VectorStoreRecordDefinition.validateSupportedTypes(
97+
recordDefinition.getDataDeclaredFields(this.options.getRecordClass()),
98+
supportedDataTypes);
99+
VectorStoreRecordDefinition.validateSupportedTypes(
100+
recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()),
101+
supportedVectorTypes);
99102

100103
// Add non-vector fields to the list
101104
nonVectorFields.add(this.recordDefinition.getKeyField().getName());

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ public RedisVectorStoreRecordCollection(
8181
}
8282

8383
// Validate supported types
84-
VectorStoreRecordDefinition.validateSupportedKeyTypes(options.getRecordClass(),
85-
recordDefinition, supportedKeyTypes);
86-
VectorStoreRecordDefinition.validateSupportedVectorTypes(options.getRecordClass(),
87-
recordDefinition, supportedVectorTypes);
84+
VectorStoreRecordDefinition.validateSupportedTypes(
85+
Collections.singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())),
86+
supportedKeyTypes);
87+
VectorStoreRecordDefinition.validateSupportedTypes(
88+
recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()),
89+
supportedVectorTypes);
8890

8991
// If mapper is not provided, set a default one
9092
if (options.getVectorStoreRecordMapper() == null) {

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,40 @@
99
import java.util.List;
1010

1111
public interface VectorStoreRecordCollection<Key, Record> {
12-
1312
/**
1413
* Gets the name of the collection.
1514
*
1615
* @return The name of the collection.
1716
*/
18-
public String getCollectionName();
17+
String getCollectionName();
1918

2019
/**
2120
* Checks if the collection exists in the store.
2221
*
2322
* @return A Mono emitting a boolean indicating if the collection exists.
2423
*/
25-
public Mono<Boolean> collectionExistsAsync();
24+
Mono<Boolean> collectionExistsAsync();
2625

2726
/**
2827
* Creates the collection in the store.
2928
*
3029
* @return A Mono representing the completion of the creation operation.
3130
*/
32-
public Mono<Void> createCollectionAsync();
31+
Mono<Void> createCollectionAsync();
3332

3433
/**
3534
* Creates the collection in the store if it does not exist.
3635
*
3736
* @return A Mono representing the completion of the creation operation.
3837
*/
39-
public Mono<Void> createCollectionIfNotExistsAsync();
38+
Mono<Void> createCollectionIfNotExistsAsync();
4039

4140
/**
4241
* Deletes the collection from the store.
4342
*
4443
* @return A Mono representing the completion of the deletion operation.
4544
*/
46-
public Mono<Void> deleteCollectionAsync();
45+
Mono<Void> deleteCollectionAsync();
4746

4847
/**
4948
* Gets a record from the store.

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ public VolatileVectorStoreRecordCollection(String collectionName,
4343
}
4444

4545
// Validate the key type
46-
VectorStoreRecordDefinition.validateSupportedKeyTypes(options.getRecordClass(),
47-
recordDefinition, supportedKeyTypes);
46+
VectorStoreRecordDefinition.validateSupportedTypes(
47+
Collections.singletonList(recordDefinition.getKeyDeclaredField(options.getRecordClass())),
48+
supportedKeyTypes);
4849
}
4950

5051
VolatileVectorStoreRecordCollection(String collectionName,

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java

Lines changed: 54 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
66
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute;
77

8-
import javax.annotation.Nonnull;
9-
import javax.annotation.Nullable;
108
import java.lang.reflect.Field;
119
import java.util.ArrayList;
1210
import java.util.Collections;
1311
import java.util.HashSet;
1412
import java.util.List;
13+
import java.util.Set;
1514
import java.util.stream.Collectors;
1615

1716
/**
@@ -50,6 +49,50 @@ public List<VectorStoreRecordField> getAllFields() {
5049
return fields;
5150
}
5251

52+
public List<VectorStoreRecordField> getNonVectorFields() {
53+
List<VectorStoreRecordField> fields = new ArrayList<>();
54+
fields.add(keyField);
55+
fields.addAll(dataFields);
56+
return fields;
57+
}
58+
59+
private List<Field> getDeclaredFields(Class<?> recordClass, List<VectorStoreRecordField> fields, String fieldType) {
60+
List<Field> declaredFields = new ArrayList<>();
61+
for (VectorStoreRecordField field : fields) {
62+
try {
63+
Field declaredField = recordClass.getDeclaredField(field.getName());
64+
declaredFields.add(declaredField);
65+
} catch (NoSuchFieldException e) {
66+
throw new IllegalArgumentException(
67+
String.format("%s field not found in record class: %s", fieldType, field.getName()));
68+
}
69+
}
70+
return declaredFields;
71+
}
72+
73+
public Field getKeyDeclaredField(Class<?> recordClass) {
74+
try {
75+
return recordClass.getDeclaredField(keyField.getName());
76+
} catch (NoSuchFieldException e) {
77+
throw new IllegalArgumentException(
78+
"Key field not found in record class: " + keyField.getName());
79+
}
80+
}
81+
82+
public List<Field> getDataDeclaredFields(Class<?> recordClass) {
83+
return getDeclaredFields(
84+
recordClass,
85+
dataFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()),
86+
"Data");
87+
}
88+
89+
public List<Field> getVectorDeclaredFields(Class<?> recordClass) {
90+
return getDeclaredFields(
91+
recordClass,
92+
vectorFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()),
93+
"Vector");
94+
}
95+
5396
private VectorStoreRecordDefinition(
5497
VectorStoreRecordKeyField keyField,
5598
List<VectorStoreRecordDataField> dataFields,
@@ -148,71 +191,19 @@ public static VectorStoreRecordDefinition fromRecordClass(Class<?> recordClass)
148191
return checkFields(keyFields, dataFields, vectorFields);
149192
}
150193

151-
private static String getSupportedTypesString(@Nullable HashSet<Class<?>> types) {
152-
if (types == null || types.isEmpty()) {
153-
return "";
154-
}
155-
return types.stream().map(Class::getName).collect(Collectors.joining(", "));
156-
}
157-
158-
public static void validateSupportedKeyTypes(@Nonnull Class<?> recordClass,
159-
@Nonnull VectorStoreRecordDefinition recordDefinition,
160-
@Nonnull HashSet<Class<?>> supportedTypes) {
161-
String supportedTypesString = getSupportedTypesString(supportedTypes);
162-
163-
try {
164-
Field declaredField = recordClass.getDeclaredField(recordDefinition.keyField.getName());
165194

195+
public static void validateSupportedTypes(List<Field> declaredFields, Set<Class<?>> supportedTypes) {
196+
Set<Class<?>> unsupportedTypes = new HashSet<>();
197+
for (Field declaredField : declaredFields) {
166198
if (!supportedTypes.contains(declaredField.getType())) {
167-
throw new IllegalArgumentException(
168-
"Unsupported key field type: " + declaredField.getType().getName()
169-
+ ". Supported types are: " + supportedTypesString);
170-
}
171-
} catch (NoSuchFieldException e) {
172-
throw new IllegalArgumentException(
173-
"Key field not found in record class: " + recordDefinition.keyField.getName());
174-
}
175-
}
176-
177-
public static void validateSupportedDataTypes(@Nonnull Class<?> recordClass,
178-
@Nonnull VectorStoreRecordDefinition recordDefinition,
179-
@Nonnull HashSet<Class<?>> supportedTypes) {
180-
String supportedTypesString = getSupportedTypesString(supportedTypes);
181-
182-
for (VectorStoreRecordDataField field : recordDefinition.dataFields) {
183-
try {
184-
Field declaredField = recordClass.getDeclaredField(field.getName());
185-
186-
if (!supportedTypes.contains(declaredField.getType())) {
187-
throw new IllegalArgumentException(
188-
"Unsupported data field type: " + declaredField.getType().getName()
189-
+ ". Supported types are: " + supportedTypesString);
190-
}
191-
} catch (NoSuchFieldException e) {
192-
throw new IllegalArgumentException(
193-
"Data field not found in record class: " + field.getName());
199+
unsupportedTypes.add(declaredField.getType());
194200
}
195201
}
196-
}
197-
198-
public static void validateSupportedVectorTypes(@Nonnull Class<?> recordClass,
199-
@Nonnull VectorStoreRecordDefinition recordDefinition,
200-
@Nonnull HashSet<Class<?>> supportedTypes) {
201-
String supportedTypesString = getSupportedTypesString(supportedTypes);
202-
203-
for (VectorStoreRecordVectorField field : recordDefinition.vectorFields) {
204-
try {
205-
Field declaredField = recordClass.getDeclaredField(field.getName());
206-
207-
if (!supportedTypes.contains(declaredField.getType())) {
208-
throw new IllegalArgumentException(
209-
"Unsupported vector field type: " + declaredField.getType().getName()
210-
+ ". Supported types are: " + supportedTypesString);
211-
}
212-
} catch (NoSuchFieldException e) {
213-
throw new IllegalArgumentException(
214-
"Vector field not found in record class: " + field.getName());
215-
}
202+
if (!unsupportedTypes.isEmpty()) {
203+
throw new IllegalArgumentException(
204+
String.format("Unsupported field types found in record class: %s. Supported types: %s",
205+
unsupportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")),
206+
supportedTypes.stream().map(Class::getName).collect(Collectors.joining(", "))));
216207
}
217208
}
218209
}

0 commit comments

Comments
 (0)