Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dataset/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_jar(arrow_java_jni_dataset_jar
src/main/java/org/apache/arrow/dataset/jni/JniLoader.java
src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java
src/main/java/org/apache/arrow/dataset/file/JniWrapper.java
src/main/java/org/apache/arrow/dataset/file/ParquetWriterProperties.java
src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java
src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java
src/main/java/org/apache/arrow/dataset/substrait/JniWrapper.java
Expand Down
1 change: 1 addition & 0 deletions dataset/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ under the License.
<directory>${arrow.cpp.build.dir}</directory>
<includes>
<include>**/*arrow_dataset_jni.*</include>
<include>**/*arrow_cdata_jni.*</include>
</includes>
</resource>
</resources>
Expand Down
337 changes: 337 additions & 0 deletions dataset/src/main/cpp/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arrow/compute/initialize.h"
#include "arrow/dataset/api.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/file_parquet.h"
#ifdef ARROW_CSV
#include "arrow/dataset/file_csv.h"
#endif
Expand All @@ -36,6 +37,11 @@
#include "arrow/engine/substrait/relation.h"
#include "arrow/ipc/api.h"
#include "arrow/util/iterator.h"
#include "arrow/util/compression.h"
#include "parquet/arrow/writer.h"
#include "parquet/file_writer.h"
#include "parquet/stream_writer.h"
#include "arrow/io/file.h"
#include "jni_util.h"
#include "org_apache_arrow_dataset_file_JniWrapper.h"
#include "org_apache_arrow_dataset_jni_JniWrapper.h"
Expand Down Expand Up @@ -231,6 +237,253 @@ class DisposableScannerAdaptor {
}
};

// Adapter to wrap Java OutputStream as Arrow OutputStream
class JavaOutputStreamAdapter : public arrow::io::OutputStream {
public:
JavaOutputStreamAdapter(JNIEnv* env, jobject java_output_stream)
: java_output_stream_(env->NewGlobalRef(java_output_stream)),
position_(0) {
JavaVM* vm;
env->GetJavaVM(&vm);
vm_ = vm;

// Get method IDs (cache them as they're valid for the lifetime of the class)
JNIEnv* current_env = GetEnv();
if (current_env) {
jclass output_stream_class = current_env->GetObjectClass(java_output_stream);
write_method_ = current_env->GetMethodID(output_stream_class, "write", "([BII)V");
flush_method_ = current_env->GetMethodID(output_stream_class, "flush", "()V");
close_method_ = current_env->GetMethodID(output_stream_class, "close", "()V");
current_env->DeleteLocalRef(output_stream_class);
}
}

~JavaOutputStreamAdapter() override {
JNIEnv* env = GetEnv();
if (env && java_output_stream_) {
env->DeleteGlobalRef(java_output_stream_);
}
}

arrow::Status Close() override {
JNIEnv* env = GetEnv();
if (!env) {
return arrow::Status::IOError("Failed to get JNI environment");
}
if (java_output_stream_) {
env->CallVoidMethod(java_output_stream_, close_method_);
RETURN_NOT_OK(CheckJniException(env));
env->DeleteGlobalRef(java_output_stream_);
java_output_stream_ = nullptr;
}
return arrow::Status::OK();
}

bool closed() const override { return java_output_stream_ == nullptr; }

arrow::Result<int64_t> Tell() const override { return position_; }

arrow::Status Write(const void* data, int64_t nbytes) override {
JNIEnv* env = GetEnv();
if (!env) {
return arrow::Status::IOError("Failed to get JNI environment");
}
if (!java_output_stream_) {
return arrow::Status::IOError("OutputStream is closed");
}

// Create byte array
jbyteArray byte_array = env->NewByteArray(static_cast<jsize>(nbytes));
if (byte_array == nullptr) {
return arrow::Status::OutOfMemory("Failed to allocate byte array");
}

// Copy data to byte array
env->SetByteArrayRegion(byte_array, 0, static_cast<jsize>(nbytes),
reinterpret_cast<const jbyte*>(data));

// Call Java write method
env->CallVoidMethod(java_output_stream_, write_method_, byte_array, 0,
static_cast<jint>(nbytes));
env->DeleteLocalRef(byte_array);

RETURN_NOT_OK(CheckJniException(env));
position_ += nbytes;
return arrow::Status::OK();
}

arrow::Status Flush() override {
JNIEnv* env = GetEnv();
if (!env) {
return arrow::Status::IOError("Failed to get JNI environment");
}
if (!java_output_stream_) {
return arrow::Status::IOError("OutputStream is closed");
}
env->CallVoidMethod(java_output_stream_, flush_method_);
return CheckJniException(env);
}

private:
JNIEnv* GetEnv() const {
JNIEnv* env;
if (vm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
return nullptr;
}
return env;
}

arrow::Status CheckJniException(JNIEnv* env) {
if (env->ExceptionCheck()) {
jthrowable exception = env->ExceptionOccurred();
env->ExceptionClear();
std::string error_msg = "Java exception occurred in OutputStream";
// Try to get exception message
jclass exception_class = env->GetObjectClass(exception);
jmethodID get_message_method =
env->GetMethodID(exception_class, "toString", "()Ljava/lang/String;");
if (get_message_method) {
jstring message =
(jstring)env->CallObjectMethod(exception, get_message_method);
if (message) {
const char* msg_chars = env->GetStringUTFChars(message, nullptr);
error_msg = std::string(msg_chars);
env->ReleaseStringUTFChars(message, msg_chars);
env->DeleteLocalRef(message);
}
}
env->DeleteLocalRef(exception);
env->DeleteLocalRef(exception_class);
return arrow::Status::IOError(error_msg);
}
return arrow::Status::OK();
}

JavaVM* vm_;
jobject java_output_stream_;
jmethodID write_method_;
jmethodID flush_method_;
jmethodID close_method_;
int64_t position_;
};

