Skip to content
Merged
135 changes: 64 additions & 71 deletions datafusion-examples/examples/extension_types/temperature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow::array::{
};
use arrow::datatypes::{Float32Type, Float64Type};
use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult};
use arrow_schema::extension::ExtensionType;
use arrow_schema::extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use datafusion::dataframe::DataFrame;
use datafusion::error::Result;
Expand All @@ -30,8 +30,9 @@ use datafusion::prelude::SessionContext;
use datafusion_common::internal_err;
use datafusion_common::types::DFExtensionType;
use datafusion_expr::registry::{
DefaultExtensionTypeRegistration, ExtensionTypeRegistry, MemoryExtensionTypeRegistry,
ExtensionTypeRegistration, ExtensionTypeRegistry, MemoryExtensionTypeRegistry,
};
use std::collections::HashMap;
use std::fmt::{Display, Write};
use std::sync::Arc;

Expand All @@ -50,13 +51,15 @@ fn create_session_context() -> Result<SessionContext> {
let registry = MemoryExtensionTypeRegistry::new_empty();

// The registration creates a new instance of the extension type with the deserialized metadata.
let temp_registration =
DefaultExtensionTypeRegistration::new_arc(|storage_type, metadata| {
Ok(TemperatureExtensionType::new(
storage_type.clone(),
metadata,
))
});
let temp_registration = ExtensionTypeRegistration::new_arc(
TemperatureExtensionType::NAME,
|storage_type, metadata| {
Ok(Arc::new(TemperatureExtensionType::try_new(
storage_type,
TemperatureUnit::deserialize(metadata)?,
)?))
},
);
registry.add_extension_type_registration(temp_registration)?;

let state = SessionStateBuilder::default()
Expand Down Expand Up @@ -96,26 +99,15 @@ async fn register_temperature_table(ctx: &SessionContext) -> Result<DataFrame> {
fn example_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("celsius", DataType::Float64, false).with_extension_type(
TemperatureExtensionType::new(DataType::Float64, TemperatureUnit::Celsius),
),
Field::new("fahrenheit", DataType::Float64, false).with_extension_type(
TemperatureExtensionType::new(DataType::Float64, TemperatureUnit::Fahrenheit),
),
Field::new("kelvin", DataType::Float32, false).with_extension_type(
TemperatureExtensionType::new(DataType::Float32, TemperatureUnit::Kelvin),
),
Field::new("celsius", DataType::Float64, false)
.with_metadata(create_metadata(TemperatureUnit::Celsius)),
Field::new("fahrenheit", DataType::Float64, false)
.with_metadata(create_metadata(TemperatureUnit::Fahrenheit)),
Field::new("kelvin", DataType::Float32, false)
.with_metadata(create_metadata(TemperatureUnit::Kelvin)),
]))
}

/// Represents the unit of a temperature reading.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TemperatureUnit {
Celsius,
Fahrenheit,
Kelvin,
}

/// Represents a float that semantically represents a temperature. The temperature can be one of
/// the supported [`TemperatureUnit`]s.
///
Expand Down Expand Up @@ -143,46 +135,57 @@ pub struct TemperatureExtensionType {
}

impl TemperatureExtensionType {
/// The name of the extension type.
pub const NAME: &'static str = "custom.temperature";

/// Creates a new [`TemperatureExtensionType`].
pub fn new(storage_type: DataType, temperature_unit: TemperatureUnit) -> Self {
Self {
storage_type,
temperature_unit,
pub fn try_new(
storage_type: &DataType,
temperature_unit: TemperatureUnit,
) -> Result<Self, ArrowError> {
match storage_type {
DataType::Float32 | DataType::Float64 => {}
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid data type: {storage_type} for temperature type, expected Float32 or Float64",
)));
Comment thread
tobixdev marked this conversation as resolved.
}
}

let result = Self {
storage_type: storage_type.clone(),
temperature_unit,
};
Ok(result)
}
}

/// Implementation of [`ExtensionType`] for [`TemperatureExtensionType`].
///
/// This implements the arrow-rs trait for reading, writing, and validating extension types.
impl ExtensionType for TemperatureExtensionType {
/// Arrow extension type name that is stored in the `ARROW:extension:name` field.
const NAME: &'static str = "custom.temperature";
type Metadata = TemperatureUnit;

fn metadata(&self) -> &Self::Metadata {
&self.temperature_unit
}
/// Represents the unit of a temperature reading.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TemperatureUnit {
Celsius,
Fahrenheit,
Kelvin,
}

