Skip to content

Commit aa4f3c7

Browse files
authored
Extension DType vtables (#6081)
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent ebd2a56 commit aa4f3c7

File tree

117 files changed

+3285
-2794
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

117 files changed

+3285
-2794
lines changed

Cargo.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encodings/datetime-parts/src/canonical.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use vortex_array::validity::Validity;
1111
use vortex_buffer::BufferMut;
1212
use vortex_dtype::DType;
1313
use vortex_dtype::PType;
14-
use vortex_dtype::datetime::TemporalMetadata;
1514
use vortex_dtype::datetime::TimeUnit;
15+
use vortex_dtype::datetime::Timestamp;
1616
use vortex_dtype::match_each_integer_ptype;
1717
use vortex_error::VortexExpect as _;
1818
use vortex_error::VortexResult;
@@ -31,11 +31,11 @@ pub fn decode_to_temporal(
3131
vortex_panic!(ComputeError: "expected dtype to be DType::Extension variant")
3232
};
3333

34-
let Ok(temporal_metadata) = TemporalMetadata::try_from(ext.as_ref()) else {
34+
let Some(options) = ext.metadata_opt::<Timestamp>() else {
3535
vortex_panic!(ComputeError: "must decode TemporalMetadata from extension metadata");
3636
};
3737

38-
let divisor = match temporal_metadata.time_unit() {
38+
let divisor = match options.unit {
3939
TimeUnit::Nanoseconds => 1_000_000_000,
4040
TimeUnit::Microseconds => 1_000_000,
4141
TimeUnit::Milliseconds => 1_000,
@@ -98,8 +98,8 @@ pub fn decode_to_temporal(
9898
Ok(TemporalArray::new_timestamp(
9999
PrimitiveArray::new(values.freeze(), Validity::copy_from_array(array.as_ref())?)
100100
.into_array(),
101-
temporal_metadata.time_unit(),
102-
temporal_metadata.time_zone().map(ToString::to_string),
101+
options.unit,
102+
options.tz.clone(),
103103
))
104104
}
105105

@@ -142,7 +142,7 @@ mod test {
142142
let date_times = DateTimePartsArray::try_from(TemporalArray::new_timestamp(
143143
milliseconds.clone().into_array(),
144144
TimeUnit::Milliseconds,
145-
Some("UTC".to_string()),
145+
Some("UTC".into()),
146146
))
147147
.unwrap();
148148

encodings/datetime-parts/src/compress.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,8 @@ mod tests {
100100
validity.clone(),
101101
)
102102
.into_array();
103-
let temporal_array = TemporalArray::new_timestamp(
104-
milliseconds,
105-
TimeUnit::Milliseconds,
106-
Some("UTC".to_string()),
107-
);
103+
let temporal_array =
104+
TemporalArray::new_timestamp(milliseconds, TimeUnit::Milliseconds, Some("UTC".into()));
108105
let TemporalParts {
109106
days,
110107
seconds,

encodings/datetime-parts/src/compute/cast.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ mod tests {
6666
)
6767
.into_array(),
6868
TimeUnit::Milliseconds,
69-
Some("UTC".to_string()),
69+
Some("UTC".into()),
7070
))
7171
.unwrap()
7272
.into_array()
@@ -99,7 +99,7 @@ mod tests {
9999
let result = cast(&array, &DType::Bool(Nullability::NonNullable));
100100
assert!(
101101
result.as_ref().is_err_and(|err| err.to_string().contains(
102-
"No compute kernel to cast array vortex.ext with dtype ext(vortex.timestamp, i64, ExtMetadata([2, 3, 0, 85, 84, 67]))? to bool"
102+
"No compute kernel to cast array vortex.ext with dtype vortex.timestamp[unit=ms, tz=UTC](i64?) to bool"
103103
)),
104104
"Got error: {result:?}"
105105
);
@@ -126,7 +126,7 @@ mod tests {
126126
345_600_000, // 4 days in ms
127127
].into_array(),
128128
TimeUnit::Milliseconds,
129-
Some("UTC".to_string())
129+
Some("UTC".into())
130130
)).unwrap())]
131131
#[case(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
132132
PrimitiveArray::from_option_iter([
@@ -137,12 +137,12 @@ mod tests {
137137
None,
138138
]).into_array(),
139139
TimeUnit::Milliseconds,
140-
Some("UTC".to_string())
140+
Some("UTC".into())
141141
)).unwrap())]
142142
#[case(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
143143
buffer![86_400_000_000_000i64].into_array(), // 1 day in ns
144144
TimeUnit::Nanoseconds,
145-
Some("UTC".to_string())
145+
Some("UTC".into())
146146
)).unwrap())]
147147
fn test_cast_datetime_parts_conformance(#[case] array: DateTimePartsArray) {
148148
use vortex_array::compute::conformance::cast::test_cast_conformance;

encodings/datetime-parts/src/compute/compare.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use vortex_array::compute::or;
1515
use vortex_array::register_kernel;
1616
use vortex_dtype::DType;
1717
use vortex_dtype::Nullability;
18-
use vortex_dtype::datetime::TemporalMetadata;
18+
use vortex_dtype::datetime::Timestamp;
1919
use vortex_error::VortexResult;
2020
use vortex_scalar::Scalar;
2121

@@ -50,8 +50,10 @@ impl CompareKernel for DateTimePartsVTable {
5050

5151
let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
5252

53-
let temporal_metadata = TemporalMetadata::try_from(ext_dtype.as_ref())?;
54-
let ts_parts = timestamp::split(timestamp, temporal_metadata.time_unit())?;
53+
let Some(options) = ext_dtype.metadata_opt::<Timestamp>() else {
54+
return Ok(None);
55+
};
56+
let ts_parts = timestamp::split(timestamp, options.unit)?;
5557

5658
match operator {
5759
Operator::Eq => compare_eq(lhs, &ts_parts, nullability),
@@ -218,7 +220,7 @@ mod test {
218220
DateTimePartsArray::try_from(TemporalArray::new_timestamp(
219221
PrimitiveArray::new(buffer![value], validity).into_array(),
220222
TimeUnit::Seconds,
221-
Some("UTC".to_string()),
223+
Some("UTC".into()),
222224
))
223225
.expect("Failed to construct DateTimePartsArray from TemporalArray")
224226
}
@@ -293,7 +295,7 @@ mod test {
293295
let temporal_array = TemporalArray::new_timestamp(
294296
PrimitiveArray::new(buffer![0i64], lhs_validity.clone()).into_array(),
295297
TimeUnit::Seconds,
296-
Some("UTC".to_string()),
298+
Some("UTC".into()),
297299
);
298300

299301
let lhs = DateTimePartsArray::try_new(

encodings/datetime-parts/src/compute/filter.rs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,8 @@ mod test {
4848
]
4949
.into_array();
5050

51-
let temporal = TemporalArray::new_timestamp(
52-
timestamps,
53-
TimeUnit::Milliseconds,
54-
Some("UTC".to_string()),
55-
);
51+
let temporal =
52+
TemporalArray::new_timestamp(timestamps, TimeUnit::Milliseconds, Some("UTC".into()));
5653

5754
let array = DateTimePartsArray::try_from(temporal).unwrap();
5855
test_filter_conformance(array.as_ref());
@@ -67,11 +64,8 @@ mod test {
6764
])
6865
.into_array();
6966

70-
let temporal = TemporalArray::new_timestamp(
71-
timestamps,
72-
TimeUnit::Milliseconds,
73-
Some("UTC".to_string()),
74-
);
67+
let temporal =
68+
TemporalArray::new_timestamp(timestamps, TimeUnit::Milliseconds, Some("UTC".into()));
7569

7670
let array = DateTimePartsArray::try_from(temporal).unwrap();
7771
test_filter_conformance(array.as_ref());

encodings/datetime-parts/src/compute/mask.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use vortex_array::Array;
45
use vortex_array::ArrayRef;
56
use vortex_array::compute::MaskKernel;
67
use vortex_array::compute::MaskKernelAdapter;
@@ -23,17 +24,14 @@ impl MaskKernel for DateTimePartsVTable {
2324
// through the days component.
2425

2526
let masked_days = mask(array.days(), mask_array)?;
27+
assert!(masked_days.dtype().is_nullable());
2628

2729
// Keep seconds and subseconds unchanged since they must remain non-nullable
2830
let seconds = array.seconds().clone();
2931
let subseconds = array.subseconds().clone();
3032

3133
// Update the dtype to reflect the new nullability of days
32-
let new_dtype = if masked_days.dtype().is_nullable() {
33-
array.dtype().as_nullable()
34-
} else {
35-
array.dtype().clone()
36-
};
34+
let new_dtype = array.dtype().as_nullable();
3735

3836
DateTimePartsArray::try_new(new_dtype, masked_days, seconds, subseconds)
3937
.map(|a| a.to_array())

encodings/datetime-parts/src/compute/mod.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,46 +26,46 @@ mod tests {
2626
#[case::datetime_seconds(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
2727
buffer![0i64, 86400, 172800, 259200, 345600].into_array(),
2828
TimeUnit::Seconds,
29-
Some("UTC".to_string()),
29+
Some("UTC".into()),
3030
)).unwrap())]
3131
#[case::datetime_millis(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
3232
buffer![0i64, 86400000, 172800000].into_array(),
3333
TimeUnit::Milliseconds,
34-
Some("UTC".to_string()),
34+
Some("UTC".into()),
3535
)).unwrap())]
3636
#[case::datetime_micros(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
3737
buffer![0i64, 86400000000, 172800000000].into_array(),
3838
TimeUnit::Microseconds,
39-
Some("UTC".to_string()),
39+
Some("UTC".into()),
4040
)).unwrap())]
4141
#[case::datetime_nanos(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
4242
buffer![0i64, 86400000000000].into_array(),
4343
TimeUnit::Nanoseconds,
44-
Some("UTC".to_string()),
44+
Some("UTC".into()),
4545
)).unwrap())]
4646
// Nullable arrays
4747
#[case::datetime_nullable_seconds(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
4848
PrimitiveArray::from_option_iter([Some(0i64), None, Some(86400), Some(172800), None]).into_array(),
4949
TimeUnit::Seconds,
50-
Some("UTC".to_string()),
50+
Some("UTC".into()),
5151
)).unwrap())]
5252
// Edge cases
5353
#[case::datetime_single(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
5454
buffer![1234567890i64].into_array(),
5555
TimeUnit::Seconds,
56-
Some("UTC".to_string()),
56+
Some("UTC".into()),
5757
)).unwrap())]
5858
// Large arrays (> 1024 elements)
5959
#[case::datetime_large(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
6060
PrimitiveArray::from_iter((0..1500).map(|i| i as i64 * 86400)).into_array(),
6161
TimeUnit::Seconds,
62-
Some("UTC".to_string()),
62+
Some("UTC".into()),
6363
)).unwrap())]
6464
// Different time patterns
6565
#[case::datetime_with_subseconds(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
6666
buffer![123456789i64, 234567890, 345678901, 456789012, 567890123].into_array(),
6767
TimeUnit::Milliseconds,
68-
Some("UTC".to_string()),
68+
Some("UTC".into()),
6969
)).unwrap())]
7070