struct ParquetWriterHolder {
std::unique_ptr<parquet::arrow::FileWriter> writer;
std::shared_ptr<arrow::io::OutputStream> output_stream;
std::shared_ptr<arrow::Schema> schema;
};

// Helper function to build WriterProperties from Java ParquetWriterProperties object
std::shared_ptr<parquet::WriterProperties> BuildWriterProperties(JNIEnv* env,
jobject java_properties) {
parquet::WriterProperties::Builder builder;

if (java_properties == nullptr) {
return builder.build();
}

jclass props_class = env->GetObjectClass(java_properties);

// Get maxRowGroupLength
jmethodID get_max_row_group_method =
env->GetMethodID(props_class, "getMaxRowGroupLength", "()J");
if (get_max_row_group_method) {
jlong max_row_group = env->CallLongMethod(java_properties, get_max_row_group_method);
if (max_row_group > 0) {
builder.max_row_group_length(static_cast<int64_t>(max_row_group));
}
}

// Get writeBatchSize
jmethodID get_write_batch_size_method =
env->GetMethodID(props_class, "getWriteBatchSize", "()J");
if (get_write_batch_size_method) {
jlong write_batch_size = env->CallLongMethod(java_properties, get_write_batch_size_method);
if (write_batch_size > 0) {
builder.write_batch_size(static_cast<int64_t>(write_batch_size));
}
}

// Get dataPageSize
jmethodID get_data_page_size_method =
env->GetMethodID(props_class, "getDataPageSize", "()J");
if (get_data_page_size_method) {
jlong data_page_size = env->CallLongMethod(java_properties, get_data_page_size_method);
if (data_page_size > 0) {
builder.data_pagesize(static_cast<int64_t>(data_page_size));
}
}

// Get compressionCodec
jmethodID get_compression_codec_method =
env->GetMethodID(props_class, "getCompressionCodec", "()Ljava/lang/String;");
if (get_compression_codec_method) {
jstring codec_str = (jstring)env->CallObjectMethod(java_properties, get_compression_codec_method);
if (codec_str != nullptr) {
std::string codec_name = arrow::dataset::jni::JStringToCString(env, codec_str);
// Use Arrow's Codec::GetCompressionType to parse compression name
auto arrow_compression_result = arrow::util::Codec::GetCompressionType(codec_name);
if (arrow_compression_result.ok()) {
// Parquet WriterProperties::Builder can accept Arrow Compression::type directly
arrow::Compression::type compression = arrow_compression_result.ValueOrDie();
// Set compression for all columns (using Arrow compression type directly)
builder.compression(compression);
} else {
// If parsing fails, log a warning but continue with UNCOMPRESSED
// This allows the code to work even with unsupported compression types
}
env->DeleteLocalRef(codec_str);
}
}

// Get compressionLevel
jmethodID get_compression_level_method =
env->GetMethodID(props_class, "getCompressionLevel", "()I");
if (get_compression_level_method) {
jint comp_level = env->CallIntMethod(java_properties, get_compression_level_method);
if (comp_level > 0) {
builder.compression_level(comp_level);
}
}

// Get writePageIndex
jmethodID get_write_page_index_method =
env->GetMethodID(props_class, "getWritePageIndex", "()Z");
if (get_write_page_index_method) {
jboolean write_index = env->CallBooleanMethod(java_properties, get_write_page_index_method);
if (write_index) {
builder.enable_write_page_index();
}
}

env->DeleteLocalRef(props_class);
return builder.build();
}

