Skip to content

Commit

Permalink
feat: Support decoding Float16 in Parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Oct 17, 2024
1 parent ed7e9d4 commit 91e97cf
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 1 deletion.
3 changes: 3 additions & 0 deletions crates/polars-io/src/parquet/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ fn assert_dtypes(dtype: &ArrowDataType) {
// These should all be casted to the BinaryView / Utf8View variants
D::Utf8 | D::Binary | D::LargeUtf8 | D::LargeBinary => unreachable!(),

// These should be casted to to Float32
D::Float16 => unreachable!(),

// This should have been converted to a LargeList
D::List(_) => unreachable!(),

Expand Down
30 changes: 29 additions & 1 deletion crates/polars-parquet/src/arrow/read/deserialize/simple.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use arrow::array::{Array, DictionaryArray, DictionaryKey, FixedSizeBinaryArray, PrimitiveArray};
use arrow::datatypes::{ArrowDataType, IntervalUnit, TimeUnit};
use arrow::match_integer_type;
use arrow::types::{days_ms, i256};
use arrow::types::{days_ms, i256, NativeType};
use ethnum::I256;
use polars_error::{polars_bail, PolarsResult};

Expand Down Expand Up @@ -275,6 +275,34 @@ pub fn page_iter_to_array(
primitive::IntDecoder::<i64, u64, _>::cast_as(),
)?
.collect_n(filter)?),

// Float16
(PhysicalType::FixedLenByteArray(2), Float32) => {
// @NOTE: To reduce code bloat, we just use the FixedSizeBinary decoder.

let mut fsb_array = PageDecoder::new(
pages,
ArrowDataType::FixedSizeBinary(2),
fixed_size_binary::BinaryDecoder { size: 2 },
)?.collect_n(filter)?;


let validity = fsb_array.take_validity();
let values = fsb_array.values().as_slice();
assert_eq!(values.len() % 2, 0);
let values = values.chunks_exact(2);
let values = values.map(|v| {
// SAFETY: We know that `v` is always of size two.
let le_bytes: [u8; 2] = unsafe { v.try_into().unwrap_unchecked() };
let v = arrow::types::f16::from_le_bytes(le_bytes);
v.to_f32()
}).collect();

let array = PrimitiveArray::<f32>::new(dtype, values, validity);

Box::new(array)
},

(PhysicalType::Float, Float32) => Box::new(PageDecoder::new(
pages,
dtype,
Expand Down
1 change: 1 addition & 0 deletions crates/polars-parquet/src/arrow/read/schema/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ fn convert_dtype(mut dtype: ArrowDataType) -> ArrowDataType {
convert_field(field);
}
},
Float16 => dtype = Float32,
Binary | LargeBinary => dtype = BinaryView,
Utf8 | LargeUtf8 => dtype = Utf8View,
Dictionary(_, ref mut dtype, _) | Extension(_, ref mut dtype, _) => {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-parquet/src/parquet/metadata/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ fn get_logical_sort_order(logical_type: &PrimitiveLogicalType) -> SortOrder {
Timestamp { .. } => SortOrder::Signed,
Unknown => SortOrder::Undefined,
Uuid => SortOrder::Unsigned,
Float16 => SortOrder::Unsigned,
}
}

Expand Down
3 changes: 3 additions & 0 deletions crates/polars-parquet/src/parquet/parquet_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ pub enum PrimitiveLogicalType {
Json,
Bson,
Uuid,
Float16,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -575,6 +576,7 @@ impl TryFrom<ParquetLogicalType> for PrimitiveLogicalType {
ParquetLogicalType::JSON(_) => PrimitiveLogicalType::Json,
ParquetLogicalType::BSON(_) => PrimitiveLogicalType::Bson,
ParquetLogicalType::UUID(_) => PrimitiveLogicalType::Uuid,
ParquetLogicalType::FLOAT16(_) => PrimitiveLogicalType::Float16,
_ => return Err(ParquetError::oos("LogicalType value out of range")),
})
}
Expand Down Expand Up @@ -629,6 +631,7 @@ impl From<PrimitiveLogicalType> for ParquetLogicalType {
PrimitiveLogicalType::Json => ParquetLogicalType::JSON(Default::default()),
PrimitiveLogicalType::Bson => ParquetLogicalType::BSON(Default::default()),
PrimitiveLogicalType::Uuid => ParquetLogicalType::UUID(Default::default()),
PrimitiveLogicalType::Float16 => ParquetLogicalType::FLOAT16(Default::default()),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/polars-parquet/src/parquet/schema/types/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ pub fn check_logical_invariants(
(String | Json | Bson, PhysicalType::ByteArray) => {},
// https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#uuid
(Uuid, PhysicalType::FixedLenByteArray(16)) => {},
(Float16, PhysicalType::FixedLenByteArray(2)) => {},
(a, b) => {
return Err(ParquetError::oos(format!(
"Cannot annotate {:?} from {:?} fields",
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,3 +2039,21 @@ def test_conserve_sortedness(
"Parquet conserved SortingColumn for column chunk of 'c' to Descending"
in captured
)


def test_decode_f16() -> None:
values = [float("nan"), 0.0, 0.5, 1.0, 1.5]

table = pa.Table.from_pydict(
{
"x": pa.array(np.array(values, dtype=np.float16), type=pa.float16()),
}
)

f = io.BytesIO()
pq.write_table(table, f)

f.seek(0)
df = pl.read_parquet(f)

assert_series_equal(df.get_column("x"), pl.Series("x", values, pl.Float32))

0 comments on commit 91e97cf

Please sign in to comment.