|
5 | 5 | import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
|
6 | 6 | import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute;
|
7 | 7 |
|
8 |
| -import javax.annotation.Nonnull; |
9 |
| -import javax.annotation.Nullable; |
10 | 8 | import java.lang.reflect.Field;
|
11 | 9 | import java.util.ArrayList;
|
12 | 10 | import java.util.Collections;
|
13 | 11 | import java.util.HashSet;
|
14 | 12 | import java.util.List;
|
| 13 | +import java.util.Set; |
15 | 14 | import java.util.stream.Collectors;
|
16 | 15 |
|
17 | 16 | /**
|
@@ -50,6 +49,50 @@ public List<VectorStoreRecordField> getAllFields() {
|
50 | 49 | return fields;
|
51 | 50 | }
|
52 | 51 |
|
| 52 | + public List<VectorStoreRecordField> getNonVectorFields() { |
| 53 | + List<VectorStoreRecordField> fields = new ArrayList<>(); |
| 54 | + fields.add(keyField); |
| 55 | + fields.addAll(dataFields); |
| 56 | + return fields; |
| 57 | + } |
| 58 | + |
| 59 | + private List<Field> getDeclaredFields(Class<?> recordClass, List<VectorStoreRecordField> fields, String fieldType) { |
| 60 | + List<Field> declaredFields = new ArrayList<>(); |
| 61 | + for (VectorStoreRecordField field : fields) { |
| 62 | + try { |
| 63 | + Field declaredField = recordClass.getDeclaredField(field.getName()); |
| 64 | + declaredFields.add(declaredField); |
| 65 | + } catch (NoSuchFieldException e) { |
| 66 | + throw new IllegalArgumentException( |
| 67 | + String.format("%s field not found in record class: %s", fieldType, field.getName())); |
| 68 | + } |
| 69 | + } |
| 70 | + return declaredFields; |
| 71 | + } |
| 72 | + |
| 73 | + public Field getKeyDeclaredField(Class<?> recordClass) { |
| 74 | + try { |
| 75 | + return recordClass.getDeclaredField(keyField.getName()); |
| 76 | + } catch (NoSuchFieldException e) { |
| 77 | + throw new IllegalArgumentException( |
| 78 | + "Key field not found in record class: " + keyField.getName()); |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + public List<Field> getDataDeclaredFields(Class<?> recordClass) { |
| 83 | + return getDeclaredFields( |
| 84 | + recordClass, |
| 85 | + dataFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), |
| 86 | + "Data"); |
| 87 | + } |
| 88 | + |
| 89 | + public List<Field> getVectorDeclaredFields(Class<?> recordClass) { |
| 90 | + return getDeclaredFields( |
| 91 | + recordClass, |
| 92 | + vectorFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), |
| 93 | + "Vector"); |
| 94 | + } |
| 95 | + |
53 | 96 | private VectorStoreRecordDefinition(
|
54 | 97 | VectorStoreRecordKeyField keyField,
|
55 | 98 | List<VectorStoreRecordDataField> dataFields,
|
@@ -148,71 +191,19 @@ public static VectorStoreRecordDefinition fromRecordClass(Class<?> recordClass)
|
148 | 191 | return checkFields(keyFields, dataFields, vectorFields);
|
149 | 192 | }
|
150 | 193 |
|
151 |
| - private static String getSupportedTypesString(@Nullable HashSet<Class<?>> types) { |
152 |
| - if (types == null || types.isEmpty()) { |
153 |
| - return ""; |
154 |
| - } |
155 |
| - return types.stream().map(Class::getName).collect(Collectors.joining(", ")); |
156 |
| - } |
157 |
| - |
158 |
| - public static void validateSupportedKeyTypes(@Nonnull Class<?> recordClass, |
159 |
| - @Nonnull VectorStoreRecordDefinition recordDefinition, |
160 |
| - @Nonnull HashSet<Class<?>> supportedTypes) { |
161 |
| - String supportedTypesString = getSupportedTypesString(supportedTypes); |
162 |
| - |
163 |
| - try { |
164 |
| - Field declaredField = recordClass.getDeclaredField(recordDefinition.keyField.getName()); |
165 | 194 |
|
| 195 | + public static void validateSupportedTypes(List<Field> declaredFields, Set<Class<?>> supportedTypes) { |
| 196 | + Set<Class<?>> unsupportedTypes = new HashSet<>(); |
| 197 | + for (Field declaredField : declaredFields) { |
166 | 198 | if (!supportedTypes.contains(declaredField.getType())) {
|
167 |
| - throw new IllegalArgumentException( |
168 |
| - "Unsupported key field type: " + declaredField.getType().getName() |
169 |
| - + ". Supported types are: " + supportedTypesString); |
170 |
| - } |
171 |
| - } catch (NoSuchFieldException e) { |
172 |
| - throw new IllegalArgumentException( |
173 |
| - "Key field not found in record class: " + recordDefinition.keyField.getName()); |
174 |
| - } |
175 |
| - } |
176 |
| - |
177 |
| - public static void validateSupportedDataTypes(@Nonnull Class<?> recordClass, |
178 |
| - @Nonnull VectorStoreRecordDefinition recordDefinition, |
179 |
| - @Nonnull HashSet<Class<?>> supportedTypes) { |
180 |
| - String supportedTypesString = getSupportedTypesString(supportedTypes); |
181 |
| - |
182 |
| - for (VectorStoreRecordDataField field : recordDefinition.dataFields) { |
183 |
| - try { |
184 |
| - Field declaredField = recordClass.getDeclaredField(field.getName()); |
185 |
| - |
186 |
| - if (!supportedTypes.contains(declaredField.getType())) { |
187 |
| - throw new IllegalArgumentException( |
188 |
| - "Unsupported data field type: " + declaredField.getType().getName() |
189 |
| - + ". Supported types are: " + supportedTypesString); |
190 |
| - } |
191 |
| - } catch (NoSuchFieldException e) { |
192 |
| - throw new IllegalArgumentException( |
193 |
| - "Data field not found in record class: " + field.getName()); |
| 199 | + unsupportedTypes.add(declaredField.getType()); |
194 | 200 | }
|
195 | 201 | }
|
196 |
| - } |
197 |
| - |
198 |
| - public static void validateSupportedVectorTypes(@Nonnull Class<?> recordClass, |
199 |
| - @Nonnull VectorStoreRecordDefinition recordDefinition, |
200 |
| - @Nonnull HashSet<Class<?>> supportedTypes) { |
201 |
| - String supportedTypesString = getSupportedTypesString(supportedTypes); |
202 |
| - |
203 |
| - for (VectorStoreRecordVectorField field : recordDefinition.vectorFields) { |
204 |
| - try { |
205 |
| - Field declaredField = recordClass.getDeclaredField(field.getName()); |
206 |
| - |
207 |
| - if (!supportedTypes.contains(declaredField.getType())) { |
208 |
| - throw new IllegalArgumentException( |
209 |
| - "Unsupported vector field type: " + declaredField.getType().getName() |
210 |
| - + ". Supported types are: " + supportedTypesString); |
211 |
| - } |
212 |
| - } catch (NoSuchFieldException e) { |
213 |
| - throw new IllegalArgumentException( |
214 |
| - "Vector field not found in record class: " + field.getName()); |
215 |
| - } |
| 202 | + if (!unsupportedTypes.isEmpty()) { |
| 203 | + throw new IllegalArgumentException( |
| 204 | + String.format("Unsupported field types found in record class: %s. Supported types: %s", |
| 205 | + unsupportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")), |
| 206 | + supportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")))); |
216 | 207 | }
|
217 | 208 | }
|
218 | 209 | }
|
0 commit comments