Skip to content

Commit 3566dfd

Browse files
committed
Cleanup for review
1 parent 34f546e commit 3566dfd

File tree

3 files changed

+104
-206
lines changed

3 files changed

+104
-206
lines changed

cpp/src/parquet/variant.cc

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ std::string VariantValue::typeDebugString() const {
326326

327327
bool VariantValue::getBool() const {
328328
if (getBasicType() != VariantBasicType::Primitive) {
329-
throw ParquetException("Not a primitive type");
329+
throw ParquetException("Expected primitive type, but got: " +
330+
variantBasicTypeToString(getBasicType()));
330331
}
331332

332333
int8_t primitive_type = static_cast<int8_t>(value[0]) >> 2;
@@ -341,11 +342,16 @@ bool VariantValue::getBool() const {
341342
std::to_string(primitive_type));
342343
}
343344

345+
void VariantValue::checkBasicType(VariantBasicType type) const {
346+
if (getBasicType() != type) {
347+
throw ParquetException("Expected basic type: " + variantBasicTypeToString(type) +
348+
", but got: " + variantBasicTypeToString(getBasicType()));
349+
}
350+
}
351+
344352
void VariantValue::checkPrimitiveType(VariantPrimitiveType type,
345353
size_t size_required) const {
346-
if (getBasicType() != VariantBasicType::Primitive) {
347-
throw ParquetException("Not a primitive type");
348-
}
354+
checkBasicType(VariantBasicType::Primitive);
349355

350356
auto primitive_type = static_cast<VariantPrimitiveType>(value[0] >> 2);
351357
if (primitive_type != type) {
@@ -354,17 +360,17 @@ void VariantValue::checkPrimitiveType(VariantPrimitiveType type,
354360
", but got: " + variantPrimitiveTypeToString(primitive_type));
355361
}
356362

357-
if (value.size() < 1 + size_required) {
363+
if (value.size() < size_required) {
358364
throw ParquetException("Invalid value: too short, expected at least " +
359-
std::to_string(1 + size_required) + " bytes for type " +
365+
std::to_string(size_required) + " bytes for type " +
360366
variantPrimitiveTypeToString(type) +
361367
", but got: " + std::to_string(value.size()) + " bytes");
362368
}
363369
}
364370

365371
template <typename PrimitiveType>
366372
PrimitiveType VariantValue::getPrimitiveType(VariantPrimitiveType type) const {
367-
checkPrimitiveType(type, sizeof(PrimitiveType));
373+
checkPrimitiveType(type, sizeof(PrimitiveType) + 1);
368374

369375
PrimitiveType primitive_value{};
370376
memcpy(&primitive_value, value.data() + 1, sizeof(PrimitiveType));
@@ -378,38 +384,27 @@ int8_t VariantValue::getInt8() const {
378384
}
379385

380386
int16_t VariantValue::getInt16() const {
381-
return getPrimitiveType<int8_t>(VariantPrimitiveType::Int16);
387+
return getPrimitiveType<int16_t>(VariantPrimitiveType::Int16);
382388
}
383389

384390
int32_t VariantValue::getInt32() const {
385-
return getPrimitiveType<int8_t>(VariantPrimitiveType::Int32);
391+
return getPrimitiveType<int32_t>(VariantPrimitiveType::Int32);
386392
}
387393

388394
int64_t VariantValue::getInt64() const {
389-
return getPrimitiveType<int8_t>(VariantPrimitiveType::Int64);
395+
return getPrimitiveType<int64_t>(VariantPrimitiveType::Int64);
390396
}
391397

392398
float VariantValue::getFloat() const {
393399
return getPrimitiveType<float>(VariantPrimitiveType::Float);
394400
}
395401

396402
double VariantValue::getDouble() const {
397-
return getPrimitiveType<float>(VariantPrimitiveType::Double);
403+
return getPrimitiveType<double>(VariantPrimitiveType::Double);
398404
}
399405

400406
std::string_view VariantValue::getPrimitiveBinaryType(VariantPrimitiveType type) const {
401-
VariantBasicType basic_type = getBasicType();
402-
if (basic_type != VariantBasicType::Primitive) {
403-
throw ParquetException("Not a primitive type");
404-
}
405-
auto primitive_type = static_cast<VariantPrimitiveType>(value[0] >> 2);
406-
if (primitive_type != VariantPrimitiveType::String) {
407-
throw ParquetException("Not a string type");
408-
}
409-
410-
if (value.size() < 5) {
411-
throw ParquetException("Invalid string value: too short");
412-
}
407+
checkPrimitiveType(type, /*size_required=*/5);
413408

414409
uint32_t length;
415410
memcpy(&length, value.data() + 1, sizeof(uint32_t));
@@ -468,7 +463,7 @@ DecimalValue<::arrow::Decimal64> VariantValue::getDecimal8() const {
468463

469464
DecimalValue<::arrow::Decimal128> VariantValue::getDecimal16() const {
470465
checkPrimitiveType(VariantPrimitiveType::Decimal16,
471-
/*size_required=*/sizeof(int64_t) * 2);
466+
/*size_required=*/sizeof(int64_t) * 2 + 2);
472467

473468
uint8_t scale = value[1];
474469

@@ -524,9 +519,7 @@ std::string VariantValue::ObjectInfo::toDebugString() const {
524519

525520

526521
VariantValue::ObjectInfo VariantValue::getObjectInfo() const {
527-
if (getBasicType() != VariantBasicType::Object) {
528-
throw ParquetException("Not an object type");
529-
}
522+
checkBasicType(VariantBasicType::Object);
530523
uint8_t value_header = value[0] >> 2;
531524
uint8_t field_offset_size = (value_header & 0b11) + 1;
532525
uint8_t field_id_size = ((value_header >> 2) & 0b11) + 1;
@@ -561,6 +554,7 @@ VariantValue::ObjectInfo VariantValue::getObjectInfo() const {
561554
memcpy(&final_offset,
562555
value.data() + info.offset_start_offset + num_elements * field_offset_size,
563556
field_offset_size);
557+
// It could be less than value size since it could be a sub-object.
564558
if (final_offset + info.data_start_offset > value.size()) {
565559
throw ParquetException("Invalid object value: final_offset=" +
566560
std::to_string(final_offset) +
@@ -591,12 +585,13 @@ std::optional<VariantValue> VariantValue::getObjectValueByKey(
591585
return std::nullopt;
592586
}
593587

594-
std::optional<VariantValue> VariantValue::getObjectFieldByFieldId(
595-
uint32_t variantId, std::string_view* key) const {
588+
VariantValue VariantValue::getObjectFieldByFieldId(uint32_t variantId,
589+
std::string_view* key) const {
596590
ObjectInfo info = getObjectInfo();
597591

598592
if (variantId >= info.num_elements) {
599-
throw ParquetException("Field ID out of range");
593+
throw ParquetException("Field ID out of range: " + std::to_string(variantId) +
594+
" >= " + std::to_string(info.num_elements));
600595
}
601596

602597
// Read the field ID
@@ -606,7 +601,7 @@ std::optional<VariantValue> VariantValue::getObjectFieldByFieldId(
606601
field_id = arrow::bit_util::FromLittleEndian(field_id);
607602

608603
// Get the key from metadata
609-
*key = metadata.getMetadataKey(field_id);
604+
*key = metadata.getMetadataKey(static_cast<int32_t>(field_id));
610605

611606
// Read the offset and next offset
612607
uint32_t offset = 0, next_offset = 0;
@@ -633,10 +628,7 @@ std::optional<VariantValue> VariantValue::getObjectFieldByFieldId(
633628
}
634629

635630
VariantValue::ArrayInfo VariantValue::getArrayInfo() const {
636-
if (getBasicType() != VariantBasicType::Array) {
637-
throw ParquetException("Expected array type, but got: " +
638-
variantBasicTypeToString(getBasicType()));
639-
}
631+
checkBasicType(VariantBasicType::Array);
640632
uint8_t value_header = value[0] >> 2;
641633
uint8_t field_offset_size = (value_header & 0b11) + 1;
642634
bool is_large = ((value_header >> 2) & 0b1);
@@ -649,7 +641,7 @@ VariantValue::ArrayInfo VariantValue::getArrayInfo() const {
649641
" for at least " + std::to_string(1 + num_elements_size));
650642
}
651643

652-
// 解析 num_elements
644+
// parse num_elements
653645
uint32_t num_elements = 0;
654646
{
655647
memcpy(&num_elements, value.data() + 1, num_elements_size);
@@ -663,14 +655,15 @@ VariantValue::ArrayInfo VariantValue::getArrayInfo() const {
663655
info.data_start_offset =
664656
info.offset_start_offset + (num_elements + 1) * field_offset_size;
665657

666-
// 检查边界
658+
// Boundary check
667659
if (info.data_start_offset > value.size()) {
668660
throw ParquetException("Invalid array value: data_start_offset=" +
669661
std::to_string(info.data_start_offset) +
670662
", value_size=" + std::to_string(value.size()));
671663
}
672664

673-
// 检查最终偏移量
665+
// Validate final offset is equal to the size of the value,
666+
// it would work since even empty array would have an offset of 0.
674667
{
675668
uint32_t final_offset = 0;
676669
memcpy(&final_offset,

cpp/src/parquet/variant.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ struct VariantValue {
189189
};
190190
ObjectInfo getObjectInfo() const;
191191
std::optional<VariantValue> getObjectValueByKey(std::string_view key) const;
192-
std::optional<VariantValue> getObjectFieldByFieldId(uint32_t variantId,
193-
std::string_view* key) const;
192+
VariantValue getObjectFieldByFieldId(uint32_t variantId, std::string_view* key) const;
194193

195194
struct ArrayInfo {
196195
uint32_t num_elements;
@@ -202,6 +201,7 @@ struct VariantValue {
202201
// Would throw ParquetException if index is out of range.
203202
VariantValue getArrayValueByIndex(uint32_t index) const;
204203

204+
private:
205205
static constexpr uint8_t BASIC_TYPE_MASK = 0b00000011;
206206
static constexpr uint8_t PRIMITIVE_TYPE_MASK = 0b00111111;
207207
/** The inclusive maximum value of the type info value. It is the size limit of
@@ -216,6 +216,7 @@ struct VariantValue {
216216
DecimalValue<DecimalType> getPrimitiveDecimalType(VariantPrimitiveType type) const;
217217

218218
std::string_view getPrimitiveBinaryType(VariantPrimitiveType type) const;
219+
void checkBasicType(VariantBasicType type) const;
219220
void checkPrimitiveType(VariantPrimitiveType type, size_t size_required) const;
220221
};
221222

0 commit comments

Comments
 (0)