Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-43911: [C++] Compute Row: ListKeyEncoder Supports #43912

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 66 additions & 41 deletions cpp/src/arrow/compute/row/row_encoder_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,51 @@ using internal::FirstTimeBitmapWriter;
namespace compute {
namespace internal {

Result<std::shared_ptr<KeyEncoder>> MakeKeyEncoder(const TypeHolder& column_type, std::shared_ptr<ExtensionType>* extension_type, MemoryPool* pool) {
Copy link
Member Author

@mapleFU mapleFU Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also return unique_ptr here. I didn't see the purpose a shared_ptr being used

Also this function is extracted from RowEncoder

const bool is_extension = column_type.id() == Type::EXTENSION;
const TypeHolder& type =
is_extension
? arrow::internal::checked_cast<const ExtensionType*>(column_type.type)
->storage_type()
: column_type;

if (is_extension) {
*extension_type = arrow::internal::checked_pointer_cast<ExtensionType>(
column_type.GetSharedPtr());
}
if (type.id() == Type::BOOL) {
return std::make_shared<BooleanKeyEncoder>();
}

if (type.id() == Type::DICTIONARY) {
return std::make_shared<DictionaryKeyEncoder>(type.GetSharedPtr(), pool);
}

if (is_fixed_width(type.id())) {
return std::make_shared<FixedWidthKeyEncoder>(type.GetSharedPtr());
}

if (is_binary_like(type.id())) {
return std::make_shared<VarLengthKeyEncoder<BinaryType>>(type.GetSharedPtr());
}

if (is_large_binary_like(type.id())) {
return std::make_shared<VarLengthKeyEncoder<LargeBinaryType>>(type.GetSharedPtr());
}

if (is_list(type.id())) {
auto element_type = ::arrow::checked_cast<BaseListType*>(type.type)->value_type();
if (is_nested(element_type->id())) {
return Status::NotImplemented("Unsupported nested type in List for row encoder", type.ToString());
}
std::shared_ptr<ExtensionType> element_extension_type;
ARROW_ASSIGN_OR_RAISE(auto element_encoder, MakeKeyEncoder(element_type, &element_extension_type, pool));
return std::make_shared<ListKeyEncoder>(std::move(element_type), std::move(element_encoder));
}

return Status::NotImplemented("Unsupported type for row encoder", type.ToString());
}

// extract the null bitmap from the leading nullity bytes of encoded keys
Status KeyEncoder::DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes,
std::shared_ptr<Buffer>* null_bitmap,
Expand Down Expand Up @@ -256,53 +301,32 @@ Result<std::shared_ptr<ArrayData>> DictionaryKeyEncoder::Decode(uint8_t** encode
return data;
}

void RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx) {
ctx_ = ctx;
encoders_.resize(column_types.size());
extension_types_.resize(column_types.size());
ListKeyEncoder::ListKeyEncoder(std::shared_ptr<DataType> element_type, std::shared_ptr<KeyEncoder> element_encoder)
: element_type_(std::move(element_type)), element_encoder_(std::move(element_encoder)) {}

for (size_t i = 0; i < column_types.size(); ++i) {
const bool is_extension = column_types[i].id() == Type::EXTENSION;
const TypeHolder& type =
is_extension
? arrow::internal::checked_cast<const ExtensionType*>(column_types[i].type)
->storage_type()
: column_types[i];

if (is_extension) {
extension_types_[i] = arrow::internal::checked_pointer_cast<ExtensionType>(
column_types[i].GetSharedPtr());
}
if (type.id() == Type::BOOL) {
encoders_[i] = std::make_shared<BooleanKeyEncoder>();
continue;
}
void ListKeyEncoder::AddLength(const ExecValue& exec_value, int64_t batch_length, int32_t* lengths) {}

if (type.id() == Type::DICTIONARY) {
encoders_[i] =
std::make_shared<DictionaryKeyEncoder>(type.GetSharedPtr(), ctx->memory_pool());
continue;
}
void ListKeyEncoder::AddLengthNull(int32_t* length) {}

if (is_fixed_width(type.id())) {
encoders_[i] = std::make_shared<FixedWidthKeyEncoder>(type.GetSharedPtr());
continue;
}
Status ListKeyEncoder::Encode(const ExecValue& data, int64_t batch_length,
uint8_t** encoded_bytes) {
return Status::NotImplemented("ListKeyEncoder::Encode");
}

if (is_binary_like(type.id())) {
encoders_[i] =
std::make_shared<VarLengthKeyEncoder<BinaryType>>(type.GetSharedPtr());
continue;
}
void ListKeyEncoder::EncodeNull(uint8_t** encoded_bytes) {}

if (is_large_binary_like(type.id())) {
encoders_[i] =
std::make_shared<VarLengthKeyEncoder<LargeBinaryType>>(type.GetSharedPtr());
continue;
}
Result<std::shared_ptr<ArrayData>> ListKeyEncoder::Decode(uint8_t** encoded_bytes, int32_t length,
MemoryPool* pool) {
return std::shared_ptr<ArrayData>(nullptr);
}

Status RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx) {
ctx_ = ctx;
encoders_.resize(column_types.size());
extension_types_.resize(column_types.size());

// We should not get here
ARROW_DCHECK(false);
for (size_t i = 0; i < column_types.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(encoders_[i], MakeKeyEncoder(column_types[i], &extension_types_[i], ctx->memory_pool()));
}

int32_t total_length = 0;
Expand All @@ -314,6 +338,7 @@ void RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext*
for (size_t i = 0; i < column_types.size(); ++i) {
encoders_[i]->EncodeNull(&buf_ptr);
}
return Status::OK();
}

void RowEncoder::Clear() {
Expand Down
33 changes: 32 additions & 1 deletion cpp/src/arrow/compute/row/row_encoder_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,28 @@ struct ARROW_EXPORT NullKeyEncoder : KeyEncoder {
}
};

struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder should I put this into .cc since it requires a lot

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please do. It would be nice to hide most contents from this file into the corresponding .cc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment explaining how the encoding looks like?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see you added a comment below.

explicit ListKeyEncoder(std::shared_ptr<DataType> element_type, std::shared_ptr<KeyEncoder> element_encoder);

void AddLength(const ExecValue&, int64_t batch_length, int32_t* lengths) override;

void AddLengthNull(int32_t* length) override;

Status Encode(const ExecValue& data, int64_t batch_length,
uint8_t** encoded_bytes) override;

void EncodeNull(uint8_t** encoded_bytes) override;

Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
MemoryPool* pool) override;

std::shared_ptr<DataType> element_type_;
std::shared_ptr<KeyEncoder> element_encoder_;
// extension_type_ is used to store the extension type of the list element.
// It would be nullptr if the list element is not an extension type.
std::shared_ptr<ExtensionType> extension_type_;
};

/// RowEncoder encodes ExecSpan to a variable length byte sequence
/// created by concatenating the encoded form of each column. The encoding
/// for each column depends on its data type.
Expand Down Expand Up @@ -328,14 +350,23 @@ struct ARROW_EXPORT NullKeyEncoder : KeyEncoder {
/// Null string Would be encoded as:
/// 1 ( 1 byte for null) + 0 ( 4 bytes for length )
///
/// ## List Type
///
/// List Type is encoded as:
/// [null byte, list element count, [element 1, element 2, ...]]
/// Element count uses 4 bytes.
///
/// Currently, we only support encoding of primitive types, dictionary types
/// in the list, the nested list is not supported.
///
/// # Row Encoding
///
/// The row format is the concatenation of the encodings of each column.
class ARROW_EXPORT RowEncoder {
public:
static constexpr int kRowIdForNulls() { return -1; }

void Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx);
Status Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx);
void Clear();
Status EncodeAndAppend(const ExecSpan& batch);
Result<ExecBatch> Decode(int64_t num_rows, const int32_t* row_ids);
Expand Down
Loading