impl TemperatureUnit {
/// Arrow extension type metadata is encoded as a string and stored using the
/// `ARROW:extension:metadata` key. As we only store the name of the unit, a simple string
/// suffices. Extension types can store more complex metadata using serialization formats like
/// JSON.
fn serialize_metadata(&self) -> Option<String> {
let s = match self.temperature_unit {
pub fn serialize(self) -> String {
let result = match self {
TemperatureUnit::Celsius => "celsius",
TemperatureUnit::Fahrenheit => "fahrenheit",
TemperatureUnit::Kelvin => "kelvin",
};
Some(s.to_string())
result.to_owned()
}

/// Inverse operation of [`Self::serialize_metadata`]. This creates the [`TemperatureUnit`]
/// Inverse operation of [`TemperatureUnit::serialize`]. This creates the [`TemperatureUnit`]
/// value from the serialized string.
fn deserialize_metadata(
metadata: Option<&str>,
) -> std::result::Result<Self::Metadata, ArrowError> {
match metadata {
pub fn deserialize(value: Option<&str>) -> std::result::Result<Self, ArrowError> {
match value {
Some("celsius") => Ok(TemperatureUnit::Celsius),
Some("fahrenheit") => Ok(TemperatureUnit::Fahrenheit),
Some("kelvin") => Ok(TemperatureUnit::Kelvin),
Expand All @@ -194,28 +197,18 @@ impl ExtensionType for TemperatureExtensionType {
)),
}
}
}

/// Checks that the extension type supports a given [`DataType`].
fn supports_data_type(
&self,
data_type: &DataType,
) -> std::result::Result<(), ArrowError> {
match data_type {
DataType::Float32 | DataType::Float64 => Ok(()),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Invalid data type: {data_type} for temperature type, expected Float32 or Float64",
))),
}
}

fn try_new(
data_type: &DataType,
metadata: Self::Metadata,
) -> std::result::Result<Self, ArrowError> {
let instance = Self::new(data_type.clone(), metadata);
instance.supports_data_type(data_type)?;
Ok(instance)
}
/// This creates a metadata map for the temperature type. Another way of writing the metadata can be
/// implemented using arrow-rs' [`ExtensionType`](arrow_schema::extension::ExtensionType) trait.
fn create_metadata(unit: TemperatureUnit) -> HashMap<String, String> {
HashMap::from([
(
EXTENSION_TYPE_NAME_KEY.to_owned(),
TemperatureExtensionType::NAME.to_owned(),
),
(EXTENSION_TYPE_METADATA_KEY.to_owned(), unit.serialize()),
])
}

/// Implementation of [`DFExtensionType`] for [`TemperatureExtensionType`].
Expand All @@ -227,7 +220,7 @@ impl DFExtensionType for TemperatureExtensionType {
}

fn serialize_metadata(&self) -> Option<String> {
ExtensionType::serialize_metadata(self)
Some(self.temperature_unit.serialize())
}

fn create_array_formatter<'fmt>(
Expand Down
123 changes: 123 additions & 0 deletions datafusion/common/src/types/canonical_extensions/bool8.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::Result;
use crate::error::_internal_err;
use crate::types::extension::DFExtensionType;
use arrow::array::{Array, Int8Array};
use arrow::datatypes::DataType;
use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult};
use arrow_schema::extension::{Bool8, ExtensionType};
use std::fmt::Write;

/// Defines the extension type logic for the canonical `arrow.bool8` extension type. This extension
/// type allows storing a Boolean value in a single byte, instead of a single bit.
///
/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also
/// [`Bool8`] for the implementation of arrow-rs, which this type uses internally.
///
/// <https://arrow.apache.org/docs/format/CanonicalExtensions.html#bit-boolean>
#[derive(Debug, Clone)]
pub struct DFBool8(Bool8);

