diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnSchema.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnSchema.java index 78c0e24a94..3b4fff7031 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnSchema.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnSchema.java @@ -180,8 +180,11 @@ public ColumnSchema(@NonNull Wal.ColumnSchema schema) { public ColumnSchemaDesc toColumnSchemaDesc() { var builder = ColumnSchemaDesc.builder() .name(this.name) - .index(this.offset) - .type(this.type.name()); + .index(this.offset); + if (this.type == null) { + return builder.build(); + } + builder.type(this.type.name()); switch (this.type) { case LIST: case TUPLE: diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/DataStore.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/DataStore.java index 3f0d0b08cf..e97f0c2e0b 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/DataStore.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/DataStore.java @@ -340,11 +340,12 @@ public RecordList query(DataStoreQueryRequest req) { } Map columnSchemaMap; if (!req.isEncodeWithType()) { - columnSchemaMap = table.getSchema().getColumnSchemaList().stream() - .filter(col -> columns.containsKey(col.getName())) + // columnSchemaMap is useless except column name + columnSchemaMap = table.getSchema().getColumnNames().stream() + .filter(columns::containsKey) .map(col -> { - var ret = new ColumnSchema(col); - ret.setName(columns.get(col.getName())); + var ret = new ColumnSchema(col, 0); + ret.setName(columns.get(col)); return ret; }) .collect(Collectors.toMap(ColumnSchema::getName, Function.identity())); @@ -613,11 +614,11 @@ private R scanRecords(DataStoreScanRequest req, ResultResolver resultReso .collect(Collectors.toMap(Entry::getKey, entry -> info.getColumnPrefix() + entry.getValue())); } - ret.columnSchemaMap = ret.schema.getColumnSchemaList().stream() - .filter(c -> ret.columns.containsKey(c.getName())) + ret.columnSchemaMap = ret.schema.getColumnNames().stream() + .filter(c -> ret.columns.containsKey(c)) .map(c -> { - var schema = new ColumnSchema(c); - schema.setName(ret.columns.get(c.getName())); + var schema = new ColumnSchema(c, 0); + schema.setName(ret.columns.get(c)); return schema; }) .collect(Collectors.toMap(ColumnSchema::getName, Function.identity())); @@ -647,25 +648,7 @@ private R scanRecords(DataStoreScanRequest req, ResultResolver resultReso for (var entry : table.columnSchemaMap.entrySet()) { var columnName = entry.getKey(); var columnSchema = entry.getValue(); - var old = columnSchemaMap.putIfAbsent(columnName, columnSchema); - if (old != null && !old.isSameType(columnSchema)) { - for (var t : tables) { - if (t.columnSchemaMap.get(columnName) != null) { - throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE, - MessageFormat.format( - "conflicting column type. {0}: column={1}, alias={2}, type={3}, " - + "{4}: column={5}, alias={6}, type={7}", - t.tableName, - columnName, - t.columns.get(columnName), - t.columnSchemaMap.get(columnName), - table.tableName, - columnName, - table.columns.get(columnName), - table.columnSchemaMap.get(columnName))); - } - } - } + columnSchemaMap.putIfAbsent(columnName, columnSchema); } } } @@ -762,14 +745,12 @@ private MemoryTable getTable(String tableName, boolean allowNull, boolean create private Map getColumnAliases(TableSchema schema, Map columns) { if (columns == null || columns.isEmpty()) { - return schema.getColumnSchemaList().stream() - .map(ColumnSchema::getName) + return schema.getColumnNames().stream() .collect(Collectors.toMap(Function.identity(), Function.identity())); } else { var ret = new HashMap(); var invalidColumns = new HashSet<>(columns.keySet()); - for (var columnSchema : schema.getColumnSchemaList()) { - var columnName = columnSchema.getName(); + for (var columnName : schema.getColumnNames()) { var alias = columns.get(columnName); if (alias != null) { ret.put(columnName, alias); diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/TableSchema.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/TableSchema.java index 1197ccc105..ef14aa1f1c 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/TableSchema.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/TableSchema.java @@ -41,6 +41,9 @@ public class TableSchema { private final Map columnSchemaIndex = new HashMap<>(); + /** + * Only used for parquet data encoding and decoding + */ @Getter private final List columnSchemaList = new ArrayList<>(); @@ -161,15 +164,15 @@ public void update(@NonNull Wal.TableSchema schema) { } } - public ColumnSchema getColumnSchemaByName(@NonNull String name) { - var index = this.columnSchemaIndex.get(name); - if (index == null) { - return null; - } - return this.getColumnSchemaByIndex(index); + public String getColumnNameByIndex(int index) { + return this.columnSchemaList.get(index).getName(); + } + + public Integer getColumnIndexByName(@NonNull String name) { + return this.columnSchemaIndex.get(name); } - public ColumnSchema getColumnSchemaByIndex(int index) { - return this.columnSchemaList.get(index); + public List getColumnNames() { + return this.columnSchemaList.stream().map(ColumnSchema::getName).collect(Collectors.toList()); } -} \ No newline at end of file +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/MemoryTableImpl.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/MemoryTableImpl.java index 2480f96639..ce93a99e56 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/MemoryTableImpl.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/MemoryTableImpl.java @@ -508,8 +508,8 @@ public Iterator query( throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE, "order by column should not be null"); } - var colSchema = this.schema.getColumnSchemaByName(col.getColumnName()); - if (colSchema == null) { + var idx = this.schema.getColumnIndexByName(col.getColumnName()); + if (idx == null) { throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE, "unknown orderBy column " + col); } @@ -566,7 +566,7 @@ public Iterator scan( } else { ColumnType startKeyType; if (startType == null) { - startKeyType = this.schema.getColumnSchemaByName(this.schema.getKeyColumn()).getType(); + startKeyType = this.schema.getKeyColumnSchema().getType(); } else { try { startKeyType = ColumnType.valueOf(startType); @@ -588,7 +588,7 @@ public Iterator scan( } else { ColumnType endKeyType; if (endType == null) { - endKeyType = this.schema.getColumnSchemaByName(this.schema.getKeyColumn()).getType(); + endKeyType = this.schema.getKeyColumnSchema().getType(); } else { try { endKeyType = ColumnType.valueOf(endType); diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/WalRecordDecoder.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/WalRecordDecoder.java index ca42580767..a106a7f978 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/WalRecordDecoder.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/WalRecordDecoder.java @@ -16,7 +16,6 @@ package ai.starwhale.mlops.datastore.impl; -import ai.starwhale.mlops.datastore.ColumnSchema; import ai.starwhale.mlops.datastore.ColumnType; import ai.starwhale.mlops.datastore.TableSchema; import ai.starwhale.mlops.datastore.Wal; @@ -51,9 +50,9 @@ public static Map decodeRecord(@NonNull TableSchema recordSch ret.put(MemoryTableImpl.DELETED_FLAG_COLUMN_NAME, BoolValue.TRUE); continue; } - var colSchema = recordSchema.getColumnSchemaByIndex(index); + var colName = recordSchema.getColumnNameByIndex(index); try { - ret.put(colSchema.getName(), WalRecordDecoder.decodeValue(colSchema, col)); + ret.put(colName, WalRecordDecoder.decodeValue(col)); } catch (Exception e) { throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE, MessageFormat.format("failed to decode wal {0}", col.toString()), @@ -63,16 +62,11 @@ public static Map decodeRecord(@NonNull TableSchema recordSch return ret; } - public static BaseValue decodeValue(ColumnSchema columnSchema, Wal.Column col) { + public static BaseValue decodeValue(Wal.Column col) { if (col.getNullValue()) { return null; } - ColumnType type; - if (columnSchema != null) { - type = columnSchema.getType(); - } else { - type = ColumnType.getTypeByIndex(col.getType()); - } + var type = ColumnType.getTypeByIndex(col.getType()); switch (type) { case UNKNOWN: return null; @@ -95,66 +89,47 @@ public static BaseValue decodeValue(ColumnSchema columnSchema, Wal.Column col) { case BYTES: return new BytesValue(ByteBuffer.wrap(col.getBytesValue().toByteArray())); case LIST: - return WalRecordDecoder.decodeList(columnSchema, col); + return WalRecordDecoder.decodeList(col); case TUPLE: - return WalRecordDecoder.decodeTuple(columnSchema, col); + return WalRecordDecoder.decodeTuple(col); case MAP: - return WalRecordDecoder.decodeMap(columnSchema, col); + return WalRecordDecoder.decodeMap(col); case OBJECT: - return WalRecordDecoder.decodeObject(columnSchema, col); + return WalRecordDecoder.decodeObject(col); default: throw new IllegalArgumentException("invalid type " + type); } } - private static BaseValue decodeList(ColumnSchema columnSchema, @NonNull Wal.Column col) { + private static BaseValue decodeList(@NonNull Wal.Column col) { var ret = new ListValue(); var values = col.getListValueList(); - for (var i = 0; i < values.size(); i++) { - ColumnSchema schema = null; - if (columnSchema != null) { - var sparse = columnSchema.getSparseElementSchema(); - if (sparse != null) { - schema = sparse.get(i); - } - if (schema == null) { - schema = columnSchema.getElementSchema(); - } - } - ret.add(WalRecordDecoder.decodeValue(schema, values.get(i))); + for (Wal.Column value : values) { + ret.add(WalRecordDecoder.decodeValue(value)); } return ret; } - private static BaseValue decodeTuple(ColumnSchema columnSchema, @NonNull Wal.Column col) { + private static BaseValue decodeTuple(@NonNull Wal.Column col) { var ret = new TupleValue(); - ret.addAll((ListValue) WalRecordDecoder.decodeList(columnSchema, col)); + ret.addAll((ListValue) WalRecordDecoder.decodeList(col)); return ret; } - private static BaseValue decodeMap(ColumnSchema columnSchema, @NonNull Wal.Column col) { + private static BaseValue decodeMap(@NonNull Wal.Column col) { var ret = new MapValue(); - // the type info in columnSchema is not reliable, do not use it for (var entry : col.getMapValueList()) { - ret.put(WalRecordDecoder.decodeValue(null, entry.getKey()), - WalRecordDecoder.decodeValue(null, entry.getValue())); + ret.put(WalRecordDecoder.decodeValue(entry.getKey()), + WalRecordDecoder.decodeValue(entry.getValue())); } return ret; } - private static BaseValue decodeObject(ColumnSchema columnSchema, @NonNull Wal.Column col) { - String pythonType; - Map attrMap; - if (columnSchema != null) { - pythonType = columnSchema.getPythonType(); - attrMap = columnSchema.getAttributesSchema(); - } else { - pythonType = col.getStringValue(); - attrMap = Map.of(); - } + private static BaseValue decodeObject(@NonNull Wal.Column col) { + var pythonType = col.getStringValue(); var ret = new ObjectValue(pythonType); col.getObjectValueMap().forEach( - (k, v) -> ret.put(k, WalRecordDecoder.decodeValue(columnSchema == null ? null : attrMap.get(k), v))); + (k, v) -> ret.put(k, WalRecordDecoder.decodeValue(v))); return ret; } } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/WalRecordEncoder.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/WalRecordEncoder.java index 4f537e7ecd..80b11bb6f7 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/WalRecordEncoder.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/WalRecordEncoder.java @@ -41,7 +41,7 @@ public static Wal.Record.Builder encodeRecord(@NonNull TableSchema schema, @NonN } else { col = BaseValue.encodeWal(entry.getValue()); } - col.setIndex(schema.getColumnSchemaByName(name).getIndex()); + col.setIndex(schema.getColumnIndexByName(name)); ret.addColumns(col); } catch (Exception e) { throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE, diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/parquet/SwReadSupport.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/parquet/SwReadSupport.java index da4ce6637d..11052865af 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/parquet/SwReadSupport.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/parquet/SwReadSupport.java @@ -157,7 +157,7 @@ private void mergeList(ListValue value, Wal.Column wal) { wal.getListValueList() .forEach(col -> { if (col.getIndex() >= 0) { - value.add(col.getIndex(), WalRecordDecoder.decodeValue(null, col)); + value.add(col.getIndex(), WalRecordDecoder.decodeValue(col)); } else { var element = value.get(-col.getIndex() - 1); switch (element.getColumnType()) { @@ -182,7 +182,7 @@ private void mergeList(ListValue value, Wal.Column wal) { private void mergeMap(MapValue value, Wal.Column wal) { for (var entry : wal.getMapValueList()) { this.mergeMapEntry(value, - WalRecordDecoder.decodeValue(null, entry.getKey()), + WalRecordDecoder.decodeValue(entry.getKey()), entry.getValue()); } } @@ -194,7 +194,7 @@ private void mergeObject(ObjectValue value, Wal.Column wal) { private void mergeMapEntry(Map value, T k, Wal.Column v) { var old = value.get(k); if (old == null) { - value.put(k, WalRecordDecoder.decodeValue(null, v)); + value.put(k, WalRecordDecoder.decodeValue(v)); } else { switch (old.getColumnType()) { case LIST: diff --git a/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java index 567dad7f8f..15b83108c8 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java @@ -357,8 +357,8 @@ public void testUpdate() throws InterruptedException { assertThat("t2", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("t2", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("b").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("b").build())); assertThat("t2", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(Map.of("k", "00000003", "b", "00000002")))); @@ -757,9 +757,9 @@ public void testQueryDefault() { assertThat("test", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("a").type("INT32").build(), - ColumnSchemaDesc.builder().name("x").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("a").build(), + ColumnSchemaDesc.builder().name("x").build())); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(Map.of("k", "00000000", "a", "00000005"), @@ -789,8 +789,8 @@ public void testQuery() { assertThat("test", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("b").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("b").build())); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(Map.of("k", "00000001", "b", "00000004"), @@ -828,8 +828,8 @@ public void testQuery() { assertThat("test", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("b").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("b").build())); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(Map.of("k", "00000001", "b", "00000004")))); @@ -854,9 +854,9 @@ public void testQuery() { assertThat("test", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("b").type("INT32").build(), - ColumnSchemaDesc.builder().name("x").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("b").build(), + ColumnSchemaDesc.builder().name("x").build())); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(new HashMap<>() { @@ -886,9 +886,9 @@ public void testQuery() { assertThat("test", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("b").type("INT32").build(), - ColumnSchemaDesc.builder().name("x").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("b").build(), + ColumnSchemaDesc.builder().name("x").build())); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(new HashMap<>() { @@ -1051,8 +1051,8 @@ public void testEqualNull() { var resp = DataStoreControllerTest.this.controller.queryTable(this.req); assertThat(resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat(Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("b").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("b").build())); assertThat(Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(new HashMap<>() { { @@ -1417,8 +1417,8 @@ public void testScanDefault() { assertThat("test", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("a").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("a").build())); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(Map.of("k", "00000000", "a", "00000005"), @@ -1445,9 +1445,9 @@ public void testScan() { assertThat("test", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("a").type("INT32").build(), - ColumnSchemaDesc.builder().name("b").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("a").build(), + ColumnSchemaDesc.builder().name("b").build())); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(Map.of("k", "00000001", "b", "00000001", "a", "00000010"), @@ -1479,9 +1479,9 @@ public void testScan() { assertThat("test", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("a").type("INT32").build(), - ColumnSchemaDesc.builder().name("b").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("a").build(), + ColumnSchemaDesc.builder().name("b").build())); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(Map.of("k", "00000001", "b", "00000001", "a", "00000010")))); @@ -1506,9 +1506,9 @@ public void testScan() { assertThat("test", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("a").type("INT32").build(), - ColumnSchemaDesc.builder().name("b").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("a").build(), + ColumnSchemaDesc.builder().name("b").build())); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(Map.of("k", "1", "b", "1", "a", "16")))); @@ -1533,10 +1533,10 @@ public void testScan() { assertThat("test", resp.getStatusCode().is2xxSuccessful(), is(true)); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getColumnTypes(), - containsInAnyOrder(ColumnSchemaDesc.builder().name("k").type("INT32").build(), - ColumnSchemaDesc.builder().name("a").type("INT32").build(), - ColumnSchemaDesc.builder().name("xb").type("INT32").build(), - ColumnSchemaDesc.builder().name("xa").type("INT32").build())); + containsInAnyOrder(ColumnSchemaDesc.builder().name("k").build(), + ColumnSchemaDesc.builder().name("a").build(), + ColumnSchemaDesc.builder().name("xb").build(), + ColumnSchemaDesc.builder().name("xa").build())); assertThat("test", Objects.requireNonNull(resp.getBody()).getData().getRecords(), is(List.of(Map.of("xa", "4", "xb", "1", "k", "1", "a", "16")))); diff --git a/server/controller/src/test/java/ai/starwhale/mlops/datastore/DataStoreTest.java b/server/controller/src/test/java/ai/starwhale/mlops/datastore/DataStoreTest.java index 3a60dbe461..b8c6719785 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/datastore/DataStoreTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/datastore/DataStoreTest.java @@ -238,15 +238,15 @@ public void testUpdate() { public void testMigration() { var srcTable = "t1"; var srcDesc = new TableSchemaDesc("k", - List.of(ColumnSchemaDesc.string().name("k").build(), - ColumnSchemaDesc.int32().name("a").build())); + List.of(ColumnSchemaDesc.string().name("k").build(), + ColumnSchemaDesc.int32().name("a").build())); this.dataStore.update(srcTable, - srcDesc, - List.of(Map.of("k", "0", "a", "5"), - Map.of("k", "1", "a", "4"), - Map.of("k", "2", "a", "3"), - Map.of("k", "3", "a", "2"), - Map.of("k", "4", "a", "1")) + srcDesc, + List.of(Map.of("k", "0", "a", "5"), + Map.of("k", "1", "a", "4"), + Map.of("k", "2", "a", "3"), + Map.of("k", "3", "a", "2"), + Map.of("k", "4", "a", "1")) ); var targetTable = "t2"; @@ -254,40 +254,40 @@ public void testMigration() { assertThrows( SwValidationException.class, () -> this.dataStore.migration(DataStoreMigrationRequest.builder() - .srcTableName(srcTable) - .targetTableName(targetTable) - .createNonExistingTargetTable(false) - .build()) + .srcTableName(srcTable) + .targetTableName(targetTable) + .createNonExistingTargetTable(false) + .build()) ); // case: target table doesn't exist but allow to create, migration with filter var length = this.dataStore.migration(DataStoreMigrationRequest.builder() - .srcTableName(srcTable) - .targetTableName(targetTable) - .filter(TableQueryFilter.builder() - .operator(Operator.OR) - .operands(List.of( - TableQueryFilter.builder() - .operator(Operator.EQUAL) - .operands(List.of( - new TableQueryFilter.Column("k"), - new TableQueryFilter.Constant( - ColumnType.STRING, "0") - )) - .build(), - TableQueryFilter.builder() - .operator(Operator.EQUAL) - .operands(List.of( - new TableQueryFilter.Column("k"), - new TableQueryFilter.Constant( - ColumnType.STRING, "2") - )) - .build() - )) - .build() - ) - .createNonExistingTargetTable(true) - .build()); + .srcTableName(srcTable) + .targetTableName(targetTable) + .filter(TableQueryFilter.builder() + .operator(Operator.OR) + .operands(List.of( + TableQueryFilter.builder() + .operator(Operator.EQUAL) + .operands(List.of( + new TableQueryFilter.Column("k"), + new TableQueryFilter.Constant( + ColumnType.STRING, "0") + )) + .build(), + TableQueryFilter.builder() + .operator(Operator.EQUAL) + .operands(List.of( + new TableQueryFilter.Column("k"), + new TableQueryFilter.Constant( + ColumnType.STRING, "2") + )) + .build() + )) + .build() + ) + .createNonExistingTargetTable(true) + .build()); assertEquals(2, length); var recordList = this.dataStore.query( @@ -300,10 +300,10 @@ public void testMigration() { // case: migration all records without filter length = this.dataStore.migration(DataStoreMigrationRequest.builder() - .srcTableName(srcTable) - .targetTableName(targetTable) - .createNonExistingTargetTable(true) - .build()); + .srcTableName(srcTable) + .targetTableName(targetTable) + .createNonExistingTargetTable(true) + .build()); assertEquals(5, length); recordList = this.dataStore.query( @@ -338,10 +338,6 @@ public void testQuery() { .start(1) .limit(2) .build()); - assertThat("test", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("a", ColumnType.INT32))); assertThat("test", recordList.getRecords(), is(List.of(Map.of("a", "00000003"), @@ -373,10 +369,6 @@ public void testQuery() { .start(1) .limit(2) .build()); - assertThat("all columns", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("k", ColumnType.STRING, "a", ColumnType.INT32))); assertThat("all columns", recordList.getRecords(), is(List.of(Map.of("k", "2", "a", "00000003"), @@ -401,16 +393,12 @@ public void testQuery() { Map.of("k", "6", "x:link/url", "http://test.com/2.png", "x:link/mime_type", "image/png"))); recordList = this.dataStore.query(DataStoreQueryRequest.builder().tableName("t1").build()); assertThat("object type", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("k", - ColumnType.STRING, - "a", - ColumnType.INT32, - "x:link/url", - ColumnType.STRING, - "x:link/mime_type", - ColumnType.STRING))); + recordList.getColumnSchemaMap() + .keySet() + .stream() + .sorted(Comparator.naturalOrder()) + .collect(Collectors.toList()), + is(List.of("a", "k", "x:link/mime_type", "x:link/url"))); assertThat("object type", recordList.getRecords(), is(List.of(Map.of("k", "0", "a", "00000005"), @@ -443,9 +431,12 @@ public void testQuery() { .columns(Map.of("x", "y", "x:link/url", "url")) .build()); assertThat("object type alias", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("url", ColumnType.STRING, "y:link/mime_type", ColumnType.STRING))); + recordList.getColumnSchemaMap() + .keySet() + .stream() + .sorted(Comparator.naturalOrder()) + .collect(Collectors.toList()), + is(List.of("url", "y:link/mime_type"))); assertThat("object type alias", recordList.getRecords(), is(List.of(Map.of(), @@ -521,7 +512,7 @@ public Object apply(String str, Boolean rawResult) { } var encodeString = new EncodeString(); - var testParams = new boolean[] {true, false, true, false, true, false}; + var testParams = new boolean[]{true, false, true, false, true, false}; for (boolean rawResult : testParams) { var recordList = this.dataStore.query(DataStoreQueryRequest.builder() .tableName("t1") @@ -693,16 +684,16 @@ public void testScanTableRange() { assertThat("part range test with exactly param", rangeList.getRanges(), is(List.of( - KeyRangeList.Range.builder() - .start("1").startType("STRING").startInclusive(true) - .end("3").endType("STRING").endInclusive(false) - .size(2) - .build(), - KeyRangeList.Range.builder() - .start("3").startType("STRING").startInclusive(true) - .end("3").endType("STRING").endInclusive(true) - .size(1) - .build() + KeyRangeList.Range.builder() + .start("1").startType("STRING").startInclusive(true) + .end("3").endType("STRING").endInclusive(false) + .size(2) + .build(), + KeyRangeList.Range.builder() + .start("3").startType("STRING").startInclusive(true) + .end("3").endType("STRING").endInclusive(true) + .size(1) + .build() )) ); // only one item(for the range test) @@ -719,9 +710,8 @@ public void testScanTableRange() { .keepNone(true) .build()); assertThat("scan one item", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("a", ColumnType.INT32))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a"))); assertThat("scan one item", recordList.getRecords(), is(List.of(Map.of("a", "00000002")))); @@ -745,16 +735,16 @@ public void testScanTableRange() { assertThat("all range test with exactly param", rangeList.getRanges(), is(List.of( - KeyRangeList.Range.builder() - .start("0").startType("STRING").startInclusive(true) - .end("2").endType("STRING").endInclusive(false) - .size(2) - .build(), - KeyRangeList.Range.builder() - .start("2").startType("STRING").startInclusive(true) - .end("3").endType("STRING").endInclusive(true) - .size(2) - .build() + KeyRangeList.Range.builder() + .start("0").startType("STRING").startInclusive(true) + .end("2").endType("STRING").endInclusive(false) + .size(2) + .build(), + KeyRangeList.Range.builder() + .start("2").startType("STRING").startInclusive(true) + .end("3").endType("STRING").endInclusive(true) + .size(2) + .build() )) ); @@ -775,21 +765,21 @@ public void testScanTableRange() { assertThat("all range test with exactly param", rangeList.getRanges(), is(List.of( - KeyRangeList.Range.builder() - .start("0").startType("STRING").startInclusive(true) - .end("2").endType("STRING").endInclusive(false) - .size(2) - .build(), - KeyRangeList.Range.builder() - .start("2").startType("STRING").startInclusive(true) - .end("4").endType("STRING").endInclusive(false) - .size(2) - .build(), - KeyRangeList.Range.builder() - .start("4").startType("STRING").startInclusive(true) - .end("4").endType("STRING").endInclusive(true) - .size(1) - .build() + KeyRangeList.Range.builder() + .start("0").startType("STRING").startInclusive(true) + .end("2").endType("STRING").endInclusive(false) + .size(2) + .build(), + KeyRangeList.Range.builder() + .start("2").startType("STRING").startInclusive(true) + .end("4").endType("STRING").endInclusive(false) + .size(2) + .build(), + KeyRangeList.Range.builder() + .start("4").startType("STRING").startInclusive(true) + .end("4").endType("STRING").endInclusive(true) + .size(1) + .build() )) ); @@ -806,21 +796,21 @@ public void testScanTableRange() { assertThat("all range test without param", rangeList.getRanges(), is(List.of( - KeyRangeList.Range.builder() - .start("0").startType("STRING").startInclusive(true) - .end("2").endType("STRING").endInclusive(false) - .size(2) - .build(), - KeyRangeList.Range.builder() - .start("2").startType("STRING").startInclusive(true) - .end("4").endType("STRING").endInclusive(false) - .size(2) - .build(), - KeyRangeList.Range.builder() - .start("4").startType("STRING").startInclusive(true) - .end(null).endType("STRING").endInclusive(false) - .size(1) - .build() + KeyRangeList.Range.builder() + .start("0").startType("STRING").startInclusive(true) + .end("2").endType("STRING").endInclusive(false) + .size(2) + .build(), + KeyRangeList.Range.builder() + .start("2").startType("STRING").startInclusive(true) + .end("4").endType("STRING").endInclusive(false) + .size(2) + .build(), + KeyRangeList.Range.builder() + .start("4").startType("STRING").startInclusive(true) + .end(null).endType("STRING").endInclusive(false) + .size(1) + .build() )) ); } @@ -855,9 +845,8 @@ public void testScanOneTable() { .keepNone(true) .build()); assertThat("test", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("a", ColumnType.INT32))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a"))); assertThat("test", recordList.getRecords(), is(List.of(Map.of("a", "00000004"), @@ -880,9 +869,8 @@ public void testScanOneTable() { .limit(2) .build()); assertThat("test", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("a", ColumnType.INT32))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a"))); assertThat("test", recordList.getRecords(), is(List.of(Map.of("a", "00000004"), Map.of()))); @@ -900,9 +888,8 @@ public void testScanOneTable() { .limit(3) .build()); assertThat("all columns", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("k", ColumnType.STRING, "a", ColumnType.INT32, "l", ColumnType.FLOAT64))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a", "k", "l"))); assertThat("all columns", recordList.getRecords(), is(List.of( @@ -920,9 +907,8 @@ public void testScanOneTable() { .limit(0) .build()); assertThat("schema only", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("k", ColumnType.STRING, "a", ColumnType.INT32, "l", ColumnType.FLOAT64))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a", "k", "l"))); assertThat("schema only", recordList.getRecords(), empty()); desc.setColumnSchemaList(new ArrayList<>(desc.getColumnSchemaList())); desc.getColumnSchemaList().addAll( @@ -938,18 +924,13 @@ public void testScanOneTable() { .build())) .build()); assertThat("object type", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("k", - ColumnType.STRING, - "a", - ColumnType.INT32, + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a", + "k", "l", - ColumnType.FLOAT64, - "x:link/url", - ColumnType.STRING, "x:link/mime_type", - ColumnType.STRING))); + "x:link/url" + ))); assertThat("object type", recordList.getRecords(), is(List.of(Map.of("k", "0", "a", "00000005", "l", "3ff8000000000000"), @@ -968,9 +949,8 @@ public void testScanOneTable() { .build())) .build()); assertThat("object type alias", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("url", ColumnType.STRING, "y:link/mime_type", ColumnType.STRING))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("url", "y:link/mime_type"))); assertThat("object type alias", recordList.getRecords(), is(List.of(Map.of(), @@ -1036,14 +1016,8 @@ public void testScanMultipleTables() { .keepNone(true) .build()); assertThat("test", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("k", - ColumnType.STRING, - "a", - ColumnType.INT32, - "b", - ColumnType.INT32))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a", "b", "k"))); assertThat("test", recordList.getRecords(), is(List.of(Map.of("k", "0", "a", "00000005", "b", "00000015"), @@ -1085,14 +1059,8 @@ public void testScanMultipleTables() { .build())) .build()); assertThat("test", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("k", - ColumnType.STRING, - "a", - ColumnType.INT32, - "b", - ColumnType.INT32))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a", "b", "k"))); assertThat("test", recordList.getRecords(), is(List.of(Map.of("k", "0", "a", "00000005", "b", "00000015"), @@ -1118,9 +1086,8 @@ public void testScanMultipleTables() { .start("7") .build()); assertThat("empty", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("a", ColumnType.INT32, "b", ColumnType.INT32, "k", ColumnType.STRING))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a", "b", "k"))); assertThat("empty", recordList.getRecords(), empty()); assertThat("empty", recordList.getLastKey(), nullValue()); @@ -1141,9 +1108,8 @@ public void testScanMultipleTables() { .keepNone(true) .build()); assertThat("alias", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("k", ColumnType.STRING, "a", ColumnType.INT32))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a", "k"))); assertThat("alias", recordList.getRecords(), is(List.of(Map.of("k", "0", "a", "00000015"), @@ -1160,13 +1126,6 @@ public void testScanMultipleTables() { .tables(List.of(DataStoreScanRequest.TableInfo.builder().tableName("t1").build(), DataStoreScanRequest.TableInfo.builder().tableName("t4").build())) .build())); - assertThrows(SwValidationException.class, () -> this.dataStore.scan(DataStoreScanRequest.builder() - .tables(List.of(DataStoreScanRequest.TableInfo.builder().tableName("t1").build(), - DataStoreScanRequest.TableInfo.builder() - .tableName("t2") - .columns(Map.of("k", "a")) - .build())) - .build())); // scan non exist table final String tableNonExist = "tableNonExist"; @@ -1178,9 +1137,8 @@ public void testScanMultipleTables() { recordList = this.dataStore.scan(builder.ignoreNonExistingTable(true).build()); assertThat("result of non exist table", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("k", ColumnType.STRING, "a", ColumnType.INT32))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("a", "k"))); assertThat("result of non exist table", recordList.getRecords(), is(List.of(Map.of("k", "0", "a", "00000005")))); @@ -1197,16 +1155,8 @@ public void testScanMultipleTables() { .build())) .build()); assertThat("column prefix", - recordList.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(Map.of("ak", - ColumnType.STRING, - "bk", - ColumnType.STRING, - "aa", - ColumnType.INT32, - "bb", - ColumnType.INT32))); + recordList.getColumnSchemaMap().keySet().stream().sorted().collect(Collectors.toList()), + is(List.of("aa", "ak", "bb", "bk"))); assertThat("column prefix", recordList.getRecords(), is(List.of(Map.of("ak", "0", "bk", "0", "aa", "00000005", "bb", "00000015"), @@ -1622,7 +1572,8 @@ public void testAllTypes() throws Exception { var schema = new TableSchemaDesc("key", columnSchemaList); var expected = new RecordList( columnSchemaList.stream() - .collect(Collectors.toMap(ColumnSchemaDesc::getName, col -> new ColumnSchema(col, 0))), + .collect( + Collectors.toMap(ColumnSchemaDesc::getName, col -> new ColumnSchema(col.getName(), 0))), new HashMap<>() { { put("key", ColumnHintsDesc.builder() @@ -1785,7 +1736,7 @@ public void testAllTypes() throws Exception { .keepNone(true) .build()); result.getColumnSchemaMap().entrySet() - .forEach(entry -> entry.setValue(new ColumnSchema(entry.getValue().toColumnSchemaDesc(), 0))); + .forEach(entry -> entry.setValue(new ColumnSchema(entry.getValue().getName(), 0))); var originFlatMap = records.get(0).get("o"); // hack flat map in record list @@ -1804,7 +1755,9 @@ public void testAllTypes() throws Exception { .build()); // restore flat map in record list records.get(0).put("o", originFlatMap); - var encoded = encodeResultWithType(expected); + var schemaMap = columnSchemaList.stream() + .collect(Collectors.toMap(ColumnSchemaDesc::getName, col -> new ColumnSchema(col, 0))); + var encoded = encodeResultWithType(schemaMap, expected); // encode of map value will return list of map, // the order of the list items is not guaranteed, @@ -1830,29 +1783,6 @@ public void testAllTypes() throws Exception { List.of(Map.of("key", "1"))); result = this.dataStore.scan(DataStoreScanRequest.builder() .tables(List.of(DataStoreScanRequest.TableInfo.builder().tableName("t").build())).build()); - assertThat(result.getColumnSchemaMap().entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getType())), - is(new HashMap<>() { - { - put("key", ColumnType.INT32); - put("a", ColumnType.BOOL); - put("b", ColumnType.INT8); - put("c", ColumnType.INT16); - put("d", ColumnType.INT32); - put("e", ColumnType.INT64); - put("f", ColumnType.FLOAT32); - put("g", ColumnType.FLOAT64); - put("h", ColumnType.BYTES); - put("i", ColumnType.UNKNOWN); - put("j", ColumnType.LIST); - put("k", ColumnType.OBJECT); - put("l", ColumnType.TUPLE); - put("m", ColumnType.MAP); - put("n", ColumnType.LIST); - put("o", ColumnType.MAP); - put("complex", ColumnType.OBJECT); - } - })); result = this.dataStore.scan(DataStoreScanRequest.builder() .tables(List.of(DataStoreScanRequest.TableInfo.builder() .tableName("t") @@ -1902,13 +1832,13 @@ public void testAllTypes() throws Exception { assertEquals(encoded, result); } - private static RecordList encodeResultWithType(RecordList records) { + private static RecordList encodeResultWithType(Map schemaMap, RecordList records) { return new RecordList(null, records.getColumnHints(), records.getRecords().stream() .map(r -> r.entrySet().stream() .collect(Collectors.toMap(Entry::getKey, - entry -> encodeValueWithType(records.getColumnSchemaMap().get(entry.getKey()), + entry -> encodeValueWithType(schemaMap.get(entry.getKey()), entry.getValue())))) .collect(Collectors.toList()), records.getLastKey(), diff --git a/server/controller/src/test/java/ai/starwhale/mlops/datastore/TableSchemaTest.java b/server/controller/src/test/java/ai/starwhale/mlops/datastore/TableSchemaTest.java index d105d49f76..2cf55a1048 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/datastore/TableSchemaTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/datastore/TableSchemaTest.java @@ -62,9 +62,6 @@ public void setUp() { @Test public void testConstructor() { assertThat(this.schema.getKeyColumn(), is("k")); - assertThat(this.schema.getColumnSchemaByName("list").toColumnSchemaDesc(), is(this.listSchemaDesc)); - assertThat(this.schema.getColumnSchemaByName("map").toColumnSchemaDesc(), is(this.mapSchemaDesc)); - assertThat(this.schema.getColumnSchemaByName("obj").toColumnSchemaDesc(), is(this.objectSchemaDesc)); new TableSchema(new TableSchemaDesc( "k", List.of(ColumnSchemaDesc.builder().name("k").type("STRING").build(), @@ -99,11 +96,11 @@ public void testGetDiffAndUpdateColumnType() { var newK = Wal.ColumnSchema.newBuilder() .setColumnName("k") .setColumnType("INT32") - .setColumnIndex(this.schema.getColumnSchemaByName("k").getIndex()); + .setColumnIndex(this.schema.getColumnIndexByName("k")); var newList = Wal.ColumnSchema.newBuilder() .setColumnName("list") .setColumnType("STRING") - .setColumnIndex(this.schema.getColumnSchemaByName("list").getIndex()); + .setColumnIndex(this.schema.getColumnIndexByName("list")); assertThat(diff, is(Wal.TableSchema.newBuilder() .addColumns(newK) .addColumns(newList) @@ -113,14 +110,10 @@ public void testGetDiffAndUpdateColumnType() { walMap.put("k", newK); walMap.put("list", newList); this.schema.update(diff); - walMap.forEach((k, v) -> assertThat(this.schema.getColumnSchemaByName(k).toWal().build(), is(v.build()))); - - var listCol = this.schema.getColumnSchemaByName("list"); - assertThat(listCol.getType(), is(ColumnType.STRING)); - assertThat(listCol.getElementSchema(), nullValue()); - assertThat(listCol.getKeySchema(), nullValue()); - assertThat(listCol.getValueSchema(), nullValue()); - assertThat(listCol.getPythonType(), nullValue()); + this.schema.toWal() + .build() + .getColumnsList() + .forEach(c -> assertThat(c, is(walMap.get(c.getColumnName()).build()))); } @Test @@ -175,10 +168,13 @@ public void testGetDiffAndUpdateNewColumn() { var walMap = this.schema.getColumnSchemaList().stream() .collect(Collectors.toMap(ColumnSchema::getName, ColumnSchema::toWal)); this.schema.update(diff); - assertThat(this.schema.getColumnSchemaByName("b").toWal().build(), is(newB.build())); + walMap.put("b", newB); newC.setElementType(newC.getElementTypeBuilder().setColumnName("element")); - assertThat(this.schema.getColumnSchemaByName("c").toWal().build(), is(newC.build())); - walMap.forEach((k, v) -> assertThat(this.schema.getColumnSchemaByName(k).toWal().build(), is(v.build()))); + walMap.put("c", newC); + this.schema.toWal() + .build() + .getColumnsList() + .forEach(c -> assertThat(c, is(walMap.get(c.getColumnName()).build()))); } @Test @@ -193,7 +189,7 @@ public void testGetDiffAndUpdateListElement() { var newList = Wal.ColumnSchema.newBuilder() .setColumnName("list") .setColumnType("LIST") - .setColumnIndex(this.schema.getColumnSchemaByName("list").getIndex()) + .setColumnIndex(this.schema.getColumnIndexByName("list")) .setElementType(Wal.ColumnSchema.newBuilder() .setColumnType("INT32") .setColumnName("element") @@ -204,7 +200,10 @@ public void testGetDiffAndUpdateListElement() { newList.setElementType(newList.getElementTypeBuilder().setColumnName("element")); walMap.put("list", newList); this.schema.update(diff); - walMap.forEach((k, v) -> assertThat(this.schema.getColumnSchemaByName(k).toWal().build(), is(v.build()))); + this.schema.toWal() + .build() + .getColumnsList() + .forEach(c -> assertThat(c, is(walMap.get(c.getColumnName()).build()))); } @Test @@ -220,7 +219,7 @@ public void testGetDiffAndUpdateMapKey() { var newMap = Wal.ColumnSchema.newBuilder() .setColumnName("map") .setColumnType("MAP") - .setColumnIndex(this.schema.getColumnSchemaByName("map").getIndex()) + .setColumnIndex(this.schema.getColumnIndexByName("map")) .setKeyType(Wal.ColumnSchema.newBuilder() .setColumnName("key") .setColumnType("INT8") @@ -232,7 +231,10 @@ public void testGetDiffAndUpdateMapKey() { newMap.setValueType(walMap.get("map").getValueType()); walMap.put("map", newMap); this.schema.update(diff); - walMap.forEach((k, v) -> assertThat(this.schema.getColumnSchemaByName(k).toWal().build(), is(v.build()))); + this.schema.toWal() + .build() + .getColumnsList() + .forEach(c -> assertThat(c, is(walMap.get(c.getColumnName()).build()))); } @Test @@ -248,7 +250,7 @@ public void testGetDiffAndUpdateMapValue() { var newMap = Wal.ColumnSchema.newBuilder() .setColumnName("map") .setColumnType("MAP") - .setColumnIndex(this.schema.getColumnSchemaByName("map").getIndex()) + .setColumnIndex(this.schema.getColumnIndexByName("map")) .setValueType(Wal.ColumnSchema.newBuilder() .setColumnName("value") .setColumnType("INT8") @@ -260,7 +262,10 @@ public void testGetDiffAndUpdateMapValue() { newMap.setValueType(newMap.getValueTypeBuilder().setColumnName("value")); walMap.put("map", newMap); this.schema.update(diff); - walMap.forEach((k, v) -> assertThat(this.schema.getColumnSchemaByName(k).toWal().build(), is(v.build()))); + this.schema.toWal() + .build() + .getColumnsList() + .forEach(c -> assertThat(c, is(walMap.get(c.getColumnName()).build()))); } @Test @@ -276,14 +281,17 @@ public void testGetDiffAndUpdateObjectPythonType() { .setColumnName("obj") .setColumnType("OBJECT") .setPythonType("tt") - .setColumnIndex(this.schema.getColumnSchemaByName("obj").getIndex()); + .setColumnIndex(this.schema.getColumnIndexByName("obj")); assertThat(diff, is(Wal.TableSchema.newBuilder().addColumns(newObj).build())); var walMap = this.schema.getColumnSchemaList().stream() .collect(Collectors.toMap(ColumnSchema::getName, ColumnSchema::toWal)); newObj.addAllAttributes(walMap.get("obj").getAttributesList()); walMap.put("obj", newObj); this.schema.update(diff); - walMap.forEach((k, v) -> assertThat(this.schema.getColumnSchemaByName(k).toWal().build(), is(v.build()))); + this.schema.toWal() + .build() + .getColumnsList() + .forEach(c -> assertThat(c, is(walMap.get(c.getColumnName()).build()))); } @Test @@ -301,7 +309,7 @@ public void testGetDiffAndUpdateObjectAttributes() { var newObj = Wal.ColumnSchema.newBuilder() .setColumnName("obj") .setColumnType("OBJECT") - .setColumnIndex(this.schema.getColumnSchemaByName("obj").getIndex()) + .setColumnIndex(this.schema.getColumnIndexByName("obj")) .addAttributes(Wal.ColumnSchema.newBuilder() .setColumnName("a") .setColumnType("INT64") @@ -317,7 +325,10 @@ public void testGetDiffAndUpdateObjectAttributes() { newObj.addAttributes(1, walMap.get("obj").getAttributes(1)); walMap.put("obj", newObj); this.schema.update(diff); - walMap.forEach((k, v) -> assertThat(this.schema.getColumnSchemaByName(k).toWal().build(), is(v.build()))); + this.schema.toWal() + .build() + .getColumnsList() + .forEach(c -> assertThat(c, is(walMap.get(c.getColumnName()).build()))); } } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/datastore/impl/MemoryTableImplTest.java b/server/controller/src/test/java/ai/starwhale/mlops/datastore/impl/MemoryTableImplTest.java index af821ba1c7..dfaf82848b 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/datastore/impl/MemoryTableImplTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/datastore/impl/MemoryTableImplTest.java @@ -305,9 +305,6 @@ public void testUpdateCommon() { put("x", null); } })); - assertThat("unknown", - this.memoryTable.getSchema().getColumnSchemaByName("x").getType(), - is(ColumnType.UNKNOWN)); assertThat("unknown", scanAll(this.memoryTable, List.of("k", "a", "b", "c", "a-b/c/d:e_f"), false), contains(new RecordResult(BaseValue.valueOf("0"), false, @@ -325,9 +322,6 @@ public void testUpdateCommon() { desc.getColumnSchemaList().set(desc.getColumnSchemaList().size() - 1, ColumnSchemaDesc.builder().name("x").type("INT32").build()); this.memoryTable.update(desc, List.of(Map.of("k", "1", "x", "1"))); - assertThat("update unknown", - this.memoryTable.getSchema().getColumnSchemaByName("x").getType(), - is(ColumnType.INT32)); assertThat("update unknown", scanAll(this.memoryTable, List.of("k", "a", "b", "c", "a-b/c/d:e_f", "x"), false), contains(new RecordResult(BaseValue.valueOf("0"),