Skip to content

Commit

Permalink
refactor(controller): remove column schemas dependency (#3057)
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui authored Dec 20, 2023
1 parent ebc61e6 commit 807d59f
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 370 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,11 @@ public ColumnSchema(@NonNull Wal.ColumnSchema schema) {
public ColumnSchemaDesc toColumnSchemaDesc() {
var builder = ColumnSchemaDesc.builder()
.name(this.name)
.index(this.offset)
.type(this.type.name());
.index(this.offset);
if (this.type == null) {
return builder.build();
}
builder.type(this.type.name());
switch (this.type) {
case LIST:
case TUPLE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,12 @@ public RecordList query(DataStoreQueryRequest req) {
}
Map<String, ColumnSchema> columnSchemaMap;
if (!req.isEncodeWithType()) {
columnSchemaMap = table.getSchema().getColumnSchemaList().stream()
.filter(col -> columns.containsKey(col.getName()))
// columnSchemaMap is useless except column name
columnSchemaMap = table.getSchema().getColumnNames().stream()
.filter(columns::containsKey)
.map(col -> {
var ret = new ColumnSchema(col);
ret.setName(columns.get(col.getName()));
var ret = new ColumnSchema(col, 0);
ret.setName(columns.get(col));
return ret;
})
.collect(Collectors.toMap(ColumnSchema::getName, Function.identity()));
Expand Down Expand Up @@ -613,11 +614,11 @@ private <R> R scanRecords(DataStoreScanRequest req, ResultResolver<R> resultReso
.collect(Collectors.toMap(Entry::getKey,
entry -> info.getColumnPrefix() + entry.getValue()));
}
ret.columnSchemaMap = ret.schema.getColumnSchemaList().stream()
.filter(c -> ret.columns.containsKey(c.getName()))
ret.columnSchemaMap = ret.schema.getColumnNames().stream()
.filter(c -> ret.columns.containsKey(c))
.map(c -> {
var schema = new ColumnSchema(c);
schema.setName(ret.columns.get(c.getName()));
var schema = new ColumnSchema(c, 0);
schema.setName(ret.columns.get(c));
return schema;
})
.collect(Collectors.toMap(ColumnSchema::getName, Function.identity()));
Expand Down Expand Up @@ -647,25 +648,7 @@ private <R> R scanRecords(DataStoreScanRequest req, ResultResolver<R> resultReso
for (var entry : table.columnSchemaMap.entrySet()) {
var columnName = entry.getKey();
var columnSchema = entry.getValue();
var old = columnSchemaMap.putIfAbsent(columnName, columnSchema);
if (old != null && !old.isSameType(columnSchema)) {
for (var t : tables) {
if (t.columnSchemaMap.get(columnName) != null) {
throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE,
MessageFormat.format(
"conflicting column type. {0}: column={1}, alias={2}, type={3}, "
+ "{4}: column={5}, alias={6}, type={7}",
t.tableName,
columnName,
t.columns.get(columnName),
t.columnSchemaMap.get(columnName),
table.tableName,
columnName,
table.columns.get(columnName),
table.columnSchemaMap.get(columnName)));
}
}
}
columnSchemaMap.putIfAbsent(columnName, columnSchema);
}
}
}
Expand Down Expand Up @@ -762,14 +745,12 @@ private MemoryTable getTable(String tableName, boolean allowNull, boolean create

private Map<String, String> getColumnAliases(TableSchema schema, Map<String, String> columns) {
if (columns == null || columns.isEmpty()) {
return schema.getColumnSchemaList().stream()
.map(ColumnSchema::getName)
return schema.getColumnNames().stream()
.collect(Collectors.toMap(Function.identity(), Function.identity()));
} else {
var ret = new HashMap<String, String>();
var invalidColumns = new HashSet<>(columns.keySet());
for (var columnSchema : schema.getColumnSchemaList()) {
var columnName = columnSchema.getName();
for (var columnName : schema.getColumnNames()) {
var alias = columns.get(columnName);
if (alias != null) {
ret.put(columnName, alias);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ public class TableSchema {

private final Map<String, Integer> columnSchemaIndex = new HashMap<>();

/**
* Only used for parquet data encoding and decoding
*/
@Getter
private final List<ColumnSchema> columnSchemaList = new ArrayList<>();

Expand Down Expand Up @@ -161,15 +164,15 @@ public void update(@NonNull Wal.TableSchema schema) {
}
}

public ColumnSchema getColumnSchemaByName(@NonNull String name) {
var index = this.columnSchemaIndex.get(name);
if (index == null) {
return null;
}
return this.getColumnSchemaByIndex(index);
public String getColumnNameByIndex(int index) {
return this.columnSchemaList.get(index).getName();
}

public Integer getColumnIndexByName(@NonNull String name) {
return this.columnSchemaIndex.get(name);
}

public ColumnSchema getColumnSchemaByIndex(int index) {
return this.columnSchemaList.get(index);
public List<String> getColumnNames() {
return this.columnSchemaList.stream().map(ColumnSchema::getName).collect(Collectors.toList());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ public Iterator<RecordResult> query(
throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE,
"order by column should not be null");
}
var colSchema = this.schema.getColumnSchemaByName(col.getColumnName());
if (colSchema == null) {
var idx = this.schema.getColumnIndexByName(col.getColumnName());
if (idx == null) {
throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE,
"unknown orderBy column " + col);
}
Expand Down Expand Up @@ -566,7 +566,7 @@ public Iterator<RecordResult> scan(
} else {
ColumnType startKeyType;
if (startType == null) {
startKeyType = this.schema.getColumnSchemaByName(this.schema.getKeyColumn()).getType();
startKeyType = this.schema.getKeyColumnSchema().getType();
} else {
try {
startKeyType = ColumnType.valueOf(startType);
Expand All @@ -588,7 +588,7 @@ public Iterator<RecordResult> scan(
} else {
ColumnType endKeyType;
if (endType == null) {
endKeyType = this.schema.getColumnSchemaByName(this.schema.getKeyColumn()).getType();
endKeyType = this.schema.getKeyColumnSchema().getType();
} else {
try {
endKeyType = ColumnType.valueOf(endType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package ai.starwhale.mlops.datastore.impl;

import ai.starwhale.mlops.datastore.ColumnSchema;
import ai.starwhale.mlops.datastore.ColumnType;
import ai.starwhale.mlops.datastore.TableSchema;
import ai.starwhale.mlops.datastore.Wal;
Expand Down Expand Up @@ -51,9 +50,9 @@ public static Map<String, BaseValue> decodeRecord(@NonNull TableSchema recordSch
ret.put(MemoryTableImpl.DELETED_FLAG_COLUMN_NAME, BoolValue.TRUE);
continue;
}
var colSchema = recordSchema.getColumnSchemaByIndex(index);
var colName = recordSchema.getColumnNameByIndex(index);
try {
ret.put(colSchema.getName(), WalRecordDecoder.decodeValue(colSchema, col));
ret.put(colName, WalRecordDecoder.decodeValue(col));
} catch (Exception e) {
throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE,
MessageFormat.format("failed to decode wal {0}", col.toString()),
Expand All @@ -63,16 +62,11 @@ public static Map<String, BaseValue> decodeRecord(@NonNull TableSchema recordSch
return ret;
}

public static BaseValue decodeValue(ColumnSchema columnSchema, Wal.Column col) {
public static BaseValue decodeValue(Wal.Column col) {
if (col.getNullValue()) {
return null;
}
ColumnType type;
if (columnSchema != null) {
type = columnSchema.getType();
} else {
type = ColumnType.getTypeByIndex(col.getType());
}
var type = ColumnType.getTypeByIndex(col.getType());
switch (type) {
case UNKNOWN:
return null;
Expand All @@ -95,66 +89,47 @@ public static BaseValue decodeValue(ColumnSchema columnSchema, Wal.Column col) {
case BYTES:
return new BytesValue(ByteBuffer.wrap(col.getBytesValue().toByteArray()));
case LIST:
return WalRecordDecoder.decodeList(columnSchema, col);
return WalRecordDecoder.decodeList(col);
case TUPLE:
return WalRecordDecoder.decodeTuple(columnSchema, col);
return WalRecordDecoder.decodeTuple(col);
case MAP:
return WalRecordDecoder.decodeMap(columnSchema, col);
return WalRecordDecoder.decodeMap(col);
case OBJECT:
return WalRecordDecoder.decodeObject(columnSchema, col);
return WalRecordDecoder.decodeObject(col);
default:
throw new IllegalArgumentException("invalid type " + type);
}
}

private static BaseValue decodeList(ColumnSchema columnSchema, @NonNull Wal.Column col) {
private static BaseValue decodeList(@NonNull Wal.Column col) {
var ret = new ListValue();
var values = col.getListValueList();
for (var i = 0; i < values.size(); i++) {
ColumnSchema schema = null;
if (columnSchema != null) {
var sparse = columnSchema.getSparseElementSchema();
if (sparse != null) {
schema = sparse.get(i);
}
if (schema == null) {
schema = columnSchema.getElementSchema();
}
}
ret.add(WalRecordDecoder.decodeValue(schema, values.get(i)));
for (Wal.Column value : values) {
ret.add(WalRecordDecoder.decodeValue(value));
}
return ret;
}

private static BaseValue decodeTuple(ColumnSchema columnSchema, @NonNull Wal.Column col) {
private static BaseValue decodeTuple(@NonNull Wal.Column col) {
var ret = new TupleValue();
ret.addAll((ListValue) WalRecordDecoder.decodeList(columnSchema, col));
ret.addAll((ListValue) WalRecordDecoder.decodeList(col));
return ret;
}

private static BaseValue decodeMap(ColumnSchema columnSchema, @NonNull Wal.Column col) {
private static BaseValue decodeMap(@NonNull Wal.Column col) {
var ret = new MapValue();
// the type info in columnSchema is not reliable, do not use it
for (var entry : col.getMapValueList()) {
ret.put(WalRecordDecoder.decodeValue(null, entry.getKey()),
WalRecordDecoder.decodeValue(null, entry.getValue()));
ret.put(WalRecordDecoder.decodeValue(entry.getKey()),
WalRecordDecoder.decodeValue(entry.getValue()));
}
return ret;
}

private static BaseValue decodeObject(ColumnSchema columnSchema, @NonNull Wal.Column col) {
String pythonType;
Map<String, ColumnSchema> attrMap;
if (columnSchema != null) {
pythonType = columnSchema.getPythonType();
attrMap = columnSchema.getAttributesSchema();
} else {
pythonType = col.getStringValue();
attrMap = Map.of();
}
private static BaseValue decodeObject(@NonNull Wal.Column col) {
var pythonType = col.getStringValue();
var ret = new ObjectValue(pythonType);
col.getObjectValueMap().forEach(
(k, v) -> ret.put(k, WalRecordDecoder.decodeValue(columnSchema == null ? null : attrMap.get(k), v)));
(k, v) -> ret.put(k, WalRecordDecoder.decodeValue(v)));
return ret;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public static Wal.Record.Builder encodeRecord(@NonNull TableSchema schema, @NonN
} else {
col = BaseValue.encodeWal(entry.getValue());
}
col.setIndex(schema.getColumnSchemaByName(name).getIndex());
col.setIndex(schema.getColumnIndexByName(name));
ret.addColumns(col);
} catch (Exception e) {
throw new SwValidationException(SwValidationException.ValidSubject.DATASTORE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ private void mergeList(ListValue value, Wal.Column wal) {
wal.getListValueList()
.forEach(col -> {
if (col.getIndex() >= 0) {
value.add(col.getIndex(), WalRecordDecoder.decodeValue(null, col));
value.add(col.getIndex(), WalRecordDecoder.decodeValue(col));
} else {
var element = value.get(-col.getIndex() - 1);
switch (element.getColumnType()) {
Expand All @@ -182,7 +182,7 @@ private void mergeList(ListValue value, Wal.Column wal) {
private void mergeMap(MapValue value, Wal.Column wal) {
for (var entry : wal.getMapValueList()) {
this.mergeMapEntry(value,
WalRecordDecoder.decodeValue(null, entry.getKey()),
WalRecordDecoder.decodeValue(entry.getKey()),
entry.getValue());
}
}
Expand All @@ -194,7 +194,7 @@ private void mergeObject(ObjectValue value, Wal.Column wal) {
private <T> void mergeMapEntry(Map<T, BaseValue> value, T k, Wal.Column v) {
var old = value.get(k);
if (old == null) {
value.put(k, WalRecordDecoder.decodeValue(null, v));
value.put(k, WalRecordDecoder.decodeValue(v));
} else {
switch (old.getColumnType()) {
case LIST:
Expand Down
Loading

0 comments on commit 807d59f

Please sign in to comment.