Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into spark35
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Jun 19, 2024
2 parents 10ecc03 + a4b968e commit a6821b1
Show file tree
Hide file tree
Showing 24 changed files with 3,021 additions and 71 deletions.
27 changes: 13 additions & 14 deletions .github/workflows/spark_sql_test_ansi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,15 @@ concurrency:
cancel-in-progress: true

on:
# enable the following once Ansi support is completed
# push:
# paths-ignore:
# - "doc/**"
# - "**.md"
# pull_request:
# paths-ignore:
# - "doc/**"
# - "**.md"

# manual trigger ONLY
push:
paths-ignore:
- "docs/**"
- "**.md"
pull_request:
paths-ignore:
- "docs/**"
- "**.md"
# manual trigger
# https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow
workflow_dispatch:

Expand All @@ -44,8 +42,8 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
java-version: [11]
spark-version: [{short: '3.4', full: '3.4.2'}]
java-version: [17]
spark-version: [{short: '4.0', full: '4.0.0-preview1'}]
module:
- {name: "catalyst", args1: "catalyst/test", args2: ""}
- {name: "sql/core-1", args1: "", args2: sql/testOnly * -- -l org.apache.spark.tags.ExtendedSQLTest -l org.apache.spark.tags.SlowSQLTest}
Expand Down Expand Up @@ -75,7 +73,8 @@ jobs:
- name: Run Spark tests
run: |
cd apache-spark
ENABLE_COMET=true ENABLE_COMET_ANSI_MODE=true build/sbt ${{ matrix.module.args1 }} "${{ matrix.module.args2 }}"
rm -rf /root/.m2/repository/org/apache/parquet # somehow parquet cache requires cleanups
RUST_BACKTRACE=1 ENABLE_COMET=true ENABLE_COMET_ANSI_MODE=true build/sbt ${{ matrix.module.args1 }} "${{ matrix.module.args2 }}"
env:
LC_ALL: "C.UTF-8"

Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ public void init() throws URISyntaxException, IOException {
missingColumns = new boolean[columns.size()];
List<String[]> paths = requestedSchema.getPaths();
StructField[] nonPartitionFields = sparkSchema.fields();
ShimFileFormat.findRowIndexColumnIndexInSchema(sparkSchema);
for (int i = 0; i < requestedSchema.getFieldCount(); i++) {
Type t = requestedSchema.getFields().get(i);
Preconditions.checkState(
Expand Down
1 change: 1 addition & 0 deletions common/src/main/java/org/apache/comet/parquet/Native.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public static native long initColumnReader(
int precision,
int expectedPrecision,
int scale,
int expectedScale,
int tu,
boolean isAdjustedUtc,
int batchSize,
Expand Down
19 changes: 15 additions & 4 deletions common/src/main/java/org/apache/comet/parquet/TypeUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Type;
import org.apache.parquet.schema.Types;
import org.apache.spark.package$;
import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException;
import org.apache.spark.sql.types.*;

Expand Down Expand Up @@ -169,6 +170,7 @@ && isUnsignedIntTypeMatched(logicalTypeAnnotation, 64)) {
break;
case INT96:
if (sparkType == TimestampNTZType$.MODULE$) {
if (isSpark40Plus()) return; // Spark 4.0+ supports Timestamp NTZ with INT96
convertErrorForTimestampNTZ(typeName.name());
} else if (sparkType == DataTypes.TimestampType) {
return;
Expand Down Expand Up @@ -218,7 +220,8 @@ private static void validateTimestampType(
// Throw an exception if the Parquet type is TimestampLTZ and the Catalyst type is TimestampNTZ.
// This is to avoid mistakes in reading the timestamp values.
if (((TimestampLogicalTypeAnnotation) logicalTypeAnnotation).isAdjustedToUTC()
&& sparkType == TimestampNTZType$.MODULE$) {
&& sparkType == TimestampNTZType$.MODULE$
&& !isSpark40Plus()) {
convertErrorForTimestampNTZ("int64 time(" + logicalTypeAnnotation + ")");
}
}
Expand All @@ -232,12 +235,14 @@ private static void convertErrorForTimestampNTZ(String parquetType) {
}

private static boolean canReadAsIntDecimal(ColumnDescriptor descriptor, DataType dt) {
if (!DecimalType.is32BitDecimalType(dt)) return false;
if (!DecimalType.is32BitDecimalType(dt) && !(isSpark40Plus() && dt instanceof DecimalType))
return false;
return isDecimalTypeMatched(descriptor, dt);
}

private static boolean canReadAsLongDecimal(ColumnDescriptor descriptor, DataType dt) {
if (!DecimalType.is64BitDecimalType(dt)) return false;
if (!DecimalType.is64BitDecimalType(dt) && !(isSpark40Plus() && dt instanceof DecimalType))
return false;
return isDecimalTypeMatched(descriptor, dt);
}

Expand All @@ -261,7 +266,9 @@ private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataTyp
DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation;
// It's OK if the required decimal precision is larger than or equal to the physical decimal
// precision in the Parquet metadata, as long as the decimal scale is the same.
return decimalType.getPrecision() <= d.precision() && decimalType.getScale() == d.scale();
return decimalType.getPrecision() <= d.precision()
&& (decimalType.getScale() == d.scale()
|| (isSpark40Plus() && decimalType.getScale() <= d.scale()));
}
return false;
}
Expand All @@ -278,4 +285,8 @@ private static boolean isUnsignedIntTypeMatched(
&& !((IntLogicalTypeAnnotation) logicalTypeAnnotation).isSigned()
&& ((IntLogicalTypeAnnotation) logicalTypeAnnotation).getBitWidth() == bitWidth;
}

private static boolean isSpark40Plus() {
return package$.MODULE$.SPARK_VERSION().compareTo("4.0") >= 0;
}
}
11 changes: 9 additions & 2 deletions common/src/main/java/org/apache/comet/parquet/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public static long initColumnReader(
promotionInfo = new TypePromotionInfo(readType);
} else {
// If type promotion is not enable, we'll just use the Parquet primitive type and precision.
promotionInfo = new TypePromotionInfo(primitiveTypeId, precision);
promotionInfo = new TypePromotionInfo(primitiveTypeId, precision, scale);
}

return Native.initColumnReader(
Expand All @@ -131,6 +131,7 @@ public static long initColumnReader(
precision,
promotionInfo.precision,
scale,
promotionInfo.scale,
tu,
isAdjustedUtc,
batchSize,
Expand All @@ -144,10 +145,13 @@ static class TypePromotionInfo {
int physicalTypeId;
// Decimal precision from the Spark read schema, or -1 if it's not decimal type.
int precision;
// Decimal scale from the Spark read schema, or -1 if it's not decimal type.
int scale;

TypePromotionInfo(int physicalTypeId, int precision) {
TypePromotionInfo(int physicalTypeId, int precision, int scale) {
this.physicalTypeId = physicalTypeId;
this.precision = precision;
this.scale = scale;
}

TypePromotionInfo(DataType sparkReadType) {
Expand All @@ -159,13 +163,16 @@ static class TypePromotionInfo {
int physicalTypeId = getPhysicalTypeId(primitiveType.getPrimitiveTypeName());
LogicalTypeAnnotation annotation = primitiveType.getLogicalTypeAnnotation();
int precision = -1;
int scale = -1;
if (annotation instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) {
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalAnnotation =
(LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) annotation;
precision = decimalAnnotation.getPrecision();
scale = decimalAnnotation.getScale();
}
this.physicalTypeId = physicalTypeId;
this.precision = precision;
this.scale = scale;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.comet.shims

import org.apache.spark.sql.types.{LongType, StructField, StructType}

object ShimFileFormat {

// TODO: remove after dropping Spark 3.3 support and directly use FileFormat.ROW_INDEX
Expand All @@ -29,4 +31,20 @@ object ShimFileFormat {
// TODO: remove after dropping Spark 3.3 support and directly use
// FileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME
val ROW_INDEX_TEMPORARY_COLUMN_NAME: String = s"_tmp_metadata_$ROW_INDEX"

// TODO: remove after dropping Spark 3.3 support and directly use
// RowIndexUtil.findRowIndexColumnIndexInSchema
def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int = {
sparkSchema.fields.zipWithIndex.find { case (field: StructField, _: Int) =>
field.name == ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME
} match {
case Some((field: StructField, idx: Int)) =>
if (field.dataType != LongType) {
throw new RuntimeException(
s"${ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME} must be of LongType")
}
idx
case _ => -1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

package org.apache.comet.shims

import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetRowIndexUtil
import org.apache.spark.sql.types.StructType

object ShimFileFormat {
// A name for a temporary column that holds row indexes computed by the file format reader
// until they can be placed in the _metadata struct.
val ROW_INDEX_TEMPORARY_COLUMN_NAME = ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME

val OPTION_RETURNING_BATCH = FileFormat.OPTION_RETURNING_BATCH
def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int =
ParquetRowIndexUtil.findRowIndexColumnIndexInSchema(sparkSchema)
}
2 changes: 1 addition & 1 deletion core/benches/parquet_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fn bench(c: &mut Criterion) {
);
b.iter(|| {
let cd = ColumnDescriptor::new(t.clone(), 0, 0, ColumnPath::from(Vec::new()));
let promition_info = TypePromotionInfo::new(PhysicalType::INT32, -1);
let promition_info = TypePromotionInfo::new(PhysicalType::INT32, -1, -1);
let mut column_reader = TestColumnReader::new(
cd,
promition_info,
Expand Down
50 changes: 41 additions & 9 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use arrow::error::ArrowError;
use datafusion_common::DataFusionError;
use jni::errors::{Exception, ToException};
use regex::Regex;

use std::{
any::Any,
convert,
Expand All @@ -37,6 +38,7 @@ use std::{
use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, jshort};

use crate::execution::operators::ExecutionError;
use jni::objects::{GlobalRef, JThrowable};
use jni::JNIEnv;
use lazy_static::lazy_static;
use parquet::errors::ParquetError;
Expand Down Expand Up @@ -160,7 +162,11 @@ pub enum CometError {
},

#[error("{class}: {msg}")]
JavaException { class: String, msg: String },
JavaException {
class: String,
msg: String,
throwable: GlobalRef,
},
}

pub fn init() {
Expand Down Expand Up @@ -208,6 +214,15 @@ impl From<CometError> for ExecutionError {
fn from(value: CometError) -> Self {
match value {
CometError::Execution { source } => source,
CometError::JavaException {
class,
msg,
throwable,
} => ExecutionError::JavaException {
class,
msg,
throwable,
},
_ => ExecutionError::GeneralError(value.to_string()),
}
}
Expand Down Expand Up @@ -379,17 +394,34 @@ pub fn unwrap_or_throw_default<T: JNIDefault>(
}
}

fn throw_exception<E: ToException>(env: &mut JNIEnv, error: &E, backtrace: Option<String>) {
fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option<String>) {
// If there isn't already an exception?
if env.exception_check().is_ok() {
// ... then throw new exception
let exception = error.to_exception();
match backtrace {
Some(backtrace_string) => env.throw_new(
exception.class,
to_stacktrace_string(exception.msg, backtrace_string).unwrap(),
),
_ => env.throw_new(exception.class, exception.msg),
match error {
CometError::JavaException {
class: _,
msg: _,
throwable,
} => env.throw(<&JThrowable>::from(throwable.as_obj())),
CometError::Execution {
source:
ExecutionError::JavaException {
class: _,
msg: _,
throwable,
},
} => env.throw(<&JThrowable>::from(throwable.as_obj())),
_ => {
let exception = error.to_exception();
match backtrace {
Some(backtrace_string) => env.throw_new(
exception.class,
to_stacktrace_string(exception.msg, backtrace_string).unwrap(),
),
_ => env.throw_new(exception.class, exception.msg),
}
}
}
.expect("Thrown exception")
}
Expand Down
8 changes: 8 additions & 0 deletions core/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use arrow::{

use arrow::compute::{cast_with_options, CastOptions};
use arrow_schema::ArrowError;
use jni::objects::GlobalRef;
use std::{fmt::Debug, sync::Arc};

mod scan;
Expand Down Expand Up @@ -52,6 +53,13 @@ pub enum ExecutionError {
/// DataFusion error
#[error("Error from DataFusion: {0}.")]
DataFusionError(String),

#[error("{class}: {msg}")]
JavaException {
class: String,
msg: String,
throwable: GlobalRef,
},
}

/// Copy an Arrow Array
Expand Down
1 change: 1 addition & 0 deletions core/src/jvm_bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,5 +385,6 @@ pub(crate) fn convert_exception(
Ok(CometError::JavaException {
class: exception_class_name_str,
msg: message_str,
throwable: env.new_global_ref(throwable)?,
})
}
4 changes: 3 additions & 1 deletion core/src/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader(
precision: jint,
read_precision: jint,
scale: jint,
read_scale: jint,
time_unit: jint,
is_adjusted_utc: jboolean,
batch_size: jint,
Expand All @@ -94,7 +95,8 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader(
is_adjusted_utc,
jni_path,
)?;
let promotion_info = TypePromotionInfo::new_from_jni(read_primitive_type, read_precision);
let promotion_info =
TypePromotionInfo::new_from_jni(read_primitive_type, read_precision, read_scale);
let ctx = Context {
column_reader: ColumnReader::get(
desc,
Expand Down
Loading

0 comments on commit a6821b1

Please sign in to comment.