Skip to content

Commit

Permalink
enhance(datastore): make scan atomic (#1119)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuchuan authored Sep 5, 2022
1 parent b77d4b7 commit de6f4f2
Show file tree
Hide file tree
Showing 7 changed files with 1,293 additions and 1,029 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import ai.starwhale.mlops.exception.SWValidationException;
import org.springframework.stereotype.Component;

import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;

@Component
public class DataStore {
Expand All @@ -51,69 +54,180 @@ public void update(String tableName,
TableSchemaDesc schema,
List<Map<String, String>> records) {
var table = this.tables.computeIfAbsent(tableName, k -> new MemoryTableImpl(tableName, this.walManager));
table.update(schema, records);
table.lock();
try {
table.update(schema, records);
} finally {
table.unlock();
}
}

public RecordList query(DataStoreQueryRequest req) {
var table = this.getTable(req.getTableName());
var columns = req.getColumns();

return table.query(columns,
req.getOrderBy(),
req.getFilter(),
req.getStart(),
req.getLimit(),
req.isKeepNone(),
req.isRawResult());
}

public RecordList scan(DataStoreScanRequest req) {
List<TableScanIterator> iters = new ArrayList<>();
for (var info : req.getTables()) {
var table = this.getTable(info.getTableName());
var iter = table.scan(info.getColumns(),
table.lock();
try {
var schema = table.getSchema();
var columns = this.getColumnAliases(schema, req.getColumns());
var columnTypeMap = schema.getColumnTypeMapping(columns);
var results = table.query(
columns,
req.getOrderBy(),
req.getFilter(),
req.getStart(),
req.isStartInclusive(),
req.getEnd(),
req.isEndInclusive(),
info.isKeepNone(),
req.getLimit(),
req.isKeepNone(),
req.isRawResult());
iter.next();
if (iter.getRecord() != null) {
iters.add(iter);
String lastKey;
if (results.isEmpty()) {
lastKey = null;
} else {
lastKey = schema.getKeyColumnType().encode(results.get(results.size() - 1).getKey(), req.isRawResult());
}
var records = results.stream()
.map(r -> this.encodeRecord(columnTypeMap, r.getValues(), req.isRawResult()))
.collect(Collectors.toList());
return new RecordList(columnTypeMap, records, lastKey);
} finally {
table.unlock();
}
if (iters.isEmpty()) {
return new RecordList(null, null, null);
}

public RecordList scan(DataStoreScanRequest req) {
var limit = req.getLimit();
if (limit > 1000) {
throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE,
"limit must be less or equal to 1000. request=" + req);
}
var columnTypeMap = new HashMap<String, ColumnType>();
for (var it : iters) {
columnTypeMap.putAll(it.getColumnTypeMapping());
if (limit < 0) {
limit = 1000;
}
var keyColumnType = iters.get(0).getKeyColumnType();
Object lastKey = null;
List<Map<String, String>> ret = new ArrayList<>();
while (!iters.isEmpty() && (req.getLimit() < 0 || ret.size() < req.getLimit())) {
lastKey = Collections.min(iters, (a, b) -> {
@SuppressWarnings("rawtypes") var x = (Comparable) a.getKey();
@SuppressWarnings("rawtypes") var y = (Comparable) b.getKey();
//noinspection unchecked
return x.compareTo(y);
}).getKey();
var record = new HashMap<String, String>();
for (var iter : iters) {
if (iter.getKey().equals(lastKey)) {
record.putAll(iter.getRecord());
iter.next();

var tablesToLock =
req.getTables()
.stream()
.map(DataStoreScanRequest.TableInfo::getTableName)
.sorted() // prevent deadlock
.map(this::getTable)
.collect(Collectors.toList());

for (var table : tablesToLock) {
table.lock();
}
try {
class TableMeta {
String tableName;
MemoryTable table;
TableSchema schema;
Map<String, String> columns;
Map<String, ColumnType> columnTypeMap;
boolean keepNone;
}

var tables = req.getTables().stream().map(info -> {
var ret = new TableMeta();
ret.tableName = info.getTableName();
ret.table = this.getTable(info.getTableName());
ret.schema = ret.table.getSchema();
ret.columns = this.getColumnAliases(ret.schema, info.getColumns());
ret.columnTypeMap = ret.schema.getColumnTypeMapping(ret.columns);
ret.keepNone = info.isKeepNone();
return ret;
}).collect(Collectors.toList());

var columnTypeMap = new HashMap<String, ColumnType>();
for (var table : tables) {
if (table.schema.getKeyColumnType() != tables.get(0).schema.getKeyColumnType()) {
throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE,
MessageFormat.format(
"conflicting key column type. {0}: key={1}, type={2}, {3}: key={4}, type={5}",
tables.get(0).tableName,
tables.get(0).schema.getKeyColumn(),
tables.get(0).schema.getKeyColumnType(),
table.tableName,
table.schema.getKeyColumn(),
table.schema.getKeyColumnType()));
}
for (var entry : table.columnTypeMap.entrySet()) {
var columnName = entry.getKey();
var columnType = entry.getValue();
var old = columnTypeMap.putIfAbsent(columnName, columnType);
if (old != null && old != columnType) {
for (var t : tables) {
if (t.columnTypeMap.get(columnName) != null) {
throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE,
MessageFormat.format(
"conflicting column type. {0}: column={1}, alias={2}, type={3}, "
+ "{4}: column={5}, alias={6}, type={7}",
t.tableName,
columnName,
t.columns.get(columnName),
t.columnTypeMap.get(columnName),
table.tableName,
columnName,
table.columns.get(columnName),
table.columnTypeMap.get(columnName)));
}
}
}
}
}
class TableRecords {
TableMeta meta;
List<MemoryTable.RecordResult> records;
int index;

public MemoryTable.RecordResult getRecord() {
return this.records.get(this.index);
}

public Object getKey() {
return this.getRecord().getKey();
}
}
if (!req.isKeepNone()) {
record.entrySet().removeIf(x -> x.getValue() == null);
var records = new ArrayList<TableRecords>();
for (var table : tables) {
var r = new TableRecords();
r.meta = table;
r.records = table.table.scan(table.columns,
req.getStart(),
req.isStartInclusive(),
req.getEnd(),
req.isEndInclusive(),
limit,
table.keepNone);
if (!r.records.isEmpty()) {
records.add(r);
}
}
var keyColumnType = tables.get(0).schema.getKeyColumnType();
Object lastKey = null;
List<Map<String, String>> ret = new ArrayList<>();
while (!records.isEmpty() && ret.size() < limit) {
lastKey = Collections.min(records, (a, b) -> {
@SuppressWarnings("rawtypes") var x = (Comparable) a.getKey();
@SuppressWarnings("rawtypes") var y = (Comparable) b.getKey();
//noinspection unchecked
return x.compareTo(y);
}).getKey();
var record = new HashMap<String, String>();
for (var r : records) {
if (r.getKey().equals(lastKey)) {
record.putAll(this.encodeRecord(r.meta.columnTypeMap, r.getRecord().values, req.isRawResult()));
++r.index;
}
}
if (!req.isKeepNone()) {
record.entrySet().removeIf(x -> x.getValue() == null);
}
ret.add(record);
records.removeIf(r -> r.index == r.records.size());
}
return new RecordList(columnTypeMap, ret, keyColumnType.encode(lastKey, false));
} finally {
for (var table : tablesToLock) {
table.unlock();
}
ret.add(record);
iters.removeIf(x -> x.getRecord() == null);
}
return new RecordList(columnTypeMap, ret, keyColumnType.encode(lastKey, false));
}

private MemoryTable getTable(String tableName) {
Expand All @@ -124,4 +238,25 @@ private MemoryTable getTable(String tableName) {
}
return table;
}

private Map<String, String> getColumnAliases(TableSchema schema, Map<String, String> columns) {
if (columns == null || columns.isEmpty()) {
return schema.getColumnSchemas().stream()
.map(ColumnSchema::getName)
.collect(Collectors.toMap(Function.identity(), Function.identity()));
}
return columns;
}

private Map<String, String> encodeRecord(Map<String, ColumnType> columnTypeMap,
Map<String, Object> values,
boolean rawResult) {
var ret = new HashMap<String, String>();
for (var entry : values.entrySet()) {
var columnName = entry.getKey();
var columnValue = entry.getValue();
ret.put(columnName, columnTypeMap.get(columnName).encode(columnValue, rawResult));
}
return ret;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/
package ai.starwhale.mlops.datastore;

import lombok.AllArgsConstructor;
import lombok.Data;

import java.util.List;
import java.util.Map;

Expand All @@ -25,20 +28,31 @@ public interface MemoryTable {

void update(TableSchemaDesc schema, List<Map<String, String>> records);

RecordList query(Map<String, String> columns,
@Data
@AllArgsConstructor
class RecordResult {
Object key;
Map<String, Object> values;
}

List<RecordResult> query(Map<String, String> columns,
List<OrderByDesc> orderBy,
TableQueryFilter filter,
int start,
int limit,
boolean keepNone,
boolean rawResult);

TableScanIterator scan(
List<RecordResult> scan(
Map<String, String> columns,
String start,
boolean startInclusive,
String end,
boolean endInclusive,
boolean keepNone,
boolean rawResult);
int limit,
boolean keepNone);

void lock();

void unlock();
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public Map<String, ColumnType> getColumnTypeMapping(@NonNull Map<String, String>
for (var entry : columnAliases.entrySet()) {
var columnSchema = this.columnSchemaMap.get(entry.getKey());
if (columnSchema == null) {
throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip(
throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE,
"column name " + entry.getKey() + " not found");
}
ret.put(entry.getValue(), columnSchema.getType());
Expand Down
Loading

0 comments on commit de6f4f2

Please sign in to comment.