Skip to content

Commit

Permalink
test: move common logic into TestBase (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarua authored Nov 2, 2023
1 parent 33ca926 commit 618d7ff
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 78 deletions.
22 changes: 22 additions & 0 deletions core/src/test/java/io/substrait/TestBase.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
package io.substrait;

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.relation.Rel;
import io.substrait.relation.RelProtoConverter;
import io.substrait.type.TypeCreator;
import java.io.IOException;

public abstract class TestBase {
Expand All @@ -14,4 +22,18 @@ public abstract class TestBase {
throw new RuntimeException(e);
}
}

protected TypeCreator R = TypeCreator.REQUIRED;

protected SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection);
protected ExtensionCollector functionCollector = new ExtensionCollector();
protected RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector);
protected ProtoRelConverter protoRelConverter =
new ProtoRelConverter(functionCollector, defaultExtensionCollection);

protected void verifyRoundTrip(Rel rel) {
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel);
Rel relReturned = protoRelConverter.from(protoRel);
assertEquals(rel, relReturned);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,14 @@
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.substrait.TestBase;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.extension.AdvancedExtension;
import io.substrait.extension.ExtensionCollector;
import io.substrait.relation.utils.StringHolder;
import java.util.Collections;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;

public class ProtoRelConverterTest extends TestBase {

final SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection);
final ExtensionCollector functionCollector = new ExtensionCollector();
final RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector);
final ProtoRelConverter protoRelConverter =
new ProtoRelConverter(functionCollector, defaultExtensionCollection);

final NamedScan commonTable =
b.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import org.junit.jupiter.api.Test;

public class AggregateRoundtripTest extends TestBase {
static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(AggregateRoundtripTest.class);

private void assertAggregateRoundtrip(Expression.AggregationInvocation invocation) {
var expression = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.TestBase;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.Expression;
import io.substrait.extension.AdvancedExtension;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.ExtensionLeaf;
Expand All @@ -22,7 +19,6 @@
import io.substrait.relation.Project;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.relation.Rel;
import io.substrait.relation.RelProtoConverter;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
Expand All @@ -45,15 +41,8 @@
*/
public class ExtensionRoundtripTest extends TestBase {

TypeCreator R = TypeCreator.REQUIRED;

final SimpleExtension.ExtensionCollection extensions = defaultExtensionCollection;

final SubstraitBuilder b = new SubstraitBuilder(extensions);
final ExtensionCollector functionCollector = new ExtensionCollector();
final RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector);
final ProtoRelConverter protoRelConverter =
new StringHolderHandlingProtoRelConverter(functionCollector, extensions);
new StringHolderHandlingProtoRelConverter(functionCollector, defaultExtensionCollection);

final Rel commonTable =
b.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
Expand All @@ -72,7 +61,8 @@ public class ExtensionRoundtripTest extends TestBase {
.optimization(new StringHolder("REL OPTIMIZATION"))
.build();

void verifyRoundTrip(Rel rel) {
@Override
protected void verifyRoundTrip(Rel rel) {
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel);
Rel relReturned = protoRelConverter.from(protoRel);
assertEquals(rel, relReturned);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
import org.junit.jupiter.params.provider.MethodSource;

public class GenericRoundtripTest extends TestBase {
static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(GenericRoundtripTest.class);

static Random rand = new Random(123);

Expand Down
26 changes: 0 additions & 26 deletions core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java
Original file line number Diff line number Diff line change
@@ -1,34 +1,14 @@
package io.substrait.type.proto;

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.TestBase;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.relation.Rel;
import io.substrait.relation.RelProtoConverter;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter;
import io.substrait.type.TypeCreator;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Test;

public class JoinRoundtripTest extends TestBase {

final SimpleExtension.ExtensionCollection extensions = defaultExtensionCollection;

TypeCreator R = TypeCreator.REQUIRED;

final SubstraitBuilder b = new SubstraitBuilder(extensions);

final ExtensionCollector functionCollector = new ExtensionCollector();
final RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector);
final ProtoRelConverter protoRelConverter =
new StringHolderHandlingProtoRelConverter(functionCollector, extensions);

final Rel leftTable =
b.namedScan(
Arrays.asList("T1"),
Expand All @@ -41,12 +21,6 @@ public class JoinRoundtripTest extends TestBase {
Arrays.asList("d", "e", "f"),
Arrays.asList(R.FP64, R.STRING, R.I64));

void verifyRoundTrip(Rel rel) {
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel);
Rel relReturned = protoRelConverter.from(protoRel);
assertEquals(rel, relReturned);
}

@Test
void hashJoin() {
List<Integer> leftKeys = Arrays.asList(0, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,16 @@
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.relation.ProtoRelConverter;
import java.math.BigDecimal;
import org.junit.jupiter.api.Test;

public class LiteralRoundtripTest extends TestBase {
static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(LiteralRoundtripTest.class);

@Test
void decimal() {
var val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2);
var to = new ExpressionProtoConverter(null, null);
var from =
new ProtoExpressionConverter(
null,
null,
EMPTY_TYPE,
new ProtoRelConverter(new ExtensionCollector(), defaultExtensionCollection));
var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter);
assertEquals(val, from.from(val.accept(to)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,20 @@
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.ReadRel;
import io.substrait.relation.LocalFiles;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.relation.RelProtoConverter;
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.files.ImmutableFileFormat;
import io.substrait.relation.files.ImmutableFileOrFiles;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.io.IOException;
import java.util.Arrays;
import org.junit.jupiter.api.Test;

public class LocalFilesRoundtripTest extends TestBase {

SimpleExtension.ExtensionCollection extensions = defaultExtensionCollection;

public LocalFilesRoundtripTest() throws IOException {}

private void assertLocalFilesRoundtrip(FileOrFiles file) {
ExtensionCollector functionCollector = new ExtensionCollector();
RelProtoConverter to = new RelProtoConverter(functionCollector);
ProtoRelConverter from = new ProtoRelConverter(functionCollector, extensions);

var builder =
LocalFiles.builder()
.initialSchema(
Expand All @@ -47,7 +34,7 @@ private void assertLocalFilesRoundtrip(FileOrFiles file) {
.build())
.addItems(file);

extensions.scalarFunctions().stream()
defaultExtensionCollection.scalarFunctions().stream()
.filter(s -> s.name().equalsIgnoreCase("equal"))
.findFirst()
.map(
Expand All @@ -63,9 +50,9 @@ private void assertLocalFilesRoundtrip(FileOrFiles file) {
.ifPresent(builder::filter);

var localFiles = builder.build();
var protoFileRel = to.toProto(localFiles);
var protoFileRel = relProtoConverter.toProto(localFiles);
assertTrue(protoFileRel.getRead().hasFilter());
assertEquals(protoFileRel, to.toProto(from.from(protoFileRel)));
assertEquals(protoFileRel, relProtoConverter.toProto(protoRelConverter.from(protoFileRel)));
}

private ImmutableFileOrFiles.Builder setPath(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.junit.jupiter.params.provider.ValueSource;

public class TestTypeRoundtrip {
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(TestTypeRoundtrip.class);

@ParameterizedTest
@ValueSource(booleans = {true, false})
Expand Down

0 comments on commit 618d7ff

Please sign in to comment.