Skip to content

Commit

Permalink
Add support for getting field names (linkedin#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
phd3 authored May 11, 2023
1 parent b560364 commit 640747a
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
/** A {@link StdType} representing a struct type. */
public interface StdStructType extends StdType {

/** Returns a {@link List} of the names of all the struct fields. */
List<String> fieldNames();

/** Returns a {@link List} of the types of all the struct fields. */
List<? extends StdType> fieldTypes();
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ public Object underlyingType() {
return _schema;
}

@Override
public List<String> fieldNames() {
return _schema.getFields().stream().map(Schema.Field::name).collect(Collectors.toList());
}

@Override
public List<? extends StdType> fieldTypes() {
return _schema.getFields().stream().map(f -> AvroWrapper.createStdType(f.schema())).collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,29 +171,32 @@ public void testMapType() {

@Test
public void testRecordType() {
Schema field1 = createSchema("field1", "\"int\"");
Schema field2 = createSchema("field2", "\"double\"");
Schema field1 = createSchema("fieldOne", "\"int\"");
Schema field2 = createSchema("fieldTwo", "\"double\"");
Schema structSchema = Schema.createRecord(ImmutableList.of(
new Schema.Field("field1", field1, null, null),
new Schema.Field("field2", field2, null, null)
new Schema.Field("fieldOne", field1, null, null),
new Schema.Field("fieldTwo", field2, null, null)
));

StdType stdStructType = AvroWrapper.createStdType(structSchema);
assertTrue(stdStructType instanceof AvroStructType);
assertEquals(structSchema, stdStructType.underlyingType());
assertEquals(field1, ((AvroStructType) stdStructType).fieldTypes().get(0).underlyingType());
assertEquals(field2, ((AvroStructType) stdStructType).fieldTypes().get(1).underlyingType());

AvroStructType avroStructType = (AvroStructType) stdStructType;
assertEquals(field1, avroStructType.fieldTypes().get(0).underlyingType());
assertEquals(field2, avroStructType.fieldTypes().get(1).underlyingType());
assertEquals(avroStructType.fieldNames(), ImmutableList.of("fieldOne", "fieldTwo"));

GenericRecord value = new GenericData.Record(structSchema);
value.put("field1", 1);
value.put("field2", 2.0);
value.put("fieldOne", 1);
value.put("fieldTwo", 2.0);
StdData stdStructData = AvroWrapper.createStdData(value, structSchema);
assertTrue(stdStructData instanceof AvroStruct);
AvroStruct avroStruct = (AvroStruct) stdStructData;
assertEquals(2, avroStruct.fields().size());
assertEquals(value, avroStruct.getUnderlyingData());
assertEquals(1, ((PlatformData) avroStruct.getField("field1")).getUnderlyingData());
assertEquals(2.0, ((PlatformData) avroStruct.getField("field2")).getUnderlyingData());
assertEquals(1, ((PlatformData) avroStruct.getField("fieldOne")).getUnderlyingData());
assertEquals(2.0, ((PlatformData) avroStruct.getField("fieldTwo")).getUnderlyingData());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.linkedin.transport.hive.HiveWrapper;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;


Expand All @@ -26,6 +27,13 @@ public Object underlyingType() {
return _structObjectInspector;
}

@Override
public List<String> fieldNames() {
return _structObjectInspector.getAllStructFieldRefs().stream()
.map(StructField::getFieldName).collect(Collectors.toList());
}


@Override
public List<? extends StdType> fieldTypes() {
return _structObjectInspector.getAllStructFieldRefs().stream()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.transport.hive.typesystem;

import com.google.common.collect.ImmutableList;
import com.linkedin.transport.api.types.StdType;
import com.linkedin.transport.hive.types.HiveStructType;
import java.util.stream.Collectors;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.testng.Assert;
import org.testng.annotations.Test;


public class TestHiveTypes {

@Test
public void testStructType() {
StructObjectInspector structObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(
ImmutableList.of("fieldOne", "fieldTwo"),
ImmutableList.of(PrimitiveObjectInspectorFactory.javaIntObjectInspector,
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector));
HiveStructType hiveStructType = new HiveStructType(structObjectInspector);
// Field names are case-insensitive
Assert.assertEquals(hiveStructType.fieldNames(), ImmutableList.of("fieldone", "fieldtwo"));
Assert.assertEquals(hiveStructType.fieldTypes().stream().map(StdType::underlyingType).collect(Collectors.toList()),
ImmutableList.of(PrimitiveObjectInspectorFactory.javaIntObjectInspector,
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ case class SparkStructType(structType: StructType) extends StdStructType {

override def underlyingType(): DataType = structType

override def fieldNames(): JavaList[String] = {
structType.fields.map(_.name).toSeq.asJava
}

override def fieldTypes(): JavaList[_ <: StdType] = {
structType.fields.map(f => SparkWrapper.createStdType(f.dataType)).toSeq.asJava
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ package com.linkedin.transport.spark

import java.nio.ByteBuffer
import java.nio.charset.Charset

import com.linkedin.transport.api.data.PlatformData
import com.linkedin.transport.spark.typesystem.{SparkBoundVariables, SparkTypeFactory}
import com.linkedin.transport.api.types.StdStructType
import com.linkedin.transport.spark.typesystem.SparkBoundVariables
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types._
Expand All @@ -20,7 +20,6 @@ import scala.collection.JavaConverters._

class TestSparkFactory {

val typeFactory: SparkTypeFactory = new SparkTypeFactory
val stdFactory = new SparkFactory(new SparkBoundVariables)

@Test
Expand Down Expand Up @@ -65,9 +64,11 @@ class TestSparkFactory {
"bytesField", "arrField")
val fieldTypes = Array("varchar", "integer", "bigint", "boolean", "real", "double", "varbinary", "array(integer)")

val stdStruct = stdFactory.createStruct(stdFactory.createStdType(fieldNames.zip(fieldTypes).map(x => x._1 + " " + x._2).mkString("row(", ", ", ")")))
val stdStructType = stdFactory.createStdType(fieldNames.zip(fieldTypes).map(x => x._1 + " " + x._2).mkString("row(", ", ", ")"));
val stdStruct = stdFactory.createStruct(stdStructType)
val internalRow = stdStruct.asInstanceOf[PlatformData].getUnderlyingData.asInstanceOf[InternalRow]
assertEquals(internalRow.numFields, fieldTypes.length)
assertEquals(stdStructType.asInstanceOf[StdStructType].fieldNames().toArray, fieldNames)
(0 until 8).foreach(idx => {
assertEquals(internalRow.get(idx, stdFactory.createStdType(fieldTypes(idx)).underlyingType().asInstanceOf[DataType]), null)
})
Expand Down
1 change: 1 addition & 0 deletions transportable-udfs-trino/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ dependencies {
// The io.airlift.slice dependency below has to match its counterpart in trino-root's pom.xml file
// If not specified, an older version is picked up transitively from another dependency
testImplementation(group: 'io.airlift', name: 'slice', version: project.ext.'airlift-slice-version')
testImplementation project(path: ':transportable-udfs-type-system', configuration: 'tests')
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ public TrinoStructType(RowType rowType) {
this.rowType = rowType;
}

@Override
public List<String> fieldNames() {
return rowType.getFields().stream()
.map(field -> field.getName().orElse(null)).collect(Collectors.toList());
}

@Override
public List<? extends StdType> fieldTypes() {
return rowType.getFields().stream().map(f -> TrinoWrapper.createStdType(f.getType())).collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.transport.trino;

import com.google.common.collect.ImmutableList;
import com.linkedin.transport.api.types.StdType;
import com.linkedin.transport.trino.types.TrinoStructType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.RowType;
import java.util.stream.Collectors;
import org.testng.Assert;
import org.testng.annotations.Test;


public class TestTrinoTypes {

@Test
public void testStructType() {
RowType.Field field1 = RowType.field("fieldOne", BigintType.BIGINT);
RowType.Field field2 = RowType.field("fieldTwo", DoubleType.DOUBLE);
TrinoStructType trinoStructType = new TrinoStructType(RowType.rowType(field1, field2));
Assert.assertEquals(trinoStructType.fieldNames(), ImmutableList.of("fieldOne", "fieldTwo"));
Assert.assertEquals(trinoStructType.fieldTypes().stream().map(StdType::underlyingType).collect(Collectors.toList()),
ImmutableList.of(BigintType.BIGINT, DoubleType.DOUBLE));
}
}

0 comments on commit 640747a

Please sign in to comment.