7171
fn test_datetime_parts_consistency(#[case] array: DateTimePartsArray) {

encodings/datetime-parts/src/compute/rules.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use vortex_array::optimizer::ArrayOptimizer;
1818
use vortex_array::optimizer::rules::ArrayParentReduceRule;
1919
use vortex_array::optimizer::rules::ParentRuleSet;
2020
use vortex_dtype::DType;
21-
use vortex_dtype::datetime::TemporalMetadata;
21+
use vortex_dtype::datetime::Timestamp;
2222
use vortex_error::VortexExpect;
2323
use vortex_error::VortexResult;
2424

@@ -161,8 +161,8 @@ fn try_extract_days_constant(array: &ArrayRef) -> Option<i64> {
161161
return None;
162162
};
163163

164-
let temporal_metadata = TemporalMetadata::try_from(ext_dtype.as_ref()).ok()?;
165-
let ts_parts = timestamp::split(timestamp, temporal_metadata.time_unit()).ok()?;
164+
let options = ext_dtype.metadata::<Timestamp>();
165+
let ts_parts = timestamp::split(timestamp, options.unit).ok()?;
166166

167167
// Only allow pushdown if seconds and subseconds are zero
168168
if ts_parts.seconds != 0 || ts_parts.subseconds != 0 {
@@ -190,8 +190,8 @@ mod tests {
190190
use vortex_array::optimizer::ArrayOptimizer;
191191
use vortex_array::validity::Validity;
192192
use vortex_buffer::Buffer;
193-
use vortex_buffer::buffer;
194193
use vortex_dtype::datetime::TimeUnit;
194+
use vortex_dtype::datetime::TimestampOptions;
195195
use vortex_scalar::Scalar;
196196

197197
use super::*;
@@ -230,12 +230,13 @@ mod tests {
230230
TimeUnit::Days => panic!("Days not supported"),
231231
};
232232
let timestamp = day * SECONDS_PER_DAY * multiplier;
233-
let temporal = TemporalArray::new_timestamp(
234-
PrimitiveArray::new(buffer![timestamp], Validity::NonNullable).into_array(),
235-
time_unit,
236-
None,
233+
let scalar = Scalar::extension::<Timestamp>(
234+
TimestampOptions {
235+
unit: time_unit,
236+
tz: None,
237+
},
238+
timestamp.into(),
237239
);
238-
let scalar = Scalar::extension(temporal.ext_dtype(), timestamp.into());
239240
ConstantArray::new(scalar, len).into_array()
240241
}
241242

@@ -249,12 +250,13 @@ mod tests {
249250
TimeUnit::Days => panic!("Days not supported"),
250251
};
251252
let timestamp = (day * SECONDS_PER_DAY + seconds) * multiplier;
252-
let temporal = TemporalArray::new_timestamp(
253-
PrimitiveArray::new(buffer![timestamp], Validity::NonNullable).into_array(),
254-
time_unit,
255-
None,
253+
let scalar = Scalar::extension::<Timestamp>(
254+
TimestampOptions {
255+
unit: time_unit,
256+
tz: None,
257+
},
258+
timestamp.into(),
256259
);
257-
let scalar = Scalar::extension(temporal.ext_dtype(), timestamp.into());
258260
ConstantArray::new(scalar, len).into_array()
259261
}
260262

encodings/datetime-parts/src/compute/take.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ mod tests {
117117
345_600_000, // 4 days in ms
118118
].into_array(),
119119
TimeUnit::Milliseconds,
120-
Some("UTC".to_string())
120+
Some("UTC".into())
121121
)).unwrap())]
122122
#[case(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
123123
PrimitiveArray::from_option_iter([
@@ -128,12 +128,12 @@ mod tests {
128128
None,
129129
]).into_array(),
130130
TimeUnit::Milliseconds,
131-
Some("UTC".to_string())
131+
Some("UTC".into())
132132
)).unwrap())]
133133
#[case(DateTimePartsArray::try_from(TemporalArray::new_timestamp(
134134
buffer![86_400_000i64].into_array(),
135135
TimeUnit::Milliseconds,
136-
Some("UTC".to_string())
136+
Some("UTC".into())
137137
)).unwrap())]
138138
fn test_take_datetime_parts_conformance(#[case] array: DateTimePartsArray) {
139139
test_take_conformance(array.as_ref());

0 commit comments

Comments
 (0)