diff --git a/server/controller/pom.xml b/server/controller/pom.xml index 26417997a3..004c365147 100644 --- a/server/controller/pom.xml +++ b/server/controller/pom.xml @@ -136,6 +136,16 @@ io.kubernetes client-java + + org.xerial.snappy + snappy-java + 1.1.8.4 + jar + + + io.github.resilience4j + resilience4j-retry + org.junit.jupiter junit-jupiter-engine @@ -216,6 +226,22 @@ maven-compiler-plugin 3.8.1 + + com.github.os72 + protoc-jar-maven-plugin + 3.11.4 + + + generate-sources + + run + + + false + + + + diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnSchema.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnSchema.java index feaaa7de13..2ca957970c 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnSchema.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnSchema.java @@ -25,10 +25,11 @@ @ToString @EqualsAndHashCode public class ColumnSchema { - private String name; - private ColumnType type; + private final String name; + private final ColumnType type; + private final int index; - public ColumnSchema(@NonNull ColumnSchemaDesc schema) { + public ColumnSchema(@NonNull ColumnSchemaDesc schema, int index) { if (schema.getName() == null) { throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( "column name should not be null"); @@ -44,5 +45,6 @@ public ColumnSchema(@NonNull ColumnSchemaDesc schema) { throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( "invalid column type " + schema.getType()); } + this.index = index; } } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnType.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnType.java index 3db76e6dd7..1e0ab243bb 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnType.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/ColumnType.java @@ -1,3 +1,18 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.starwhale.mlops.datastore; import ai.starwhale.mlops.exception.SWValidationException; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/DataStore.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/DataStore.java index 6ddba4e3a0..891297dc9a 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/DataStore.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/DataStore.java @@ -17,7 +17,7 @@ import ai.starwhale.mlops.datastore.impl.MemoryTableImpl; import ai.starwhale.mlops.exception.SWValidationException; -import org.springframework.stereotype.Service; +import org.springframework.stereotype.Component; import java.util.ArrayList; import java.util.Collections; @@ -26,14 +26,31 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -@Service +@Component public class DataStore { + private final WalManager walManager; + private final Map tables = new ConcurrentHashMap<>(); + public DataStore(WalManager walManager) { + this.walManager = walManager; + var it = this.walManager.readAll(); + while (it.hasNext()) { + var entry = it.next(); + var tableName = entry.getTableName(); + var table = this.tables.computeIfAbsent(tableName, k -> new MemoryTableImpl(tableName, this.walManager)); + table.updateFromWal(entry); + } + } + + public void terminate() { + this.walManager.terminate(); + } + public void update(String tableName, TableSchemaDesc schema, List> records) { - var table = this.tables.computeIfAbsent(tableName, k -> new MemoryTableImpl()); + var table = this.tables.computeIfAbsent(tableName, k -> new MemoryTableImpl(tableName, this.walManager)); table.update(schema, records); } @@ -75,8 +92,9 @@ public RecordList scan(DataStoreScanRequest req) { List> ret = new ArrayList<>(); while (!iters.isEmpty() && (req.getLimit() < 0 || ret.size() < req.getLimit())) { lastKey = Collections.min(iters, (a, b) -> { - var x = (Comparable) a.getKey(); - var y = (Comparable) b.getKey(); + @SuppressWarnings("rawtypes") var x = (Comparable) a.getKey(); + @SuppressWarnings("rawtypes") var y = (Comparable) b.getKey(); + //noinspection unchecked return x.compareTo(y); }).getKey(); var record = new HashMap(); diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/MemoryTable.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/MemoryTable.java index ac77771973..2f1cf6938a 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/MemoryTable.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/MemoryTable.java @@ -21,6 +21,8 @@ public interface MemoryTable { TableSchema getSchema(); + void updateFromWal(Wal.WalEntry entry); + void update(TableSchemaDesc schema, List> records); RecordList query(Map columns, diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/RecordList.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/RecordList.java index 2bbff452f0..829ec3a847 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/RecordList.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/RecordList.java @@ -1,3 +1,18 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.starwhale.mlops.datastore; import lombok.AllArgsConstructor; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/TableSchema.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/TableSchema.java index b98b261fb7..db099f1da3 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/TableSchema.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/TableSchema.java @@ -38,6 +38,7 @@ public class TableSchema { @Getter private final ColumnType keyColumnType; private final Map columnSchemaMap; + private int maxColumnIndex; public TableSchema(@NonNull TableSchemaDesc schema) { this.keyColumn = schema.getKeyColumn(); @@ -48,7 +49,7 @@ public TableSchema(@NonNull TableSchemaDesc schema) { this.columnSchemaMap = new HashMap<>(); if (schema.getColumnSchemaList() != null) { for (var col : schema.getColumnSchemaList()) { - var colSchema = new ColumnSchema(col); + var colSchema = new ColumnSchema(col, this.maxColumnIndex++); if (!TableSchema.COLUMN_NAME_PATTERN.matcher(col.getName()).matches()) { throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( "invalid column name " + col.getName()); @@ -75,6 +76,7 @@ public TableSchema(@NonNull TableSchema schema) { this.keyColumn = schema.keyColumn; this.keyColumnType = schema.keyColumnType; this.columnSchemaMap = new HashMap<>(schema.columnSchemaMap); + this.maxColumnIndex = schema.maxColumnIndex; } public ColumnSchema getColumnSchemaByName(@NonNull String name) { @@ -85,7 +87,7 @@ public List getColumnSchemas() { return List.copyOf(this.columnSchemaMap.values()); } - public void merge(@NonNull TableSchemaDesc schema) { + public List merge(@NonNull TableSchemaDesc schema) { if (schema.getKeyColumn() != null && !this.keyColumn.equals(schema.getKeyColumn())) { throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( MessageFormat.format( @@ -94,9 +96,10 @@ public void merge(@NonNull TableSchemaDesc schema) { schema.getKeyColumn())); } var columnSchemaMap = new HashMap(); + var columnIndex = this.maxColumnIndex; for (var col : schema.getColumnSchemaList()) { var current = this.columnSchemaMap.get(col.getName()); - var colSchema = new ColumnSchema(col); + var colSchema = new ColumnSchema(col, current == null ? columnIndex++ : current.getIndex()); if (current != null && current.getType() != ColumnType.UNKNOWN && colSchema.getType() != ColumnType.UNKNOWN @@ -105,11 +108,14 @@ public void merge(@NonNull TableSchemaDesc schema) { MessageFormat.format("conflicting type for column {0}, expected {1}, actual {2}", col.getName(), current.getType(), col.getType())); } - if (current == null || colSchema.getType() != ColumnType.UNKNOWN) { + if (current == null + || current.getType() != colSchema.getType() && colSchema.getType() != ColumnType.UNKNOWN) { columnSchemaMap.put(col.getName(), colSchema); } } this.columnSchemaMap.putAll(columnSchemaMap); + this.maxColumnIndex = columnIndex; + return List.copyOf(columnSchemaMap.values()); } public Map getColumnTypeMapping() { diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/WalManager.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/WalManager.java new file mode 100644 index 0000000000..86dcaae1ef --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/WalManager.java @@ -0,0 +1,350 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.starwhale.mlops.datastore; + +import ai.starwhale.mlops.exception.SWProcessException; +import ai.starwhale.mlops.exception.SWValidationException; +import ai.starwhale.mlops.memory.SwBuffer; +import ai.starwhale.mlops.memory.SwBufferInputStream; +import ai.starwhale.mlops.memory.SwBufferManager; +import ai.starwhale.mlops.memory.SwBufferOutputStream; +import ai.starwhale.mlops.objectstore.ObjectStore; +import com.google.protobuf.CodedOutputStream; +import io.github.resilience4j.core.IntervalFunction; +import io.github.resilience4j.retry.Retry; +import io.github.resilience4j.retry.RetryConfig; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; +import org.xerial.snappy.Snappy; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.TreeMap; + +@Slf4j +@Component +public class WalManager extends Thread { + private final ObjectStore objectStore; + + private final SwBufferManager bufferManager; + + private final int walFileSize; + + private final int walMaxFileSize; + + private final String logFilePrefix; + + private final int walWaitIntervalMillis; + + private final LinkedList entries = new LinkedList<>(); + + private SwBufferOutputStream outputStream; + + private final SwBuffer outputBuffer; + + private final SwBuffer compressedBuffer; + + private boolean terminated; + + private int logFileIndex; + + private final List existedLogFiles = new ArrayList<>(); + + public WalManager(ObjectStore objectStore, + SwBufferManager bufferManager, + @Value("${sw.datastore.walFileSize}") int walFileSize, + @Value("${sw.datastore.walMaxFileSize}") int walMaxFileSize, + @Value("${sw.datastore.walPrefix}") String walPrefix, + @Value("${sw.datastore.walWaitIntervalMillis}") int walWaitIntervalMillis) throws IOException { + this.objectStore = objectStore; + this.bufferManager = bufferManager; + this.walFileSize = walFileSize; + this.walMaxFileSize = walMaxFileSize; + this.logFilePrefix = walPrefix + "wal.log."; + this.walWaitIntervalMillis = walWaitIntervalMillis; + this.outputBuffer = this.bufferManager.allocate(this.walMaxFileSize); + this.compressedBuffer = this.bufferManager.allocate(this.walMaxFileSize); + this.outputStream = new SwBufferOutputStream(this.outputBuffer); + var walMap = new TreeMap(); + var it = this.objectStore.list(this.logFilePrefix); + while (it.hasNext()) { + var fn = it.next(); + try { + var index = Integer.parseInt(fn.substring(this.logFilePrefix.length())); + walMap.put(index, fn); + } catch (NumberFormatException e) { + // ignore + } + } + if (!walMap.isEmpty()) { + this.logFileIndex = walMap.lastKey() + 1; + this.existedLogFiles.addAll(walMap.values()); + } + this.start(); + } + + public Iterator readAll() { + return new Iterator<>() { + private final List files = new ArrayList<>(WalManager.this.existedLogFiles); + private final SwBuffer buf = WalManager.this.bufferManager.allocate(WalManager.this.walMaxFileSize); + private SwBufferInputStream inputStream; + + @Override + public boolean hasNext() { + return !this.files.isEmpty() || this.inputStream != null; + } + + @Override + public Wal.WalEntry next() { + if (this.inputStream == null) { + if (this.files.isEmpty()) { + return null; + } else { + this.getNext(); + } + } + try { + var ret = Wal.WalEntry.parseDelimitedFrom(this.inputStream); + if (this.inputStream.remaining() == 0) { + this.inputStream = null; + } + return ret; + } catch (IOException e) { + log.error("failed to parse proto", e); + this.inputStream = null; + throw new SWProcessException(SWProcessException.ErrorType.DATASTORE); + } + } + + private void getNext() { + var fn = this.files.get(0); + this.files.remove(0); + SwBuffer data; + try { + data = objectStore.get(fn); + } catch (IOException e) { + log.error("fail to read from object store", e); + throw new SWProcessException(SWProcessException.ErrorType.DATASTORE); + } + int uncompressedSize; + try { + var inBuf = data.asByteBuffer(); + var outBuf = this.buf.asByteBuffer(); + if (inBuf.hasArray()) { + uncompressedSize = Snappy.uncompress(inBuf.array(), 0, inBuf.capacity(), outBuf.array(), 0); + } else { + uncompressedSize = Snappy.uncompress(inBuf, outBuf); + } + } catch (IOException e) { + log.error("fail to uncompress", e); + throw new SWProcessException(SWProcessException.ErrorType.DATASTORE); + } + WalManager.this.bufferManager.release(data); + this.inputStream = new SwBufferInputStream(this.buf.slice(0, uncompressedSize)); + } + }; + } + + public void append(Wal.WalEntry entry) { + if (entry.getSerializedSize() > this.walMaxFileSize) { + for (var e : this.splitEntry(entry)) { + this.append(e); + } + } else { + synchronized (this.entries) { + if (this.terminated) { + throw new SWProcessException(SWProcessException.ErrorType.DATASTORE, "terminated"); + } + this.entries.add(entry); + } + } + } + + public void terminate() { + synchronized (this.entries) { + if (this.terminated) { + return; + } + this.terminated = true; + this.entries.notifyAll(); + } + try { + this.join(); + } catch (InterruptedException e) { + log.warn("interrupted", e); + } + } + + @Override + public void run() { + for (; ; ) { + try { + var status = this.populateOutput(); + if (status == PopulationStatus.TERMINATED) { + return; + } + this.writeToObjectStore(status == PopulationStatus.BUFFER_FULL); + } catch (Throwable e) { + log.error("unexpected exception", e); + try { + //noinspection BusyWait + Thread.sleep(1000); + } catch (InterruptedException ex) { + log.warn("interrupted", e); + } + } + } + } + + private enum PopulationStatus { + TERMINATED, + BUFFER_FULL, + NO_MORE_ENTRIES + } + + private PopulationStatus populateOutput() { + synchronized (this.entries) { + while (!this.terminated && this.entries.isEmpty()) { + try { + this.entries.wait(this.walWaitIntervalMillis); + } catch (InterruptedException e) { + log.warn("interrupted", e); + } + } + if (this.terminated && this.entries.isEmpty()) { + return PopulationStatus.TERMINATED; + } + } + for (; ; ) { + Wal.WalEntry entry; + synchronized (this.entries) { + if (this.entries.isEmpty()) { + break; + } + entry = this.entries.getFirst(); + } + if (CodedOutputStream.computeMessageSizeNoTag(entry) + this.outputStream.getOffset() + > this.walMaxFileSize) { + if (this.outputStream.getOffset() == 0) { + // huge single entry + log.error( + "data loss: discard unexpected huge entry. size={} table={} schema={} records count={}", + entry.getSerializedSize(), + entry.getTableName(), + entry.getTableSchema(), + entry.getRecordsCount()); + } else { + return PopulationStatus.BUFFER_FULL; + } + } else { + try { + entry.writeDelimitedTo(this.outputStream); + } catch (IOException e) { + log.error("data loss: unexpected exception", e); + } + } + synchronized (this.entries) { + this.entries.removeFirst(); + } + } + if (this.outputStream.getOffset() >= this.walFileSize) { + return PopulationStatus.BUFFER_FULL; + } + return PopulationStatus.NO_MORE_ENTRIES; + } + + private void writeToObjectStore(boolean clearOutput) { + int compressedSize; + try { + var inBuf = this.outputBuffer.asByteBuffer(); + var outBuf = this.compressedBuffer.asByteBuffer(); + if (inBuf.hasArray()) { + compressedSize = Snappy.compress(inBuf.array(), 0, this.outputStream.getOffset(), outBuf.array(), 0); + } else { + inBuf.limit(this.outputStream.getOffset()); + compressedSize = Snappy.compress(inBuf, outBuf); + } + } catch (IOException e) { + log.error("data loss: failed to compress", e); + return; + } + try { + Retry.decorateCheckedRunnable( + Retry.of("put", RetryConfig.custom() + .maxAttempts(10000) + .intervalFunction(IntervalFunction.ofExponentialRandomBackoff(100, 2.0, 0.5, 10000)) + .retryOnException(e -> !terminated) + .build()), + () -> this.objectStore.put(this.logFilePrefix + this.logFileIndex, + this.compressedBuffer.slice(0, compressedSize))) + .run(); + } catch (Throwable e) { + log.error("data loss: failed to write wal log", e); + } + if (clearOutput) { + ++this.logFileIndex; + this.outputStream = new SwBufferOutputStream(this.outputBuffer); + } + } + + private List splitEntry(Wal.WalEntry entry) { + List ret = new ArrayList<>(); + var builder = Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName(entry.getTableName()); + int headerSize = builder.build().getSerializedSize(); + if (entry.hasTableSchema()) { + builder.setTableSchema(entry.getTableSchema()); + } + int currentEntrySize = builder.build().getSerializedSize(); + if (currentEntrySize > this.walMaxFileSize) { + throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE, + "schema is too large or walMaxFileSize is too small. size=" + currentEntrySize + + " walMaxFileSize=" + this.walMaxFileSize); + } + for (var record : entry.getRecordsList()) { + // field number is less than 128, so simply use 1 instead + var recordSize = CodedOutputStream.computeMessageSize(1, record); + currentEntrySize += recordSize; + if (currentEntrySize + CodedOutputStream.computeUInt32SizeNoTag(currentEntrySize) > this.walMaxFileSize) { + ret.add(builder.build()); + builder.clearTableSchema(); + builder.clearRecords(); + currentEntrySize = headerSize + recordSize; + if (currentEntrySize + CodedOutputStream.computeUInt32SizeNoTag(currentEntrySize) + > this.walMaxFileSize) { + throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE, + "huge single record. size=" + currentEntrySize); + } + } + builder.addRecords(record); + } + if (builder.getRecordsCount() > 0) { + ret.add(builder.build()); + } + for (var e : ret) { + if (e.getSerializedSize() > this.walMaxFileSize) { + throw new SWProcessException(SWProcessException.ErrorType.DATASTORE, + "invalid entry size " + e.getSerializedSize()); + } + } + return ret; + } +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/MemoryTableImpl.java b/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/MemoryTableImpl.java index 9b4e130a56..9b057c51ed 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/MemoryTableImpl.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/datastore/impl/MemoryTableImpl.java @@ -16,6 +16,7 @@ package ai.starwhale.mlops.datastore.impl; import ai.starwhale.mlops.datastore.ColumnSchema; +import ai.starwhale.mlops.datastore.ColumnSchemaDesc; import ai.starwhale.mlops.datastore.ColumnType; import ai.starwhale.mlops.datastore.MemoryTable; import ai.starwhale.mlops.datastore.OrderByDesc; @@ -24,12 +25,16 @@ import ai.starwhale.mlops.datastore.TableScanIterator; import ai.starwhale.mlops.datastore.TableSchema; import ai.starwhale.mlops.datastore.TableSchemaDesc; +import ai.starwhale.mlops.datastore.Wal; +import ai.starwhale.mlops.datastore.WalManager; import ai.starwhale.mlops.exception.SWValidationException; +import com.google.protobuf.ByteString; import lombok.Getter; import java.nio.ByteBuffer; import java.text.MessageFormat; import java.util.ArrayList; +import java.util.Comparator; import java.util.ConcurrentModificationException; import java.util.HashMap; import java.util.Iterator; @@ -41,59 +46,228 @@ import java.util.stream.Collectors; public class MemoryTableImpl implements MemoryTable { + private final String tableName; + + private final WalManager walManager; + private TableSchema schema = null; private final TreeMap> recordMap = new TreeMap<>(); + // used only for initialization from WAL + private final Map indexMap = new HashMap<>(); + + public MemoryTableImpl(String tableName, WalManager walManager) { + this.tableName = tableName; + this.walManager = walManager; + } + public TableSchema getSchema() { return this.schema == null ? null : new TableSchema(this.schema); } + @Override + public void updateFromWal(Wal.WalEntry entry) { + if (entry.hasTableSchema()) { + var schemaDesc = this.parseSchema(entry.getTableSchema()); + if (this.schema == null) { + this.schema = new TableSchema(schemaDesc); + } else { + this.schema.merge(schemaDesc); + } + } + var recordList = entry.getRecordsList(); + if (!recordList.isEmpty()) { + this.insertRecords(recordList.stream().map(this::parseRecord).collect(Collectors.toList())); + } + } + + private TableSchemaDesc parseSchema(Wal.TableSchema tableSchema) { + var ret = new TableSchemaDesc(); + var keyColumn = tableSchema.getKeyColumn(); + if (!keyColumn.isEmpty()) { + ret.setKeyColumn(keyColumn); + } + var columnList = new ArrayList<>(tableSchema.getColumnsList()); + columnList.sort(Comparator.comparingInt(Wal.ColumnSchema::getColumnIndex)); + var columnSchemaList = new ArrayList(); + for (var col : columnList) { + var colDesc = new ColumnSchemaDesc(col.getColumnName(), col.getColumnType()); + columnSchemaList.add(colDesc); + this.indexMap.put(col.getColumnIndex(), new ColumnSchema(colDesc, col.getColumnIndex())); + } + ret.setColumnSchemaList(columnSchemaList); + return ret; + } + + private Map parseRecord(Wal.Record record) { + Map ret = new HashMap<>(); + for (var col : record.getColumnsList()) { + if (col.getIndex() == -1) { + ret.put("-", true); + } else { + var colSchema = this.indexMap.get(col.getIndex()); + ret.put(colSchema.getName(), this.parseValue(colSchema, col)); + } + } + return ret; + } + + private Object parseValue(ColumnSchema columnSchema, Wal.Column col) { + if (col.getNullValue()) { + return null; + } + switch (columnSchema.getType()) { + case UNKNOWN: + return null; + case BOOL: + return col.getBoolValue(); + case INT8: + return (byte) col.getIntValue(); + case INT16: + return (short) col.getIntValue(); + case INT32: + return (int) col.getIntValue(); + case INT64: + return col.getIntValue(); + case FLOAT32: + return col.getFloatValue(); + case FLOAT64: + return col.getDoubleValue(); + case STRING: + return col.getStringValue(); + case BYTES: + return ByteBuffer.wrap(col.getBytesValue().toByteArray()); + default: + throw new IllegalArgumentException("invalid type " + this); + } + } + + @Override synchronized public void update(TableSchemaDesc schema, List> records) { + var logEntryBuilder = Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName(this.tableName); + TableSchema newSchema = this.schema; if (schema == null) { if (this.schema == null) { - throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( + throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE, "schema should not be null for the first update"); } } else { + var logSchemaBuilder = Wal.TableSchema.newBuilder(); + List diff; if (this.schema == null) { - this.schema = new TableSchema(schema); + newSchema = new TableSchema(schema); + diff = newSchema.getColumnSchemas(); + logSchemaBuilder.setKeyColumn(newSchema.getKeyColumn()); } else { - this.schema.merge(schema); + newSchema = new TableSchema(this.schema); + diff = newSchema.merge(schema); } + for (var col : diff) { + logSchemaBuilder.addColumns(Wal.ColumnSchema.newBuilder() + .setColumnIndex(col.getIndex()) + .setColumnName(col.getName()) + .setColumnType(col.getType().toString())); + } + logEntryBuilder.setTableSchema(logSchemaBuilder); } + List> decodedRecords = null; if (records != null) { - var decodedRecords = new ArrayList>(); + decodedRecords = new ArrayList<>(); for (var record : records) { - var key = record.get(this.schema.getKeyColumn()); + var key = record.get(newSchema.getKeyColumn()); if (key == null) { - throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( - MessageFormat.format("key column {0} is null", this.schema.getKeyColumn())); + throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE, + MessageFormat.format("key column {0} is null", newSchema.getKeyColumn())); } if (record.get("-") != null) { - decodedRecords.add(Map.of(this.schema.getKeyColumn(), - this.schema.getKeyColumnType().decode(key), + decodedRecords.add(Map.of(newSchema.getKeyColumn(), + newSchema.getKeyColumnType().decode(key), "-", true)); } else { - decodedRecords.add(MemoryTableImpl.decodeRecord(this.schema, record)); + decodedRecords.add(MemoryTableImpl.decodeRecord(newSchema, record)); } } for (var record : decodedRecords) { - var key = record.get(this.schema.getKeyColumn()); - if (record.get("-") != null) { - this.recordMap.remove(key); - } else { - var old = this.recordMap.putIfAbsent(key, record); - if (old != null) { - old.putAll(record); - } + logEntryBuilder.addRecords(MemoryTableImpl.writeRecord(newSchema, record)); + } + } + this.walManager.append(logEntryBuilder.build()); + this.schema = newSchema; + if (decodedRecords != null) { + this.insertRecords(decodedRecords); + } + } + + private void insertRecords(List> records) { + for (var record : records) { + var key = record.get(this.schema.getKeyColumn()); + if (record.get("-") != null) { + this.recordMap.remove(key); + } else { + var old = this.recordMap.putIfAbsent(key, record); + if (old != null) { + old.putAll(record); } } } } + private static Wal.Record.Builder writeRecord(TableSchema schema, Map record) { + var ret = Wal.Record.newBuilder(); + for (var entry : record.entrySet()) { + ret.addColumns(MemoryTableImpl.writeColumn(schema, entry.getKey(), entry.getValue())); + } + return ret; + } + + private static Wal.Column.Builder writeColumn(TableSchema schema, String name, Object value) { + var ret = Wal.Column.newBuilder(); + if (name.equals("-")) { + ret.setIndex(-1); + } else { + var colSchema = schema.getColumnSchemaByName(name); + ret.setIndex(colSchema.getIndex()); + if (value == null) { + ret.setNullValue(true); + } else { + switch (colSchema.getType()) { + case UNKNOWN: + ret.setNullValue(true); + break; + case BOOL: + ret.setBoolValue((Boolean) value); + break; + case INT8: + case INT16: + case INT32: + case INT64: + ret.setIntValue(((Number) value).longValue()); + break; + case FLOAT32: + ret.setFloatValue((Float) value); + break; + case FLOAT64: + ret.setDoubleValue((Double) value); + break; + case STRING: + ret.setStringValue((String) value); + break; + case BYTES: + ret.setBytesValue(ByteString.copyFrom(((ByteBuffer) value).array())); + break; + default: + throw new IllegalArgumentException("invalid type " + colSchema.getType()); + } + } + } + return ret; + } + @Override synchronized public RecordList query( Map columns, @@ -112,11 +286,11 @@ synchronized public RecordList query( if (orderBy != null) { for (var col : orderBy) { if (col == null) { - throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( + throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE, "order by column should not be null"); } if (this.schema.getColumnSchemaByName(col.getColumnName()) == null) { - throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( + throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE, "unknown orderBy column " + col); } } @@ -219,6 +393,7 @@ public InternalIterator( } else { this.endKey = MemoryTableImpl.this.recordMap.lowerKey(endKey); } + //noinspection rawtypes,unchecked if (startKey != null && this.endKey != null && ((Comparable) startKey).compareTo(this.endKey) <= 0) { this.iterator = MemoryTableImpl.this.recordMap.subMap(startKey, true, this.endKey, true) .entrySet() @@ -370,11 +545,11 @@ private void checkSameType(List operands) { var col = (TableQueryFilter.Column) op; var colSchema = this.schema.getColumnSchemaByName(col.getName()); if (colSchema == null) { - throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( + throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE, "invalid filter, unknown column " + col.getName()); } if (!type.getName().equals(colSchema.getType().getName())) { - throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( + throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE, MessageFormat.format( "invalid filter, can not compare column {0} of type {1} with column {2} of type {3}", col.getName(), @@ -394,10 +569,10 @@ private void checkSameType(List operands) { checkFailed = !type.getName().equals(ColumnType.INT32.getName()) && !type.getName().equals(ColumnType.FLOAT32.getName()); } else { - throw new IllegalArgumentException("unexpectd operand class " + op.getClass()); + throw new IllegalArgumentException("unexpected operand class " + op.getClass()); } if (checkFailed) { - throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE).tip( + throw new SWValidationException(SWValidationException.ValidSubject.DATASTORE, MessageFormat.format( "invalid filter, can not compare column {0} of type {1} with value {2} of type {3}", firstCol.orElseThrow().getName(), @@ -466,13 +641,13 @@ private static Map decodeRecord(TableSchema schema, Map= this.buffer.capacity()) { + return -1; + } + return this.buffer.getByte(this.offset++); + } + + @Override + public int read(byte b[], int off, int len) { + var capacity = this.buffer.capacity(); + if (this.offset >= capacity) { + return -1; + } + var ret = this.buffer.getBytes(this.offset, b, off, len); + this.offset += ret; + return ret; + } +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/memory/BufferManager.java b/server/controller/src/main/java/ai/starwhale/mlops/memory/SwBufferManager.java similarity index 90% rename from server/controller/src/main/java/ai/starwhale/mlops/memory/BufferManager.java rename to server/controller/src/main/java/ai/starwhale/mlops/memory/SwBufferManager.java index 9a4ec89965..0533560a71 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/memory/BufferManager.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/memory/SwBufferManager.java @@ -15,6 +15,8 @@ */ package ai.starwhale.mlops.memory; -public interface BufferManager { +public interface SwBufferManager { SwBuffer allocate(int capacity); + + void release(SwBuffer buffer); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/memory/SwBufferOutputStream.java b/server/controller/src/main/java/ai/starwhale/mlops/memory/SwBufferOutputStream.java new file mode 100644 index 0000000000..cb1bfcdae2 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/memory/SwBufferOutputStream.java @@ -0,0 +1,50 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.starwhale.mlops.memory; + +import lombok.Getter; + +import javax.validation.constraints.NotNull; +import java.io.IOException; +import java.io.OutputStream; + +public class SwBufferOutputStream extends OutputStream { + private final SwBuffer buffer; + @Getter + private int offset; + + public SwBufferOutputStream(SwBuffer buffer) { + this.buffer = buffer; + } + + @Override + public void write(int b) throws IOException { + if (this.offset >= this.buffer.capacity()) { + throw new IOException("buffer size limit exceeded"); + } + this.buffer.setByte(this.offset, (byte) b); + ++this.offset; + } + + @Override + public void write(@NotNull byte[] b, int off, int len) throws IOException { + if (this.offset + len > this.buffer.capacity()) { + throw new IOException("buffer size limit exceeded"); + } + this.buffer.setBytes(this.offset, b, off, len); + this.offset += len; + } +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/BufferManagerConfig.java b/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/BufferManagerConfig.java deleted file mode 100644 index d794a5dd74..0000000000 --- a/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/BufferManagerConfig.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright 2022 Starwhale, Inc. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package ai.starwhale.mlops.memory.impl; - -import ai.starwhale.mlops.memory.BufferManager; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; - -@Configuration -public class BufferManagerConfig { - @Bean - BufferManager getBufferManager() { - return new ByteBufferManager(); - } -} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/SwBytesBuffer.java b/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/SwByteBuffer.java similarity index 65% rename from server/controller/src/main/java/ai/starwhale/mlops/memory/impl/SwBytesBuffer.java rename to server/controller/src/main/java/ai/starwhale/mlops/memory/impl/SwByteBuffer.java index 3ff2357219..489fbeead6 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/SwBytesBuffer.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/SwByteBuffer.java @@ -19,14 +19,19 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.text.MessageFormat; -public class SwBytesBuffer implements SwBuffer { - private ByteBuffer buf; +public class SwByteBuffer implements SwBuffer { + private final ByteBuffer buf; - protected SwBytesBuffer(int capacity) { + protected SwByteBuffer(int capacity) { this.buf = ByteBuffer.allocate(capacity); } + private SwByteBuffer(ByteBuffer buf) { + this.buf = buf; + } + @Override public byte getByte(int index) { return this.buf.get(index); @@ -89,26 +94,34 @@ public void setDouble(int index, double value) { @Override public String getString(int index, int count) { - return new String(this.getBytes(index, count), StandardCharsets.UTF_8); + var b = new byte[count]; + if (this.getBytes(index, b, 0, count) != count) { + throw new IllegalArgumentException( + MessageFormat.format("not enough data. index={0} count={1}", index, count)); + } + return new String(b, StandardCharsets.UTF_8); } @Override public void setString(int index, String value) { - this.setBytes(index, value.getBytes(StandardCharsets.UTF_8)); + var b = value.getBytes(StandardCharsets.UTF_8); + this.setBytes(index, b, 0, b.length); } @Override - public byte[] getBytes(int index, int count) { - byte[] data = new byte[count]; + public int getBytes(int index, byte[] b, int offset, int len) { this.buf.position(index); - this.buf.get(data); - return data; + if (len > this.buf.remaining()) { + len = this.buf.remaining(); + } + this.buf.get(b, offset, len); + return len; } @Override - public void setBytes(int index, byte[] value) { + public void setBytes(int index, byte[] b, int offset, int len) { this.buf.position(index); - this.buf.put(value); + this.buf.put(b, offset, len); } @Override @@ -116,13 +129,24 @@ public int capacity() { return this.buf.capacity(); } + @Override + public SwBuffer slice(int offset, int len) { + this.buf.position(offset); + this.buf.limit(offset + len); + var buf = new SwByteBuffer(this.buf.slice()); + this.buf.limit(this.buf.capacity()); + return buf; + } + @Override public void copyTo(SwBuffer buf) { - buf.setBytes(0, this.buf.array()); + buf.setBytes(0, this.buf.array(), 0, this.buf.limit()); } @Override public ByteBuffer asByteBuffer() { - return this.buf; + var buf = this.buf.duplicate(); + buf.clear(); + return buf; } } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/ByteBufferManager.java b/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/SwByteBufferManager.java similarity index 73% rename from server/controller/src/main/java/ai/starwhale/mlops/memory/impl/ByteBufferManager.java rename to server/controller/src/main/java/ai/starwhale/mlops/memory/impl/SwByteBufferManager.java index df51f6e0a0..0aeb11b51b 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/ByteBufferManager.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/memory/impl/SwByteBufferManager.java @@ -15,12 +15,18 @@ */ package ai.starwhale.mlops.memory.impl; -import ai.starwhale.mlops.memory.BufferManager; +import ai.starwhale.mlops.memory.SwBufferManager; import ai.starwhale.mlops.memory.SwBuffer; +import org.springframework.stereotype.Component; -public class ByteBufferManager implements BufferManager { +@Component +public class SwByteBufferManager implements SwBufferManager { @Override public SwBuffer allocate(int capacity) { - return new SwBytesBuffer(capacity); + return new SwByteBuffer(capacity); + } + + @Override + public void release(SwBuffer buffer) { } } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/objectstore/ObjectStore.java b/server/controller/src/main/java/ai/starwhale/mlops/objectstore/ObjectStore.java index 94cd7e99c9..512041806f 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/objectstore/ObjectStore.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/objectstore/ObjectStore.java @@ -18,9 +18,14 @@ import ai.starwhale.mlops.memory.SwBuffer; import java.io.IOException; +import java.util.Iterator; public interface ObjectStore { + Iterator list(String prefix) throws IOException; + void put(String name, SwBuffer buffer) throws IOException; SwBuffer get(String name) throws IOException; + + void delete(String name) throws IOException; } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/FileIterator.java b/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/FileIterator.java new file mode 100644 index 0000000000..b2cce7581e --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/FileIterator.java @@ -0,0 +1,109 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.starwhale.mlops.objectstore.impl; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +public class FileIterator implements Iterator { + private final List> stack = new ArrayList<>(); + private final File rootDir; + private String next; + + public FileIterator(String rootDir, String prefix) { + var root = new File(rootDir); + if (!root.isDirectory()) { + throw new IllegalArgumentException(rootDir + " not found or is not a directory"); + } + this.rootDir = root; + var path = prefix.split("/", -1); + if (path.length > 1) { + for (int i = 0; i < path.length - 1; ++i) { + root = new File(root, path[i]); + if (!root.isDirectory()) { + return; + } + } + prefix = path[path.length - 1]; + } + var candidates = new ArrayList(); + for (var fn : Objects.requireNonNull(root.list())) { + if (fn.startsWith(prefix)) { + candidates.add(new File(root, fn)); + } + } + if (candidates.isEmpty()) { + return; + } + candidates.sort(FileIterator::compareFilePaths); + this.stack.add(candidates); + this.findNext(); + } + + @Override + public boolean hasNext() { + return this.next != null; + } + + @Override + public String next() { + if (this.next == null) { + return null; + } + var ret = this.next; + this.next = null; + this.findNext(); + return ret; + } + + private void findNext() { + while (!this.stack.isEmpty()) { + var last = this.stack.get(this.stack.size() - 1); + if (last.isEmpty()) { + this.stack.remove(this.stack.size() - 1); + continue; + } + var next = last.remove(last.size() - 1); + if (!next.isDirectory()) { + var names = new ArrayList(); + for (var path : this.rootDir.toPath().relativize(next.toPath())) { + names.add(path.toString()); + } + this.next = String.join("/", names); + return; + } + var candidates = new ArrayList<>(Arrays.asList(Objects.requireNonNull(next.listFiles()))); + candidates.sort(FileIterator::compareFilePaths); + this.stack.add(candidates); + } + } + + private static int compareFilePaths(File a, File b) { + var x = a.getPath(); + var y = b.getPath(); + if (a.isDirectory()) { + x += "/"; + } + if (b.isDirectory()) { + y += "/"; + } + return y.compareTo(x); + } +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/FileSystemObjectStore.java b/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/FileSystemObjectStore.java index 23302930f6..87b762116c 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/FileSystemObjectStore.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/FileSystemObjectStore.java @@ -15,39 +15,79 @@ */ package ai.starwhale.mlops.objectstore.impl; -import ai.starwhale.mlops.memory.BufferManager; import ai.starwhale.mlops.memory.SwBuffer; +import ai.starwhale.mlops.memory.SwBufferManager; import ai.starwhale.mlops.objectstore.ObjectStore; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.stereotype.Component; +import java.io.File; import java.io.FileInputStream; +import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.nio.file.Files; -import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; import java.security.InvalidParameterException; import java.text.MessageFormat; +import java.util.Iterator; @Slf4j +@Component +@ConditionalOnProperty(prefix = "sw.storage", name = "type", havingValue = "fs") public class FileSystemObjectStore implements ObjectStore { - @Autowired - private BufferManager bufferManager; + private final SwBufferManager bufferManager; + + private final String rootDir; + + public FileSystemObjectStore(SwBufferManager bufferManager, @Value("${sw.datastore.fsRootDir}") String rootDir) { + this.bufferManager = bufferManager; + this.rootDir = rootDir; + } + + @Override + public Iterator list(String prefix) throws IOException { + return new FileIterator(this.rootDir, prefix); + } @Override public void put(String name, SwBuffer buf) throws IOException { - new FileOutputStream(name).getChannel().write(buf.asByteBuffer()); + var f = new File(this.rootDir, name); + //noinspection ResultOfMethodCallIgnored + f.getParentFile().mkdirs(); + var temp = File.createTempFile("sw_tmp", null); + try (var channel = new FileOutputStream(temp).getChannel()) { + channel.write(buf.asByteBuffer()); + Files.move(temp.toPath(), f.toPath(), StandardCopyOption.ATOMIC_MOVE, StandardCopyOption.REPLACE_EXISTING); + } finally { + //noinspection ResultOfMethodCallIgnored + temp.delete(); + } } @Override public SwBuffer get(String name) throws IOException { - long size = Files.size(Paths.get(name)); + var f = new File(this.rootDir, name); + if (!f.exists()) { + throw new FileNotFoundException(f.getAbsolutePath()); + } + long size = Files.size(f.toPath()); if (size > 64 * 1024 * 1024) { throw new InvalidParameterException( MessageFormat.format("file {0} is too large, size {1}", name, size)); } var buf = this.bufferManager.allocate((int) size); - new FileInputStream(name).getChannel().read(buf.asByteBuffer()); + try (var channel = new FileInputStream(f).getChannel()) { + channel.read(buf.asByteBuffer()); + } return buf; } + + @Override + public void delete(String name) { + //noinspection ResultOfMethodCallIgnored + new File(this.rootDir, name).delete(); + } } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/ObjectStoreConfig.java b/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/ObjectStoreConfig.java deleted file mode 100644 index d9d259d7b9..0000000000 --- a/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/ObjectStoreConfig.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright 2022 Starwhale, Inc. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package ai.starwhale.mlops.objectstore.impl; - -import ai.starwhale.mlops.objectstore.ObjectStore; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; - -@Configuration -public class ObjectStoreConfig { - @Bean - public ObjectStore getObjectStore() { - return new FileSystemObjectStore(); - } -} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/S3ObjectStore.java b/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/S3ObjectStore.java new file mode 100644 index 0000000000..9cccd33c48 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/objectstore/impl/S3ObjectStore.java @@ -0,0 +1,68 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.starwhale.mlops.objectstore.impl; + +import ai.starwhale.mlops.memory.SwBuffer; +import ai.starwhale.mlops.memory.SwBufferInputStream; +import ai.starwhale.mlops.memory.SwBufferManager; +import ai.starwhale.mlops.objectstore.ObjectStore; +import ai.starwhale.mlops.storage.StorageAccessService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.stereotype.Component; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; + +import java.io.IOException; +import java.util.Iterator; + +@Slf4j +@Component +@ConditionalOnProperty(prefix = "sw.storage", name = "type", havingValue = "s3", matchIfMissing = true) +public class S3ObjectStore implements ObjectStore { + private final SwBufferManager bufferManager; + + private final StorageAccessService storageAccessService; + + public S3ObjectStore(SwBufferManager bufferManager, StorageAccessService storageAccessService) { + this.bufferManager = bufferManager; + this.storageAccessService = storageAccessService; + } + + @Override + public Iterator list(String prefix) throws IOException { + return this.storageAccessService.list(prefix).iterator(); + } + + @Override + public void put(String name, SwBuffer buf) throws IOException { + this.storageAccessService.put(name, new SwBufferInputStream(buf)); + } + + @Override + public SwBuffer get(String name) throws IOException { + @SuppressWarnings("unchecked") + var result = (ResponseInputStream) this.storageAccessService.get(name); + var ret = this.bufferManager.allocate(result.response().contentLength().intValue()); + assert result.read(ret.asByteBuffer().array()) == ret.capacity(); + return ret; + } + + @Override + public void delete(String name) throws IOException { + this.storageAccessService.delete(name); + } +} diff --git a/server/controller/src/main/protobuf/wal.proto b/server/controller/src/main/protobuf/wal.proto new file mode 100644 index 0000000000..f9a64ca79b --- /dev/null +++ b/server/controller/src/main/protobuf/wal.proto @@ -0,0 +1,39 @@ +syntax = "proto3"; + +option java_package = "ai.starwhale.mlops.datastore"; + +message ColumnSchema { + string column_name = 1; + string column_type = 2; + int32 column_index = 3; +} + +message TableSchema { + string key_column = 1; + repeated ColumnSchema columns = 2; +} + +message Column { + int32 index = 1; + bool null_value = 2; + bool bool_value = 3; + int64 int_value = 4; + float float_value = 5; + double double_value = 6; + string string_value = 7; + bytes bytes_value = 8; +} + +message Record { + repeated Column columns = 1; +} + +message WalEntry { + enum Type { + UPDATE = 0; + } + Type entry_type = 1; + string table_name = 2; + TableSchema table_schema = 3; + repeated Record records = 4; +} \ No newline at end of file diff --git a/server/controller/src/main/resources/application.yaml b/server/controller/src/main/resources/application.yaml index adba2d2f68..4218827f7a 100644 --- a/server/controller/src/main/resources/application.yaml +++ b/server/controller/src/main/resources/application.yaml @@ -32,6 +32,12 @@ sw: controller: apiPrefix: /api/v1 whiteList: /api/v1/report + datastore: + fsRootDir: ${SW_DATA_STORE_FS_ROOT_DIR:.} + walFileSize: ${SW_DATASTORE_WAL_FILE_SIZE:65536} + walMaxFileSize: ${SW_DATASTORE_WAL_MAX_FILE_SIZE:67108864} + walPrefix: ${SW_DATASTORE_WAL_LOG_PREFIX:wal/} + walWaitIntervalMillis: ${SW_DATASTORE_WAL_WAIT_INTERVAL_MILLIS:500} --- #Development spring: diff --git a/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java index 4bc550e75e..64cf0e5954 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.starwhale.mlops.api; import ai.starwhale.mlops.api.protocol.datastore.ColumnDesc; @@ -15,13 +30,16 @@ import ai.starwhale.mlops.datastore.OrderByDesc; import ai.starwhale.mlops.datastore.TableQueryFilter; import ai.starwhale.mlops.datastore.TableSchemaDesc; +import ai.starwhale.mlops.datastore.WalManager; import ai.starwhale.mlops.exception.SWValidationException; import brave.internal.collect.Lists; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -30,6 +48,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.BDDMockito.given; public class DataStoreControllerTest { private DataStoreController controller; @@ -37,7 +56,9 @@ public class DataStoreControllerTest { @BeforeEach public void setUp() { this.controller = new DataStoreController(); - this.controller.setDataStore(new DataStore()); + var walManager = Mockito.mock(WalManager.class); + given(walManager.readAll()).willReturn(Collections.emptyIterator()); + this.controller.setDataStore(new DataStore(walManager)); } @Test @@ -744,7 +765,7 @@ public void setUp() { setValues(List.of(new RecordValueDesc() {{ setKey("k"); setValue("2"); - }},new RecordValueDesc() {{ + }}, new RecordValueDesc() {{ setKey("a"); }})); }})); diff --git a/server/controller/src/test/java/ai/starwhale/mlops/datastore/ColumnSchemaTest.java b/server/controller/src/test/java/ai/starwhale/mlops/datastore/ColumnSchemaTest.java index aaae1f9e65..06c5831008 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/datastore/ColumnSchemaTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/datastore/ColumnSchemaTest.java @@ -1,7 +1,20 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.starwhale.mlops.datastore; -import ai.starwhale.mlops.datastore.ColumnSchema; -import ai.starwhale.mlops.datastore.ColumnSchemaDesc; import ai.starwhale.mlops.exception.SWValidationException; import org.junit.jupiter.api.Test; @@ -13,23 +26,23 @@ public class ColumnSchemaTest { @Test public void testConstructor() { - new ColumnSchema(new ColumnSchemaDesc("k", "STRING")); + new ColumnSchema(new ColumnSchemaDesc("k", "STRING"), 0); } @Test public void testConstructorException() { - assertThrows(NullPointerException.class, () -> new ColumnSchema(null), "null schema"); + assertThrows(NullPointerException.class, () -> new ColumnSchema(null, 0), "null schema"); assertThrows(SWValidationException.class, - () -> new ColumnSchema(new ColumnSchemaDesc(null, "STRING")), + () -> new ColumnSchema(new ColumnSchemaDesc(null, "STRING"), 0), "null column name"); assertThrows(SWValidationException.class, - () -> new ColumnSchema(new ColumnSchemaDesc("k", null)), + () -> new ColumnSchema(new ColumnSchemaDesc("k", null), 0), "null type"); assertThrows(SWValidationException.class, - () -> new ColumnSchema(new ColumnSchemaDesc("k", "invalid")), + () -> new ColumnSchema(new ColumnSchemaDesc("k", "invalid"), 0), "invalid type"); } } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/datastore/DataStoreTest.java b/server/controller/src/test/java/ai/starwhale/mlops/datastore/DataStoreTest.java index c119d05908..8f2bc631d3 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/datastore/DataStoreTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/datastore/DataStoreTest.java @@ -1,8 +1,28 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.starwhale.mlops.datastore; +import ai.starwhale.mlops.memory.impl.SwByteBufferManager; +import ai.starwhale.mlops.objectstore.impl.FileSystemObjectStore; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import java.io.File; +import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -12,16 +32,27 @@ import static org.hamcrest.Matchers.nullValue; public class DataStoreTest { + @TempDir + private File rootDir; private DataStore dataStore; + private SwByteBufferManager bufferManager; + + private FileSystemObjectStore objectStore; + + private WalManager walManager; + @BeforeEach - public void setUp() { - this.dataStore = new DataStore(); + public void setUp() throws IOException { + this.bufferManager = new SwByteBufferManager(); + this.objectStore = new FileSystemObjectStore(bufferManager, this.rootDir.getAbsolutePath()); + this.walManager = new WalManager(objectStore, bufferManager, 256, 4096, "test/", 10); + this.dataStore = new DataStore(this.walManager); } @Test - public void testUpdate() { + public void testUpdate() throws IOException { this.dataStore.update("t1", new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "STRING"), @@ -68,6 +99,28 @@ public void testUpdate() { .build()) .getRecords(), is(List.of(Map.of("k", "3", "x", "2")))); + + this.dataStore.terminate(); + this.walManager = new WalManager(this.objectStore, this.bufferManager, 256, 4096, "test/", 10); + this.dataStore = new DataStore(this.walManager); + assertThat("t1", + this.dataStore.scan(DataStoreScanRequest.builder() + .tables(List.of(DataStoreScanRequest.TableInfo.builder() + .tableName("t1") + .columns(Map.of("k", "k", "a", "a")) + .build())) + .build()) + .getRecords(), + is(List.of(Map.of("k", "1", "a", "2")))); + assertThat("t2", + this.dataStore.scan(DataStoreScanRequest.builder() + .tables(List.of(DataStoreScanRequest.TableInfo.builder() + .tableName("t2") + .columns(Map.of("k", "k", "x", "x")) + .build())) + .build()) + .getRecords(), + is(List.of(Map.of("k", "3", "x", "2")))); } @Test diff --git a/server/controller/src/test/java/ai/starwhale/mlops/datastore/TableSchemaTest.java b/server/controller/src/test/java/ai/starwhale/mlops/datastore/TableSchemaTest.java index 735022bc93..21039f85b3 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/datastore/TableSchemaTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/datastore/TableSchemaTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.starwhale.mlops.datastore; import ai.starwhale.mlops.exception.SWValidationException; @@ -9,6 +24,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -115,7 +131,7 @@ public void testCopyConstructor() { public void testGetColumnSchemaByName() { assertThat("common", this.schema.getColumnSchemaByName("k"), - is(new ColumnSchema(new ColumnSchemaDesc("k", "STRING")))); + is(new ColumnSchema(new ColumnSchemaDesc("k", "STRING"), 0))); assertThat("null", this.schema.getColumnSchemaByName("x"), nullValue()); } @@ -123,8 +139,8 @@ public void testGetColumnSchemaByName() { public void testGetColumnSchemas() { var columnSchemas = this.schema.getColumnSchemas(); assertThat("equals", columnSchemas, containsInAnyOrder( - new ColumnSchema(new ColumnSchemaDesc("k", "STRING")), - new ColumnSchema(new ColumnSchemaDesc("a", "INT32")))); + new ColumnSchema(new ColumnSchemaDesc("k", "STRING"), 0), + new ColumnSchema(new ColumnSchemaDesc("a", "INT32"), 1))); assertThrows(UnsupportedOperationException.class, columnSchemas::clear, "read only"); } @@ -146,41 +162,64 @@ public void testMerge() { new ColumnSchemaDesc("a", "STRING")))), "conflicting type 2"); - this.schema.merge(new TableSchemaDesc("k", List.of( + var diff = this.schema.merge(new TableSchemaDesc("k", List.of( new ColumnSchemaDesc("k", "STRING"), new ColumnSchemaDesc("b", "FLOAT32")))); assertThat("new column", this.schema.getColumnSchemas(), - containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING")), - new ColumnSchema(new ColumnSchemaDesc("a", "INT32")), - new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32")))); + containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING"), 0), + new ColumnSchema(new ColumnSchemaDesc("a", "INT32"), 1), + new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32"), 2))); + assertThat("new column", diff, containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32"), 2))); - this.schema.merge(new TableSchemaDesc(null, List.of(new ColumnSchemaDesc("x", "UNKNOWN")))); + diff = this.schema.merge(new TableSchemaDesc(null, List.of(new ColumnSchemaDesc("x", "UNKNOWN")))); assertThat("new unknown column", this.schema.getColumnSchemas(), - containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING")), - new ColumnSchema(new ColumnSchemaDesc("a", "INT32")), - new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32")), - new ColumnSchema(new ColumnSchemaDesc("x", "UNKNOWN")))); - - this.schema.merge(new TableSchemaDesc(null, List.of(new ColumnSchemaDesc("k", "UNKNOWN")))); + containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING"), 0), + new ColumnSchema(new ColumnSchemaDesc("a", "INT32"), 1), + new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32"), 2), + new ColumnSchema(new ColumnSchemaDesc("x", "UNKNOWN"), 3))); + assertThat("new unknown column", + diff, + containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("x", "UNKNOWN"), 3))); + + diff = this.schema.merge(new TableSchemaDesc(null, List.of(new ColumnSchemaDesc("k", "UNKNOWN")))); assertThat("merge unknown to existing", this.schema.getColumnSchemas(), - containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING")), - new ColumnSchema(new ColumnSchemaDesc("a", "INT32")), - new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32")), - new ColumnSchema(new ColumnSchemaDesc("x", "UNKNOWN")))); + containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING"), 0), + new ColumnSchema(new ColumnSchemaDesc("a", "INT32"), 1), + new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32"), 2), + new ColumnSchema(new ColumnSchemaDesc("x", "UNKNOWN"), 3))); + assertThat("merge unknown to existing", diff, empty()); - this.schema.merge(new TableSchemaDesc(null, List.of(new ColumnSchemaDesc("x", "UNKNOWN")))); + diff = this.schema.merge(new TableSchemaDesc(null, List.of(new ColumnSchemaDesc("x", "UNKNOWN")))); assertThat("merge unknown to unknown", this.schema.getColumnSchemas(), - containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING")), - new ColumnSchema(new ColumnSchemaDesc("a", "INT32")), - new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32")), - new ColumnSchema(new ColumnSchemaDesc("x", "UNKNOWN")))); + containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING"), 0), + new ColumnSchema(new ColumnSchemaDesc("a", "INT32"), 1), + new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32"), 2), + new ColumnSchema(new ColumnSchemaDesc("x", "UNKNOWN"), 3))); + assertThat("merge unknown to unknown", diff, empty()); - this.schema.merge(new TableSchemaDesc(null, List.of(new ColumnSchemaDesc("x", "INT32")))); + diff = this.schema.merge(new TableSchemaDesc(null, List.of(new ColumnSchemaDesc("x", "INT32")))); assertThat("merge other to unknown", this.schema.getColumnSchemas(), - containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING")), - new ColumnSchema(new ColumnSchemaDesc("a", "INT32")), - new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32")), - new ColumnSchema(new ColumnSchemaDesc("x", "INT32")))); + containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING"), 0), + new ColumnSchema(new ColumnSchemaDesc("a", "INT32"), 1), + new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32"), 2), + new ColumnSchema(new ColumnSchemaDesc("x", "INT32"), 3))); + assertThat("merge other to unknown", + diff, + containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("x", "INT32"), 3))); + + diff = this.schema.merge(new TableSchemaDesc(null, + List.of(new ColumnSchemaDesc("y", "STRING"), new ColumnSchemaDesc("z", "BYTES")))); + assertThat("merge multiple", this.schema.getColumnSchemas(), + containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("k", "STRING"), 0), + new ColumnSchema(new ColumnSchemaDesc("a", "INT32"), 1), + new ColumnSchema(new ColumnSchemaDesc("b", "FLOAT32"), 2), + new ColumnSchema(new ColumnSchemaDesc("x", "INT32"), 3), + new ColumnSchema(new ColumnSchemaDesc("y", "STRING"), 4), + new ColumnSchema(new ColumnSchemaDesc("z", "BYTES"), 5))); + assertThat("merge multiple", + diff, + containsInAnyOrder(new ColumnSchema(new ColumnSchemaDesc("y", "STRING"), 4), + new ColumnSchema(new ColumnSchemaDesc("z", "BYTES"), 5))); } @Test diff --git a/server/controller/src/test/java/ai/starwhale/mlops/datastore/WalManagerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/datastore/WalManagerTest.java new file mode 100644 index 0000000000..5dd25cdbed --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/datastore/WalManagerTest.java @@ -0,0 +1,295 @@ +package ai.starwhale.mlops.datastore; + +import ai.starwhale.mlops.exception.SWValidationException; +import ai.starwhale.mlops.memory.SwBufferManager; +import ai.starwhale.mlops.memory.impl.SwByteBufferManager; +import ai.starwhale.mlops.objectstore.ObjectStore; +import ai.starwhale.mlops.objectstore.impl.FileSystemObjectStore; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import com.google.protobuf.CodedOutputStream; +import org.apache.commons.lang3.tuple.Triple; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.Mockito; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class WalManagerTest { + @TempDir + private File rootDir; + + private SwBufferManager bufferManager; + + private FileSystemObjectStore objectStore; + + private WalManager walManager; + + @BeforeEach + public void setUp() throws IOException { + this.bufferManager = new SwByteBufferManager(); + this.objectStore = new FileSystemObjectStore(this.bufferManager, this.rootDir.getAbsolutePath()); + this.walManager = new WalManager(this.objectStore, this.bufferManager, 256, 4096, "test/", 10); + } + + + @AfterEach + public void tearDown() { + this.walManager.terminate(); + } + + private Wal.TableSchema createTableSchema(String keyColumn, List> columns) { + var builder = Wal.TableSchema.newBuilder(); + if (keyColumn != null) { + builder.setKeyColumn(keyColumn); + } + for (var triple : columns) { + builder.addColumns(Wal.ColumnSchema.newBuilder() + .setColumnIndex(triple.getLeft()) + .setColumnName(triple.getMiddle()) + .setColumnType(triple.getRight())); + } + return builder.build(); + } + + private List createRecords(List> records) { + var ret = new ArrayList(); + for (var record : records) { + var recordBuilder = Wal.Record.newBuilder(); + for (var entry : record.entrySet()) { + var columnBuilder = Wal.Column.newBuilder(); + columnBuilder.setIndex(entry.getKey()); + if (entry.getValue() == null) { + columnBuilder.setNullValue(true); + } else { + if (entry.getValue() instanceof Boolean) { + columnBuilder.setBoolValue((Boolean) entry.getValue()); + } else if (entry.getValue() instanceof Byte) { + columnBuilder.setIntValue((Byte) entry.getValue()); + } else if (entry.getValue() instanceof Short) { + columnBuilder.setIntValue((Short) entry.getValue()); + } else if (entry.getValue() instanceof Integer) { + columnBuilder.setIntValue((Integer) entry.getValue()); + } else if (entry.getValue() instanceof Long) { + columnBuilder.setIntValue((Long) entry.getValue()); + } else if (entry.getValue() instanceof Float) { + columnBuilder.setFloatValue((Float) entry.getValue()); + } else if (entry.getValue() instanceof Double) { + columnBuilder.setDoubleValue((Double) entry.getValue()); + } else if (entry.getValue() instanceof String) { + columnBuilder.setStringValue((String) entry.getValue()); + } else if (entry.getValue() instanceof ByteBuffer) { + columnBuilder.setBytesValue(ByteString.copyFrom((ByteBuffer) entry.getValue())); + } + } + recordBuilder.addColumns(columnBuilder); + } + ret.add(recordBuilder.build()); + } + return ret; + } + + @Test + public void testSimple() throws IOException, InterruptedException { + var entries = List.of( + Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t1") + .setTableSchema(this.createTableSchema("k", + List.of(Triple.of(1, "k", "STRING"), + Triple.of(2, "a", "INT32")))) + .addAllRecords(this.createRecords(List.of( + Map.of(1, "a", 2, 1), + Map.of(1, "b", 2, 2), + Map.of(1, "c", 2, 3) + ))) + .build(), + Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t2") + .setTableSchema(this.createTableSchema(null, + List.of(Triple.of(3, "x", "INT32")))) + .addAllRecords(this.createRecords(List.of( + Map.of(1, "a", 3, 1), + Map.of(1, "b", 3, 2), + Map.of(1, "c", 3, 3) + ))) + .build(), + Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t2") + .setTableSchema(this.createTableSchema(null, + List.of(Triple.of(4, "y", "STRING")))) + .addAllRecords(this.createRecords(List.of( + Map.of(1, "a", 4, "a".repeat(100)), + Map.of(1, "b", 4, "b".repeat(100)), + Map.of(1, "c", 4, "c".repeat(100)) + ))) + .build(), + Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t2") + .setTableSchema(this.createTableSchema(null, + List.of(Triple.of(4, "y", "STRING")))) + .addAllRecords(this.createRecords(List.of( + Map.of(1, "a", 4, "a".repeat(10)), + Map.of(1, "b", 4, "b".repeat(10)), + Map.of(1, "c", 4, "c".repeat(10)) + ))) + .build()); + this.walManager.append(entries.get(0)); + this.walManager.append(entries.get(1)); + Thread.sleep(50); + this.walManager.append(entries.get(2)); + Thread.sleep(50); + this.walManager.append(entries.get(3)); + this.walManager.terminate(); + assertThat(ImmutableList.copyOf(this.objectStore.list("")), is(List.of("test/wal.log.0", "test/wal.log.1"))); + this.walManager = new WalManager(this.objectStore, this.bufferManager, 256, 4096, "test/", 10); + assertThat(ImmutableList.copyOf(this.walManager.readAll()), is(entries)); + } + + @Test + public void testMany() throws IOException, InterruptedException { + List entries = new ArrayList<>(); + entries.add(Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t") + .setTableSchema(this.createTableSchema("k", List.of(Triple.of(1, "k", "STRING")))) + .build()); + for (int i = 0; i < 50000; ++i) { + entries.add(Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t") + .addAllRecords(this.createRecords(List.of(Map.of(1, "" + i)))) + .build()); + } + for (var entry : entries) { + this.walManager.append(entry); + } + this.walManager.terminate(); + this.walManager = new WalManager(this.objectStore, this.bufferManager, 256, 4096, "test/", 10); + assertThat(ImmutableList.copyOf(this.walManager.readAll()), is(entries)); + } + + @Test + public void testHugeEntry() throws IOException { + var entry = Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t") + .setTableSchema(this.createTableSchema("k", List.of(Triple.of(1, "k", "INT32")))) + .addAllRecords(this.createRecords(IntStream.range(1, 5000) + .mapToObj(i -> Map.of(1, (Object) i)) + .collect(Collectors.toList()))) + .build(); + this.walManager.append(entry); + this.walManager.terminate(); + this.walManager = new WalManager(this.objectStore, this.bufferManager, 256, 4096, "test/", 10); + var entries = ImmutableList.copyOf(this.walManager.readAll()); + assertThat(entries.size(), greaterThan(1)); + assertThat(entries.get(0).getTableSchema(), is(entry.getTableSchema())); + int index = 1; + for (var e : entries) { + for (var r : e.getRecordsList()) { + for (var c : r.getColumnsList()) { + assertThat("index", c.getIndex(), is(1)); + assertThat("value", c.getIntValue(), is((long) index)); + ++index; + } + } + } + assertThat("count", index, is(5000)); + } + + @Test + public void testAppendHugeSchema() { + assertThrows(SWValidationException.class, () -> this.walManager.append(Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t") + .setTableSchema(this.createTableSchema("k", + IntStream.range(1, 5000) + .mapToObj(i -> Triple.of(1, "" + i, "INT32")) + .collect(Collectors.toList()))) + .build())); + } + + @Test + public void testAppendHugeSingleRecord() { + assertThrows(SWValidationException.class, () -> this.walManager.append(Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t") + .addAllRecords(this.createRecords(List.of(IntStream.range(1, 5000) + .boxed() + .collect(Collectors.toMap(i -> i, i -> i))))) + .build())); + } + + @Test + public void testAppendSplitSizeCalculation() throws IOException { + var builder = Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t") + .setTableSchema(this.createTableSchema("k", List.of(Triple.of(1, "k", "INT32")))) + .addAllRecords(this.createRecords(IntStream.range(1, 200) + .mapToObj(i -> Map.of(1, (Object) i)) + .collect(Collectors.toList()))); + var entry1 = builder.build(); + var entry2 = Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t") + .addAllRecords(this.createRecords(List.of(Map.of(1, 1)))) + .build(); + // make max file size equal the message + this.walManager = new WalManager(this.objectStore, + this.bufferManager, + 256, + entry1.getSerializedSize() + CodedOutputStream.computeUInt32SizeNoTag(entry1.getSerializedSize()), + "test/", + 10); + builder.addAllRecords(entry2.getRecordsList()); + this.walManager.append(builder.build()); + this.walManager.terminate(); + this.walManager = new WalManager(this.objectStore, this.bufferManager, 256, 4096, "test/", 10); + assertThat(ImmutableList.copyOf(this.walManager.readAll()), is(List.of(entry1, entry2))); + } + + @Test + public void testWriteFailureAndRetry() throws Exception { + var objectStore = Mockito.mock(ObjectStore.class); + given(objectStore.list(anyString())).willReturn(Collections.emptyIterator()); + doThrow(new IOException()) + .doThrow(new IOException()) + .doNothing() + .when(objectStore).put(anyString(), any()); + var walManager = new WalManager(objectStore, this.bufferManager, 256, 4096, "test/", 10); + walManager.append(Wal.WalEntry.newBuilder() + .setEntryType(Wal.WalEntry.Type.UPDATE) + .setTableName("t") + .build()); + Thread.sleep(1000); + walManager.terminate(); + verify(objectStore, times(3)).put(eq("test/wal.log.0"), any()); + } +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/datastore/impl/MemoryTableImplTest.java b/server/controller/src/test/java/ai/starwhale/mlops/datastore/impl/MemoryTableImplTest.java index 229df2eea7..c0d87a641a 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/datastore/impl/MemoryTableImplTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/datastore/impl/MemoryTableImplTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.starwhale.mlops.datastore.impl; import ai.starwhale.mlops.datastore.ColumnSchema; @@ -7,12 +22,20 @@ import ai.starwhale.mlops.datastore.OrderByDesc; import ai.starwhale.mlops.datastore.TableQueryFilter; import ai.starwhale.mlops.datastore.TableScanIterator; +import ai.starwhale.mlops.datastore.TableSchema; import ai.starwhale.mlops.datastore.TableSchemaDesc; +import ai.starwhale.mlops.datastore.WalManager; import ai.starwhale.mlops.exception.SWValidationException; +import ai.starwhale.mlops.memory.SwBufferManager; +import ai.starwhale.mlops.memory.impl.SwByteBufferManager; +import ai.starwhale.mlops.objectstore.impl.FileSystemObjectStore; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import java.io.File; +import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -45,8 +68,8 @@ var record = it.getRecord(); return ret; } - private static List> scanAll(MemoryTable memoryTable) { - return MemoryTableImplTest.getRecords(memoryTable.scan(null, null, true, null, false, false)); + private static List> scanAll(MemoryTable memoryTable, boolean keepNone) { + return MemoryTableImplTest.getRecords(memoryTable.scan(null, null, true, null, false, keepNone)); } private static List> decodeRecords(Map columnTypeMap, @@ -62,13 +85,25 @@ var record = new HashMap(); .collect(Collectors.toList()); } + @TempDir + private File rootDir; + + private WalManager walManager; + + @BeforeEach + public void setUp() throws IOException { + SwBufferManager bufferManager = new SwByteBufferManager(); + FileSystemObjectStore objectStore = new FileSystemObjectStore(bufferManager, this.rootDir.getAbsolutePath()); + this.walManager = new WalManager(objectStore, bufferManager, 256, 4096, "test/", 10); + } + @Nested public class UpdateTest { private MemoryTableImpl memoryTable; @BeforeEach public void setUp() { - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); } @Test @@ -78,10 +113,10 @@ public void testUpdateCommon() { new ColumnSchemaDesc("k", "STRING"), new ColumnSchemaDesc("a", "INT32"))), List.of(Map.of("k", "0", "a", "a"))); - assertThat("init", scanAll(this.memoryTable), contains(Map.of("k", "0", "a", "a"))); + assertThat("init", scanAll(this.memoryTable, false), contains(Map.of("k", "0", "a", "a"))); this.memoryTable.update(null, List.of(Map.of("k", "1", "a", "b"))); - assertThat("insert", scanAll(this.memoryTable), contains( + assertThat("insert", scanAll(this.memoryTable, false), contains( Map.of("k", "0", "a", "a"), Map.of("k", "1", "a", "b"))); @@ -89,21 +124,21 @@ public void testUpdateCommon() { null, List.of(Map.of("k", "2", "a", "c"), Map.of("k", "3", "a", "d"))); - assertThat("insert multiple", scanAll(this.memoryTable), contains( + assertThat("insert multiple", scanAll(this.memoryTable, false), contains( Map.of("k", "0", "a", "a"), Map.of("k", "1", "a", "b"), Map.of("k", "2", "a", "c"), Map.of("k", "3", "a", "d"))); this.memoryTable.update(null, List.of(Map.of("k", "1", "a", "c"))); - assertThat("overwrite", scanAll(this.memoryTable), contains( + assertThat("overwrite", scanAll(this.memoryTable, false), contains( Map.of("k", "0", "a", "a"), Map.of("k", "1", "a", "c"), Map.of("k", "2", "a", "c"), Map.of("k", "3", "a", "d"))); this.memoryTable.update(null, List.of(Map.of("k", "2", "-", "1"))); - assertThat("delete", scanAll(this.memoryTable), contains( + assertThat("delete", scanAll(this.memoryTable, false), contains( Map.of("k", "0", "a", "a"), Map.of("k", "1", "a", "c"), Map.of("k", "3", "a", "d"))); @@ -111,7 +146,7 @@ public void testUpdateCommon() { this.memoryTable.update( new TableSchemaDesc(null, List.of(new ColumnSchemaDesc("b", "INT32"))), List.of(Map.of("k", "1", "b", "0"))); - assertThat("new column", scanAll(this.memoryTable), contains( + assertThat("new column", scanAll(this.memoryTable, false), contains( Map.of("k", "0", "a", "a"), Map.of("k", "1", "a", "c", "b", "0"), Map.of("k", "3", "a", "d"))); @@ -130,7 +165,7 @@ public void testUpdateCommon() { put("k", "3"); put("b", null); }})); - assertThat("null value", scanAll(this.memoryTable), contains( + assertThat("null value", scanAll(this.memoryTable, false), contains( Map.of("k", "0", "a", "a"), Map.of("k", "1", "b", "0"), Map.of("k", "2"), @@ -149,7 +184,7 @@ public void testUpdateCommon() { }}, Map.of("k", "0", "-", "1"), Map.of("k", "2", "-", "1"))); - assertThat("mixed", scanAll(this.memoryTable), contains( + assertThat("mixed", scanAll(this.memoryTable, false), contains( Map.of("k", "1", "c", "1"), Map.of("k", "3", "a", "0"), Map.of("k", "4", "c", "0"))); @@ -157,7 +192,7 @@ public void testUpdateCommon() { this.memoryTable.update( new TableSchemaDesc(null, List.of(new ColumnSchemaDesc("a-b/c/d:e_f", "INT32"))), List.of(Map.of("k", "0", "a-b/c/d:e_f", "0"))); - assertThat("complex name", scanAll(this.memoryTable), contains( + assertThat("complex name", scanAll(this.memoryTable, false), contains( Map.of("k", "0", "a-b/c/d:e_f", "0"), Map.of("k", "1", "c", "1"), Map.of("k", "3", "a", "0"), @@ -172,7 +207,7 @@ public void testUpdateCommon() { assertThat("unknown", this.memoryTable.getSchema().getColumnSchemaByName("x").getType(), is(ColumnType.UNKNOWN)); - assertThat("unknown", scanAll(this.memoryTable), contains( + assertThat("unknown", scanAll(this.memoryTable, false), contains( Map.of("k", "0", "a-b/c/d:e_f", "0"), Map.of("k", "1", "c", "1"), Map.of("k", "3", "a", "0"), @@ -184,7 +219,7 @@ public void testUpdateCommon() { assertThat("update unknown", this.memoryTable.getSchema().getColumnSchemaByName("x").getType(), is(ColumnType.INT32)); - assertThat("update unknown", scanAll(this.memoryTable), contains( + assertThat("update unknown", scanAll(this.memoryTable, false), contains( Map.of("k", "0", "a-b/c/d:e_f", "0"), Map.of("k", "1", "c", "1", "x", "1"), Map.of("k", "3", "a", "0"), @@ -199,7 +234,7 @@ public void testUpdateCommon() { assertThat("unknown again", this.memoryTable.getSchema().getColumnSchemaByName("x").getType(), is(ColumnType.INT32)); - assertThat("unknown again", scanAll(this.memoryTable), contains( + assertThat("unknown again", scanAll(this.memoryTable, false), contains( Map.of("k", "0", "a-b/c/d:e_f", "0"), Map.of("k", "1", "c", "1"), Map.of("k", "3", "a", "0"), @@ -218,82 +253,85 @@ public void testUpdateAllColumnTypes() { new ColumnSchemaDesc("e", "INT64"), new ColumnSchemaDesc("f", "FLOAT32"), new ColumnSchemaDesc("g", "FLOAT64"), - new ColumnSchemaDesc("h", "BYTES"))), - List.of(Map.of("key", "x", - "a", "1", - "b", "10", - "c", "1000", - "d", "100000", - "e", "10000000", - "f", Integer.toHexString(Float.floatToIntBits(1.1f)), - "g", Long.toHexString(Double.doubleToLongBits(1.1)), - "h", Base64.getEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))))); - assertThat("all types", scanAll(this.memoryTable), contains( - Map.of( - "key", "x", - "a", "1", - "b", "10", - "c", "1000", - "d", "100000", - "e", "10000000", - "f", Integer.toHexString(Float.floatToIntBits(1.1f)), - "g", Long.toHexString(Double.doubleToLongBits(1.1)), - "h", Base64.getEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)))) - - ); + new ColumnSchemaDesc("h", "BYTES"), + new ColumnSchemaDesc("i", "UNKNOWN"))), + List.of(new HashMap<>() {{ + put("key", "x"); + put("a", "1"); + put("b", "10"); + put("c", "1000"); + put("d", "100000"); + put("e", "10000000"); + put("f", Integer.toHexString(Float.floatToIntBits(1.1f))); + put("g", Long.toHexString(Double.doubleToLongBits(1.1))); + put("h", Base64.getEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))); + put("i", null); + }})); + assertThat("all types", scanAll(this.memoryTable, false), contains( + new HashMap<>() {{ + put("key", "x"); + put("a", "1"); + put("b", "10"); + put("c", "1000"); + put("d", "100000"); + put("e", "10000000"); + put("f", Integer.toHexString(Float.floatToIntBits(1.1f))); + put("g", Long.toHexString(Double.doubleToLongBits(1.1))); + put("h", Base64.getEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))); + }})); } @Test public void testUpdateAllKeyColumnTypes() { - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); this.memoryTable.update( new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "BOOL"))), List.of(Map.of("k", "1"))); - assertThat("bool", scanAll(this.memoryTable), contains(Map.of("k", "1"))); + assertThat("bool", scanAll(this.memoryTable, false), contains(Map.of("k", "1"))); - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); this.memoryTable.update( new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "INT8"))), List.of(Map.of("k", "10"))); - assertThat("int8", scanAll(this.memoryTable), contains(Map.of("k", "10"))); + assertThat("int8", scanAll(this.memoryTable, false), contains(Map.of("k", "10"))); - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); this.memoryTable.update( new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "INT16"))), List.of(Map.of("k", "1000"))); - assertThat("int16", scanAll(this.memoryTable), contains(Map.of("k", "1000"))); + assertThat("int16", scanAll(this.memoryTable, false), contains(Map.of("k", "1000"))); - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); this.memoryTable.update( new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "INT32"))), List.of(Map.of("k", "100000"))); - assertThat("int32", scanAll(this.memoryTable), contains(Map.of("k", "100000"))); + assertThat("int32", scanAll(this.memoryTable, false), contains(Map.of("k", "100000"))); - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("", MemoryTableImplTest.this.walManager); this.memoryTable.update( new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "INT64"))), List.of(Map.of("k", "10000000"))); - assertThat("int64", scanAll(this.memoryTable), contains(Map.of("k", "10000000"))); + assertThat("int64", scanAll(this.memoryTable, false), contains(Map.of("k", "10000000"))); - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("", MemoryTableImplTest.this.walManager); this.memoryTable.update( new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "FLOAT32"))), List.of(Map.of("k", Integer.toHexString(Float.floatToIntBits(1.1f))))); - assertThat("float32", scanAll(this.memoryTable), contains( + assertThat("float32", scanAll(this.memoryTable, false), contains( Map.of("k", Integer.toHexString(Float.floatToIntBits(1.1f))))); - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("", MemoryTableImplTest.this.walManager); this.memoryTable.update( new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "FLOAT64"))), List.of(Map.of("k", Long.toHexString(Double.doubleToLongBits(1.1))))); - assertThat("float64", scanAll(this.memoryTable), contains( + assertThat("float64", scanAll(this.memoryTable, false), contains( Map.of("k", Long.toHexString(Double.doubleToLongBits(1.1))))); - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("", MemoryTableImplTest.this.walManager); this.memoryTable.update( new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "BYTES"))), List.of(Map.of("k", Base64.getEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))))); - assertThat("bytes", scanAll(this.memoryTable), contains( + assertThat("bytes", scanAll(this.memoryTable, false), contains( Map.of("k", Base64.getEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))))); } @@ -308,7 +346,7 @@ public void testUpdateExceptions() { new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("a", "INT32"))), List.of(Map.of("k", "0"), Map.of("k", "1"))), "no key column schema"); - assertThat("no key column schema", scanAll(this.memoryTable), empty()); + assertThat("no key column schema", scanAll(this.memoryTable, false), empty()); assertThrows(SWValidationException.class, () -> this.memoryTable.update( @@ -317,14 +355,14 @@ public void testUpdateExceptions() { new ColumnSchemaDesc("-", "INT32"))), List.of(Map.of("k", "0"))), "invalid column name"); - assertThat("invalid column name", scanAll(this.memoryTable), empty()); + assertThat("invalid column name", scanAll(this.memoryTable, false), empty()); assertThrows(SWValidationException.class, () -> this.memoryTable.update( new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "STRING"))), List.of(Map.of("k", "0"), Map.of("k", "1", "a", "1"))), "extra column data"); - assertThat("extra column data", scanAll(this.memoryTable), empty()); + assertThat("extra column data", scanAll(this.memoryTable, false), empty()); assertThrows(SWValidationException.class, () -> this.memoryTable.update( @@ -333,10 +371,74 @@ public void testUpdateExceptions() { new ColumnSchemaDesc("a", "INT32"))), List.of(Map.of("k", "0"), Map.of("k", "1", "a", "h"))), "fail to decode"); - assertThat("fail to decode", scanAll(this.memoryTable), empty()); + assertThat("fail to decode", scanAll(this.memoryTable, false), empty()); + } + + @Test + public void testUpdateWalError() { + var schema = new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "STRING"))); + assertThrows(SWValidationException.class, + () -> this.memoryTable.update( + schema, + List.of(Map.of("k", "a".repeat(5000)))), + "huge entry"); + assertThat("null", this.memoryTable.getSchema(), nullValue()); + this.memoryTable.update(schema, null); + assertThrows(SWValidationException.class, + () -> this.memoryTable.update( + null, + List.of(Map.of("k", "a".repeat(5000)))), + "huge entry"); + assertThat("schema", this.memoryTable.getSchema(), is(new TableSchema(schema))); + assertThat("records", scanAll(this.memoryTable, false), empty()); } - } + @Test + public void testUpdateFromWal() throws IOException { + this.memoryTable.update( + new TableSchemaDesc("key", List.of( + new ColumnSchemaDesc("key", "STRING"), + new ColumnSchemaDesc("a", "BOOL"), + new ColumnSchemaDesc("b", "INT8"), + new ColumnSchemaDesc("c", "INT16"), + new ColumnSchemaDesc("d", "INT32"), + new ColumnSchemaDesc("e", "INT64"), + new ColumnSchemaDesc("f", "FLOAT32"), + new ColumnSchemaDesc("g", "FLOAT64"), + new ColumnSchemaDesc("h", "BYTES"), + new ColumnSchemaDesc("i", "UNKNOWN"))), + null); + List> records = new ArrayList<>(); + for (int i = 0; i < 100; ++i) { + final int index = i; + records.add(new HashMap<>() {{ + put("key", String.format("%03d", index)); + put("a", "" + index % 2); + put("b", Integer.toHexString(index + 10)); + put("c", Integer.toHexString(index + 1000)); + put("d", Integer.toHexString(index + 100000)); + put("e", Integer.toHexString(index + 10000000)); + put("f", Integer.toHexString(Float.floatToIntBits(index + 0.1f))); + put("g", Long.toHexString(Double.doubleToLongBits(index + 0.1))); + put("h", Base64.getEncoder().encodeToString( + ("test" + index).getBytes(StandardCharsets.UTF_8))); + put("i", null); + }}); + } + this.memoryTable.update(null, records); + MemoryTableImplTest.this.walManager.terminate(); + SwBufferManager bufferManager = new SwByteBufferManager(); + FileSystemObjectStore objectStore = new FileSystemObjectStore(bufferManager, + MemoryTableImplTest.this.rootDir.getAbsolutePath()); + MemoryTableImplTest.this.walManager = new WalManager(objectStore, bufferManager, 256, 4096, "test/", 10); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); + var it = MemoryTableImplTest.this.walManager.readAll(); + while (it.hasNext()) { + this.memoryTable.updateFromWal(it.next()); + } + assertThat(scanAll(this.memoryTable, true), is(records)); + } + } @Nested public class QueryScanTest { @@ -402,13 +504,13 @@ public void setUp() { new ColumnSchemaDesc("h", "STRING"), new ColumnSchemaDesc("i", "BYTES"), new ColumnSchemaDesc("z", "UNKNOWN"))); - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); this.memoryTable.update(schema, records); } @Test public void testQueryInitialEmptyTable() { - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); var recordList = this.memoryTable.query(null, null, null, -1, -1); assertThat("empty", recordList.getColumnTypeMap(), nullValue()); assertThat("empty", recordList.getRecords(), empty()); @@ -416,7 +518,7 @@ public void testQueryInitialEmptyTable() { @Test public void testQueryEmptyTableWithSchema() { - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); this.memoryTable.update(new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "STRING"))), List.of(Map.of("k", "0", "-", "1"))); var recordList = this.memoryTable.query(null, null, null, -1, -1); @@ -1811,7 +1913,7 @@ public void testQueryUnknown() { @Test public void testScanInitialEmptyTable() { - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); var it = this.memoryTable.scan(null, null, false, null, false, false); var recordList = MemoryTableImplTest.getRecords(it); assertThat("empty", it.getColumnTypeMapping(), nullValue()); @@ -1820,7 +1922,7 @@ public void testScanInitialEmptyTable() { @Test public void testScanEmptyTableWithSchema() { - this.memoryTable = new MemoryTableImpl(); + this.memoryTable = new MemoryTableImpl("test", MemoryTableImplTest.this.walManager); this.memoryTable.update(new TableSchemaDesc("k", List.of(new ColumnSchemaDesc("k", "STRING"))), List.of(Map.of("k", "0", "-", "1"))); var it = this.memoryTable.scan(null, null, false, null, false, false); @@ -1950,4 +2052,4 @@ public void testScanUnknown() { Map.of()))); } } -} +} \ No newline at end of file diff --git a/server/controller/src/test/java/ai/starwhale/mlops/memory/SwBufferInputStreamTest.java b/server/controller/src/test/java/ai/starwhale/mlops/memory/SwBufferInputStreamTest.java new file mode 100644 index 0000000000..87a71c1a04 --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/memory/SwBufferInputStreamTest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.starwhale.mlops.memory; + +import ai.starwhale.mlops.memory.impl.SwByteBufferManager; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class SwBufferInputStreamTest { + private final SwByteBufferManager bufferManager = new SwByteBufferManager(); + private SwBufferInputStream inputStream; + + @BeforeEach + public void setUp() { + SwBuffer buffer = this.bufferManager.allocate(10); + buffer.setString(0, "0123456789"); + this.inputStream = new SwBufferInputStream(buffer); + } + + @Test + public void testRead() { + for (int i = 0; i < 10; ++i) { + assertThat(this.inputStream.read(), is(i + (int) '0')); + } + assertThat(this.inputStream.read(), is(-1)); + assertThat(this.inputStream.read(), is(-1)); + } + + @Test + public void testReadBytes() { + var b = new byte[10]; + assertThat(this.inputStream.read(b, 1, 2), is(2)); + assertThat(Arrays.copyOfRange(b, 1, 3), is(new byte[]{'0', '1'})); + assertThat(this.inputStream.read(b, 0, 10), is(8)); + assertThat(Arrays.copyOfRange(b, 0, 8), is(new byte[]{'2', '3', '4', '5', '6', '7', '8', '9'})); + assertThat(this.inputStream.read(b, 0, 10), is(-1)); + assertThat(this.inputStream.read(b, 0, 10), is(-1)); + } +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/memory/SwBufferOutputStreamTest.java b/server/controller/src/test/java/ai/starwhale/mlops/memory/SwBufferOutputStreamTest.java new file mode 100644 index 0000000000..ba3fb16768 --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/memory/SwBufferOutputStreamTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.starwhale.mlops.memory; + +import ai.starwhale.mlops.memory.impl.SwByteBufferManager; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class SwBufferOutputStreamTest { + private final SwByteBufferManager bufferManager = new SwByteBufferManager(); + private SwBufferInputStream inputStream; + private SwBufferOutputStream outputStream; + + @BeforeEach + public void setUp() { + SwBuffer buffer = this.bufferManager.allocate(10); + this.inputStream = new SwBufferInputStream(buffer); + this.outputStream = new SwBufferOutputStream(buffer); + } + + @Test + public void testWrite() throws IOException { + for (int i = 0; i < 10; ++i) { + this.outputStream.write(i); + } + assertThrows(IOException.class, () -> this.outputStream.write(0)); + assertThrows(IOException.class, () -> this.outputStream.write(0)); + } + + @Test + public void testWriteBytes() throws IOException { + var b = new byte[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + this.outputStream.write(b, 0, 1); + this.outputStream.write(b, 1, 2); + this.outputStream.write(b, 3, 3); + assertThrows(IOException.class, () -> this.outputStream.write(b, 0, 5)); + this.outputStream.write(b, 6, 4); + assertThrows(IOException.class, () -> this.outputStream.write(0)); + assertThrows(IOException.class, () -> this.outputStream.write(b, 0, 1)); + var c = new byte[10]; + assertThat(this.inputStream.read(c, 0, 10), is(10)); + assertThat(c, is(b)); + } +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/memory/impl/SwByteBufferTest.java b/server/controller/src/test/java/ai/starwhale/mlops/memory/impl/SwByteBufferTest.java new file mode 100644 index 0000000000..104a334fda --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/memory/impl/SwByteBufferTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.starwhale.mlops.memory.impl; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class SwByteBufferTest { + private SwByteBuffer buffer; + + @BeforeEach + public void setUp() { + this.buffer = new SwByteBuffer(1024); + } + + @Test + public void testCapacity() { + assertThat("", this.buffer.asByteBuffer().capacity(), is(1024)); + } + + @Test + public void testGetAndSet() { + this.buffer.setByte(0, (byte) 1); + assertThat(this.buffer.getByte(0), is((byte) 1)); + + this.buffer.setShort(1, (short) 2); + assertThat(this.buffer.getShort(1), is((short) 2)); + + this.buffer.setInt(3, 3); + assertThat(this.buffer.getInt(3), is(3)); + + this.buffer.setLong(7, 4L); + assertThat(this.buffer.getLong(7), is(4L)); + + this.buffer.setFloat(15, 5.f); + assertThat(this.buffer.getFloat(15), is(5.f)); + + this.buffer.setDouble(19, 6.); + assertThat(this.buffer.getDouble(19), is(6.)); + + this.buffer.setString(27, "test"); + assertThat(this.buffer.getString(27, 4), is("test")); + + this.buffer.setBytes(29, "test".getBytes(StandardCharsets.UTF_8), 0, 4); + byte[] b = new byte[10]; + this.buffer.getBytes(29, b, 1, 4); + assertThat(Arrays.copyOfRange(b, 1, 5), is("test".getBytes(StandardCharsets.UTF_8))); + + this.buffer.getBytes(0, b, 0, 7); + assertThat(Arrays.copyOfRange(b, 0, 7), is(new byte[]{1, 0, 2, 0, 0, 0, 3})); + } + + @Test + public void testSlice() { + this.buffer.setString(0, "012345"); + var buf = this.buffer.slice(1, 5); + assertThat(buf.asByteBuffer(), is(ByteBuffer.wrap("12345".getBytes(StandardCharsets.UTF_8)))); + this.buffer.setString(0, "1234567890"); + assertThat(buf.asByteBuffer(), is(ByteBuffer.wrap("23456".getBytes(StandardCharsets.UTF_8)))); + } +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/objectstore/impl/FileIteratorTest.java b/server/controller/src/test/java/ai/starwhale/mlops/objectstore/impl/FileIteratorTest.java new file mode 100644 index 0000000000..73ba5938d6 --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/objectstore/impl/FileIteratorTest.java @@ -0,0 +1,73 @@ +package ai.starwhale.mlops.objectstore.impl; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class FileIteratorTest { + @TempDir + private File rootDir; + + @SuppressWarnings("ResultOfMethodCallIgnored") + @BeforeEach + public void setUp() throws IOException { + var a = new File(this.rootDir, "a"); + a.mkdir(); + new File(a, "b").createNewFile(); + new File(a, "b1").createNewFile(); + new File(a, "b2").createNewFile(); + + var a1 = new File(this.rootDir, "a1"); + a1.mkdir(); + new File(a1, "b1").mkdir(); + new File(a1, "b2").createNewFile(); + + var a1_b = new File(a1, "b"); + a1_b.mkdir(); + new File(a1_b, "c").createNewFile(); + new File(a1_b, "c1").createNewFile(); + + var a2 = new File(this.rootDir, "a2"); + a2.mkdir(); + new File(new File(a2, "b"), "c").mkdirs(); + new File(a2, "b1").mkdir(); + + new File(this.rootDir, "a10").createNewFile(); + new File(this.rootDir, "a11").createNewFile(); + new File(this.rootDir, "a.txt").createNewFile(); + } + + private List getAll(FileIterator fileIterator) { + var ret = new ArrayList(); + while (fileIterator.hasNext()) { + ret.add(fileIterator.next()); + } + return ret; + } + + @Test + public void testScan() { + assertThat(getAll(new FileIterator(this.rootDir.getAbsolutePath(), "a")), + is(List.of("a.txt", "a/b", "a/b1", "a/b2", "a1/b/c", "a1/b/c1", "a1/b2", "a10", "a11"))); + assertThat(getAll(new FileIterator(this.rootDir.getAbsolutePath(), "a/")), + is(List.of("a/b", "a/b1", "a/b2"))); + assertThat(getAll(new FileIterator(this.rootDir.getAbsolutePath(), "a/b")), + is(List.of("a/b", "a/b1", "a/b2"))); + assertThat(getAll(new FileIterator(this.rootDir.getAbsolutePath(), "a1/b/")), + is(List.of("a1/b/c", "a1/b/c1"))); + assertThat(getAll(new FileIterator(this.rootDir.getAbsolutePath(), "a1/b/c1")), + is(List.of("a1/b/c1"))); + assertThat(getAll(new FileIterator(this.rootDir.getAbsolutePath(), "a1/b/c1/")), + is(List.of())); + assertThat(getAll(new FileIterator(this.rootDir.getAbsolutePath(), "b")), + is(List.of())); + } +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/objectstore/impl/FileSystemObjectStoreTest.java b/server/controller/src/test/java/ai/starwhale/mlops/objectstore/impl/FileSystemObjectStoreTest.java new file mode 100644 index 0000000000..a28b27df68 --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/objectstore/impl/FileSystemObjectStoreTest.java @@ -0,0 +1,51 @@ +package ai.starwhale.mlops.objectstore.impl; + +import ai.starwhale.mlops.memory.SwBufferManager; +import ai.starwhale.mlops.memory.impl.SwByteBufferManager; +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class FileSystemObjectStoreTest { + @TempDir + private File rootDir; + + private SwBufferManager bufferManager; + + private FileSystemObjectStore objectStore; + + @BeforeEach + public void setUp() { + this.bufferManager = new SwByteBufferManager(); + this.objectStore = new FileSystemObjectStore(this.bufferManager, this.rootDir.getAbsolutePath()); + } + + @Test + public void testAll() throws IOException { + var buf = this.bufferManager.allocate(100); + buf.setString(0, "c:t1"); + this.objectStore.put("t1", buf.slice(0, 4)); + buf.setString(0, "c:t2"); + this.objectStore.put("t2", buf.slice(0, 4)); + buf.setString(0, "c:t/t3"); + this.objectStore.put("t/t3", buf.slice(0, 6)); + buf.setString(0, "c:d/a"); + this.objectStore.put("d/a", buf.slice(0, 5)); + assertThat(ImmutableList.copyOf(this.objectStore.list("t")), is(List.of("t/t3", "t1", "t2"))); + assertThat(this.objectStore.get("t1").asByteBuffer(), + is(ByteBuffer.wrap("c:t1".getBytes(StandardCharsets.UTF_8)))); + assertThat(this.objectStore.get("t/t3").asByteBuffer(), + is(ByteBuffer.wrap("c:t/t3".getBytes(StandardCharsets.UTF_8)))); + } + +} diff --git a/server/controller/t1 b/server/controller/t1 new file mode 100644 index 0000000000..c431a34630 --- /dev/null +++ b/server/controller/t1 @@ -0,0 +1 @@ +c:t1 \ No newline at end of file diff --git a/server/controller/t2 b/server/controller/t2 new file mode 100644 index 0000000000..9c5b7d7e35 --- /dev/null +++ b/server/controller/t2 @@ -0,0 +1 @@ +c:t2 \ No newline at end of file