impl DFBool8 {
/// Creates a new [`DFBool8`], validating that the storage type is compatible with the
/// extension type.
///
/// Even though [`DFBool8`] only supports a single storage type ([`DataType::Int8`]), passing-in
/// the storage type allows conveniently validating whether this extension type is compatible
/// with a given [`DataType`].
pub fn try_new(
data_type: &DataType,
metadata: <Bool8 as ExtensionType>::Metadata,
) -> Result<Self> {
// Validates the storage type
Ok(Self(<Bool8 as ExtensionType>::try_new(
data_type, metadata,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the supported data types be checked here as well? Maybe add to the docstring what data_type represents (it seems DFExtensionType::storage_type is hardcoded below so it's not immediately obvious why this is passed into the constructor).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point that this is not obvious. The validation happens in <Bool8 as ExtensionType>::try_new.

)?))
}
}

impl DFExtensionType for DFBool8 {
fn storage_type(&self) -> DataType {
DataType::Int8
}

fn serialize_metadata(&self) -> Option<String> {
self.0.serialize_metadata()
}

fn create_array_formatter<'fmt>(
&self,
array: &'fmt dyn Array,
options: &FormatOptions<'fmt>,
) -> Result<Option<ArrayFormatter<'fmt>>> {
if array.data_type() != &DataType::Int8 {
return _internal_err!("Wrong array type for Bool8");
}

let display_index = Bool8ValueDisplayIndex {
array: array.as_any().downcast_ref().unwrap(),
null_str: options.null(),
};
Ok(Some(ArrayFormatter::new(
Box::new(display_index),
options.safe(),
)))
}
}

/// Pretty printer for binary bool8 values.
#[derive(Debug, Clone, Copy)]
struct Bool8ValueDisplayIndex<'a> {
array: &'a Int8Array,
null_str: &'a str,
}

impl DisplayIndex for Bool8ValueDisplayIndex<'_> {
fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult {
if self.array.is_null(idx) {
write!(f, "{}", self.null_str)?;
return Ok(());
}

let bytes = self.array.value(idx);
write!(f, "{}", bytes != 0)?;
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
pub fn test_pretty_bool8() {
let values = Int8Array::from_iter([Some(0), Some(1), Some(-20), None]);

let extension_type = DFBool8(Bool8 {});
let formatter = extension_type
.create_array_formatter(&values, &FormatOptions::default().with_null("NULL"))
.unwrap()
.unwrap();

assert_eq!(formatter.value(0).to_string(), "false");
assert_eq!(formatter.value(1).to_string(), "true");
assert_eq!(formatter.value(2).to_string(), "true");
assert_eq!(formatter.value(3).to_string(), "NULL");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::Result;
use crate::types::extension::DFExtensionType;
use arrow::datatypes::DataType;
use arrow_schema::extension::{ExtensionType, FixedShapeTensor};

/// Defines the extension type logic for the canonical `arrow.fixed_shape_tensor` extension type.
Comment thread
tobixdev marked this conversation as resolved.
/// This extension type can be used to store a [tensor](https://en.wikipedia.org/wiki/Tensor) of
/// a fixed shape.
///
/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also
/// [`FixedShapeTensor`] for the implementation of arrow-rs, which this type uses internally.
///
/// <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor>
#[derive(Debug, Clone)]
pub struct DFFixedShapeTensor {
inner: FixedShapeTensor,
/// The storage type of the tensor.
///
/// While we could reconstruct the storage type from the inner [`FixedShapeTensor`], we may
/// choose a different name for the field within the [`DataType::FixedSizeList`] which can
/// cause problems down the line (e.g., checking for equality).
storage_type: DataType,
}

impl DFFixedShapeTensor {
/// Creates a new [`DFFixedShapeTensor`], validating that the storage type is compatible with
/// the extension type.
pub fn try_new(
data_type: &DataType,
metadata: <FixedShapeTensor as ExtensionType>::Metadata,
) -> Result<Self> {
Ok(Self {
inner: <FixedShapeTensor as ExtensionType>::try_new(data_type, metadata)?,
storage_type: data_type.clone(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will say this is a bit unwieldy. I don't have a better proposal but wanted to note it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Users could implement their DFExtensionType without using the arrow-rs trait (like we did in our example) or provide a common function that can be used on both trait implementations. The problem in this case is that <FixedShapeTensor as ExtensionType>::try_new does some processing that I do not want to duplicate.

So I think users can work around that issue.

})
}
}

impl DFExtensionType for DFFixedShapeTensor {
fn storage_type(&self) -> DataType {
self.storage_type.clone()
}

fn serialize_metadata(&self) -> Option<String> {
self.inner.serialize_metadata()
}
Comment on lines +57 to +63
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there cases where the metadata depends on the storage type?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one could think of such a type. E.g., a type that stores the serialized DataType in the metadata and then is only compatible with this exact DataType.

I am not sure how useful that is. Any concerns regarding that?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it does, self.storage_type() is available here (but I'm not aware of any extension types that duplicate information in the metadata and that would otherwise be inferrable from the storage type)

}
Loading
Loading