// Helper function to build ArrowWriterProperties
std::shared_ptr<parquet::ArrowWriterProperties> BuildArrowWriterProperties(JNIEnv* env,
jobject java_properties) {
parquet::ArrowWriterProperties::Builder builder;

if (java_properties == nullptr) {
return builder.build();
}

jclass props_class = env->GetObjectClass(java_properties);

// Get useThreads (for ArrowWriterProperties)
jmethodID get_use_threads_method =
env->GetMethodID(props_class, "getUseThreads", "()Z");
if (get_use_threads_method) {
jboolean use_threads = env->CallBooleanMethod(java_properties, get_use_threads_method);
builder.set_use_threads(use_threads);
}

env->DeleteLocalRef(props_class);
return builder.build();
}


arrow::Result<std::shared_ptr<arrow::Schema>> SchemaFromColumnNames(
const std::shared_ptr<arrow::Schema>& input,
const std::vector<std::string>& column_names) {
Expand Down Expand Up @@ -1019,3 +1272,87 @@ JNIEXPORT void JNICALL
JniAssertOkOrThrow(arrow::ExportRecordBatchReader(reader_out, arrow_stream_out));
JNI_METHOD_END()
}

/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: nativeCreateParquetWriter
* Signature: (Ljava/io/OutputStream;JLorg/apache/arrow/dataset/file/ParquetWriterProperties;)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_nativeCreateParquetWriter(
JNIEnv* env, jobject, jobject java_output_stream, jlong arrow_schema_ptr,
jobject java_properties) {
JNI_METHOD_START
// Import Schema from Arrow C Data Interface
auto* c_schema = reinterpret_cast<ArrowSchema*>(arrow_schema_ptr);
std::shared_ptr<arrow::Schema> schema = JniGetOrThrow(arrow::ImportSchema(c_schema));

// Create output stream adapter from Java OutputStream
auto output_stream = std::make_shared<JavaOutputStreamAdapter>(env, java_output_stream);

// Build Parquet writer properties from Java configuration
std::shared_ptr<parquet::WriterProperties> writer_props =
BuildWriterProperties(env, java_properties);
std::shared_ptr<parquet::ArrowWriterProperties> arrow_writer_props =
BuildArrowWriterProperties(env, java_properties);

// Open Parquet file writer
auto writer_result = parquet::arrow::FileWriter::Open(*schema, arrow::default_memory_pool(), output_stream,
writer_props, arrow_writer_props);
std::unique_ptr<parquet::arrow::FileWriter> writer = JniGetOrThrow(std::move(writer_result));

// Return the writer wrapped in a shared_ptr
auto holder = std::make_shared<ParquetWriterHolder>();
holder->writer = std::move(writer);
holder->output_stream = output_stream;
holder->schema = schema;
return CreateNativeRef<ParquetWriterHolder>(holder);
JNI_METHOD_END(0L)
}

/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: nativeWriteParquetBatch
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_nativeWriteParquetBatch(
JNIEnv* env, jobject, jlong native_ptr, jlong arrow_array_ptr) {
JNI_METHOD_START
auto holder = RetrieveNativeInstance<ParquetWriterHolder>(native_ptr);
if (!holder->writer) {
JniThrow("ParquetWriter is already closed");
}

// Import RecordBatch from Arrow C Data Interface
auto* c_array = reinterpret_cast<ArrowArray*>(arrow_array_ptr);
std::shared_ptr<arrow::RecordBatch> batch = JniGetOrThrow(arrow::ImportRecordBatch(c_array, holder->schema));

// Write the RecordBatch
JniAssertOkOrThrow(holder->writer->WriteRecordBatch(*batch));
return 1; // Success
JNI_METHOD_END(0)
}

/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: nativeCloseParquetWriter
* Signature: (J)I
*/
JNIEXPORT jint JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_nativeCloseParquetWriter(
JNIEnv* env, jobject, jlong native_ptr) {
JNI_METHOD_START
auto holder = RetrieveNativeInstance<ParquetWriterHolder>(native_ptr);
if (holder->writer) {
JniAssertOkOrThrow(holder->writer->Close());
holder->writer.reset();
}
// Close output stream
if (holder->output_stream) {
JniAssertOkOrThrow(holder->output_stream->Close());
}
ReleaseNativeRef<ParquetWriterHolder>(native_ptr);
return 1; // Success
JNI_METHOD_END(0)
}
Loading
Loading