Skip to content

Commit

Permalink
enhance(datastore): add tuple type support (#1484)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuchuan authored Nov 11, 2022
1 parent 5e4e011 commit faca792
Show file tree
Hide file tree
Showing 10 changed files with 380 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ public abstract class ColumnType {

public static ColumnType fromColumnSchemaDesc(ColumnSchemaDesc schema) {
var typeName = schema.getType().toUpperCase();
if (typeName.equals(ColumnTypeList.TYPE_NAME)) {
if (typeName.equals(ColumnTypeList.TYPE_NAME) || typeName.equals(ColumnTypeTuple.TYPE_NAME)) {
var elementType = schema.getElementType();
if (elementType == null) {
throw new IllegalArgumentException("elementType should not be null for LIST");
throw new IllegalArgumentException("elementType should not be null for " + typeName);
}
if (typeName.equals(ColumnTypeList.TYPE_NAME)) {
return new ColumnTypeList(ColumnType.fromColumnSchemaDesc(elementType));
} else {
return new ColumnTypeTuple(ColumnType.fromColumnSchemaDesc(elementType));
}
return new ColumnTypeList(ColumnType.fromColumnSchemaDesc(elementType));
} else if (typeName.equals(ColumnTypeObject.TYPE_NAME)) {
var attributes = schema.getAttributes();
if (attributes == null || attributes.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class ColumnTypeList extends ColumnType {

public static final String TYPE_NAME = "LIST";

private final ColumnType elementType;
protected final ColumnType elementType;

ColumnTypeList(ColumnType elementType) {
this.elementType = elementType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ public Object encode(Object value, boolean rawResult) {
} else if (this == STRING) {
return value;
} else if (this == BYTES) {
return Base64.getEncoder().encodeToString(((ByteBuffer) value).array());
var base64 = Base64.getEncoder().encode(((ByteBuffer) value).duplicate());
return StandardCharsets.UTF_8.decode(base64).toString();
}
}
throw new IllegalArgumentException("invalid type " + this);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.datastore;

import lombok.EqualsAndHashCode;
import lombok.Getter;

@Getter
@EqualsAndHashCode(callSuper = true)
public class ColumnTypeTuple extends ColumnTypeList {

public static final String TYPE_NAME = "TUPLE";

ColumnTypeTuple(ColumnType elementType) {
super(elementType);
}

@Override
public String toString() {
return "(" + elementType + ")";
}


@Override
public String getTypeName() {
return ColumnTypeTuple.TYPE_NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;

@Getter
@AllArgsConstructor
@EqualsAndHashCode
public class RecordList {

private Map<String, ColumnType> columnTypeMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ public void testIsComparableWith() {
assertThat(new ColumnTypeList(ColumnTypeScalar.INT32).isComparableWith(
new ColumnTypeList(ColumnTypeScalar.STRING)),
is(false));
assertThat(new ColumnTypeList(ColumnTypeScalar.INT32).isComparableWith(
new ColumnTypeTuple(ColumnTypeScalar.INT32)),
is(true));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ public void testFromColumnSchemaDesc() {
.build())
.build()),
is(new ColumnTypeList(new ColumnTypeList(ColumnTypeScalar.INT32))));
assertThat("simple tuple", ColumnType.fromColumnSchemaDesc(ColumnSchemaDesc.builder()
.type("TUPLE")
.elementType(ColumnSchemaDesc.builder().type("INT32").build())
.build()),
is(new ColumnTypeTuple(ColumnTypeScalar.INT32)));
assertThat("composite tuple", ColumnType.fromColumnSchemaDesc(ColumnSchemaDesc.builder()
.type("TUPLE")
.elementType(ColumnSchemaDesc.builder()
.type("LIST")
.elementType(ColumnSchemaDesc.builder().type("INT32").build())
.build())
.build()),
is(new ColumnTypeTuple(new ColumnTypeList(ColumnTypeScalar.INT32))));
assertThat("object", ColumnType.fromColumnSchemaDesc(ColumnSchemaDesc.builder()
.type("OBJECT")
.pythonType("t")
Expand Down Expand Up @@ -380,6 +393,69 @@ public void testCompareList() {
greaterThan(0));
}

@Test
public void testCompareTuple() {
assertThat(ColumnType.compare(new ColumnTypeTuple(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeTuple(ColumnTypeScalar.INT8),
List.of(1, 2, 4)),
lessThan(0));
assertThat(ColumnType.compare(new ColumnTypeTuple(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeTuple(ColumnTypeScalar.INT8),
List.of(1, 2, 3, 0)),
lessThan(0));
assertThat(ColumnType.compare(new ColumnTypeTuple(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeTuple(ColumnTypeScalar.INT8),
List.of(1, 2, 3)),
equalTo(0));
assertThat(ColumnType.compare(new ColumnTypeTuple(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeTuple(ColumnTypeScalar.INT8),
List.of(1, 2, 2)),
greaterThan(0));
assertThat(ColumnType.compare(new ColumnTypeTuple(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeTuple(ColumnTypeScalar.INT8),
List.of()),
greaterThan(0));
}

@Test
public void testCompareListTuple() {
assertThat(ColumnType.compare(new ColumnTypeTuple(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeList(ColumnTypeScalar.INT8),
List.of(1, 2, 4)),
lessThan(0));
assertThat(ColumnType.compare(new ColumnTypeList(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeTuple(ColumnTypeScalar.INT8),
List.of(1, 2, 4)),
lessThan(0));
assertThat(ColumnType.compare(new ColumnTypeTuple(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeList(ColumnTypeScalar.INT8),
List.of(1, 2, 3)),
equalTo(0));
assertThat(ColumnType.compare(new ColumnTypeList(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeTuple(ColumnTypeScalar.INT8),
List.of(1, 2, 3)),
equalTo(0));
assertThat(ColumnType.compare(new ColumnTypeTuple(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeList(ColumnTypeScalar.INT8),
List.of(1, 2, 2)),
greaterThan(0));
assertThat(ColumnType.compare(new ColumnTypeList(ColumnTypeScalar.INT32),
List.of(1, 2, 3),
new ColumnTypeTuple(ColumnTypeScalar.INT8),
List.of(1, 2, 2)),
greaterThan(0));
}

@Test
public void testCompareObject() {
var type = new ColumnTypeObject("t",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.datastore;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.jupiter.api.Assertions.assertThrows;

import ai.starwhale.mlops.exception.SwValidationException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;

public class ColumnTypeTupleTest {

@Test
public void testGetTypeName() {
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).getTypeName(), is("TUPLE"));
}

@Test
public void testToColumnSchemaDesc() {
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).toColumnSchemaDesc("t"),
is(ColumnSchemaDesc.builder()
.name("t")
.type("TUPLE")
.elementType(ColumnSchemaDesc.builder()
.type("INT32")
.build())
.build()));
}

@Test
public void testToString() {
assertThat(new ColumnTypeList(ColumnTypeScalar.INT32).toString(), is("[INT32]"));
}


@Test
public void testIsComparableWith() {
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).isComparableWith(ColumnTypeScalar.UNKNOWN), is(true));
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).isComparableWith(ColumnTypeScalar.INT32), is(false));
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).isComparableWith(
new ColumnTypeTuple(ColumnTypeScalar.INT32)),
is(true));
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).isComparableWith(
new ColumnTypeTuple(ColumnTypeScalar.FLOAT64)),
is(true));
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).isComparableWith(
new ColumnTypeTuple(ColumnTypeScalar.STRING)),
is(false));
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).isComparableWith(
new ColumnTypeList(ColumnTypeScalar.INT32)),
is(true));
}

@Test
public void testEncode() {
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).encode(List.of(9, 10, 11), false),
is(List.of("00000009", "0000000a", "0000000b")));
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).encode(List.of(9, 10, 11), true),
is(List.of("9", "10", "11")));
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).encode(new ArrayList<Integer>() {
{
add(0);
add(null);
add(1);
}
}, false),
is(new ArrayList<String>() {
{
add("00000000");
add(null);
add("00000001");
}
}));
var composite = new ColumnTypeTuple(
new ColumnTypeObject("t", Map.of("a", ColumnTypeScalar.INT32, "b", ColumnTypeScalar.INT32)));
assertThat(composite.encode(List.of(Map.of("a", 9, "b", 10), Map.of("a", 10, "b", 11)), false),
is(List.of(Map.of("a", "00000009", "b", "0000000a"), Map.of("a", "0000000a", "b", "0000000b"))));
assertThat(composite.encode(List.of(Map.of("a", 9, "b", 10), Map.of("a", 10, "b", 11)), true),
is(List.of(Map.of("a", "9", "b", "10"), Map.of("a", "10", "b", "11"))));
}

@Test
public void testDecode() {
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).decode(List.of("9", "a", "b")),
is(List.of(9, 10, 11)));
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).decode(new ArrayList<String>() {
{
add("0");
add(null);
add("1");
}
}),
is(new ArrayList<Integer>() {
{
add(0);
add(null);
add(1);
}
}));
var composite = new ColumnTypeTuple(
new ColumnTypeObject("t", Map.of("a", ColumnTypeScalar.INT32, "b", ColumnTypeScalar.INT32)));
assertThat(composite.decode(List.of(Map.of("a", "9", "b", "a"), Map.of("a", "a", "b", "b"))),
is(List.of(Map.of("a", 9, "b", 10), Map.of("a", 10, "b", 11))));

assertThrows(SwValidationException.class, () -> new ColumnTypeTuple(ColumnTypeScalar.INT32).decode("9"));
assertThrows(SwValidationException.class,
() -> new ColumnTypeTuple(ColumnTypeScalar.INT32).decode(List.of("z")));
}

@Test
public void testFromAndToWal() {
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).toWal(-1, List.of(9, 10, 11)).getIndex(), is(-1));
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).toWal(10, List.of(9, 10, 11)).getIndex(), is(10));
assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).fromWal(
new ColumnTypeTuple(ColumnTypeScalar.INT32).toWal(0, null).build()),
nullValue());

assertThat(new ColumnTypeTuple(ColumnTypeScalar.INT32).fromWal(
new ColumnTypeTuple(ColumnTypeScalar.INT32).toWal(0, List.of(9, 10, 11)).build()),
is(List.of(9, 10, 11)));
}

}
Loading

0 comments on commit faca792

Please sign in to comment.