Skip to content

Commit bb65358

Browse files
authored
Read/write nested dictionary in ipc stream reader/writer (#1566)
* Read dictionary inside dictionary * Fix clippy
1 parent 2bcc0cf commit bb65358

File tree

3 files changed

+97
-7
lines changed

3 files changed

+97
-7
lines changed

arrow/src/datatypes/field.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,25 @@ impl Field {
116116
/// Returns a (flattened) vector containing all fields contained within this field (including it self)
117117
pub(crate) fn fields(&self) -> Vec<&Field> {
118118
let mut collected_fields = vec![self];
119-
match &self.data_type {
119+
collected_fields.append(&mut self._fields(&self.data_type));
120+
121+
collected_fields
122+
}
123+
124+
fn _fields<'a>(&'a self, dt: &'a DataType) -> Vec<&Field> {
125+
let mut collected_fields = vec![];
126+
127+
match dt {
120128
DataType::Struct(fields) | DataType::Union(fields, _) => {
121129
collected_fields.extend(fields.iter().flat_map(|f| f.fields()))
122130
}
123131
DataType::List(field)
124132
| DataType::LargeList(field)
125133
| DataType::FixedSizeList(field, _)
126134
| DataType::Map(field, _) => collected_fields.push(field),
135+
DataType::Dictionary(_, value_field) => {
136+
collected_fields.append(&mut self._fields(value_field.as_ref()))
137+
}
127138
_ => (),
128139
}
129140

arrow/src/ipc/reader.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,7 @@ mod tests {
10191019

10201020
use flate2::read::GzDecoder;
10211021

1022+
use crate::datatypes::Int8Type;
10221023
use crate::{datatypes, util::integration_util::*};
10231024

10241025
#[test]
@@ -1441,4 +1442,42 @@ mod tests {
14411442
let output_batch = roundtrip_ipc_stream(&input_batch);
14421443
assert_eq!(input_batch, output_batch);
14431444
}
1445+
1446+
#[test]
1447+
fn test_roundtrip_stream_nested_dict_dict() {
1448+
let values = StringArray::from_iter_values(["a", "b", "c"]);
1449+
let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1]);
1450+
let dict_array = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();
1451+
let dict_data = dict_array.data();
1452+
1453+
let value_offsets = Buffer::from_slice_ref(&[0, 2, 4, 6]);
1454+
1455+
let list_data_type = DataType::List(Box::new(Field::new_dict(
1456+
"item",
1457+
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
1458+
false,
1459+
1,
1460+
false,
1461+
)));
1462+
let list_data = ArrayData::builder(list_data_type)
1463+
.len(3)
1464+
.add_buffer(value_offsets)
1465+
.add_child_data(dict_data.clone())
1466+
.build()
1467+
.unwrap();
1468+
let list_array = ListArray::from(list_data);
1469+
1470+
let dict_dict_array =
1471+
DictionaryArray::<Int8Type>::try_new(&keys, &list_array).unwrap();
1472+
1473+
let schema = Arc::new(Schema::new(vec![Field::new(
1474+
"f1",
1475+
dict_dict_array.data_type().clone(),
1476+
false,
1477+
)]));
1478+
let input_batch =
1479+
RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
1480+
let output_batch = roundtrip_ipc_stream(&input_batch);
1481+
assert_eq!(input_batch, output_batch);
1482+
}
14441483
}

arrow/src/ipc/writer.rs

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ use std::io::{BufWriter, Write};
2525

2626
use flatbuffers::FlatBufferBuilder;
2727

28-
use crate::array::{as_struct_array, as_union_array, ArrayData, ArrayRef};
28+
use crate::array::{
29+
as_list_array, as_struct_array, as_union_array, make_array, ArrayData, ArrayRef,
30+
};
2931
use crate::buffer::{Buffer, MutableBuffer};
3032
use crate::datatypes::*;
3133
use crate::error::{ArrowError, Result};
@@ -137,15 +139,14 @@ impl IpcDataGenerator {
137139
}
138140
}
139141

140-
fn encode_dictionaries(
142+
fn _encode_dictionaries(
141143
&self,
142-
field: &Field,
143144
column: &ArrayRef,
144145
encoded_dictionaries: &mut Vec<EncodedData>,
145146
dictionary_tracker: &mut DictionaryTracker,
146147
write_options: &IpcWriteOptions,
147148
) -> Result<()> {
148-
// TODO: Handle other nested types (map, list, etc)
149+
// TODO: Handle other nested types (map, etc)
149150
match column.data_type() {
150151
DataType::Struct(fields) => {
151152
let s = as_struct_array(column);
@@ -159,6 +160,16 @@ impl IpcDataGenerator {
159160
)?;
160161
}
161162
}
163+
DataType::List(field) => {
164+
let list = as_list_array(column);
165+
self.encode_dictionaries(
166+
field,
167+
&list.values(),
168+
encoded_dictionaries,
169+
dictionary_tracker,
170+
write_options,
171+
)?;
172+
}
162173
DataType::Union(fields, _) => {
163174
let union = as_union_array(column);
164175
for (field, ref column) in fields
@@ -175,13 +186,37 @@ impl IpcDataGenerator {
175186
)?;
176187
}
177188
}
189+
_ => (),
190+
}
191+
192+
Ok(())
193+
}
194+
195+
fn encode_dictionaries(
196+
&self,
197+
field: &Field,
198+
column: &ArrayRef,
199+
encoded_dictionaries: &mut Vec<EncodedData>,
200+
dictionary_tracker: &mut DictionaryTracker,
201+
write_options: &IpcWriteOptions,
202+
) -> Result<()> {
203+
match column.data_type() {
178204
DataType::Dictionary(_key_type, _value_type) => {
179205
let dict_id = field
180206
.dict_id()
181207
.expect("All Dictionary types have `dict_id`");
182208
let dict_data = column.data();
183209
let dict_values = &dict_data.child_data()[0];
184210

211+
let values = make_array(dict_data.child_data()[0].clone());
212+
213+
self._encode_dictionaries(
214+
&values,
215+
encoded_dictionaries,
216+
dictionary_tracker,
217+
write_options,
218+
)?;
219+
185220
let emit = dictionary_tracker.insert(dict_id, column)?;
186221

187222
if emit {
@@ -192,7 +227,12 @@ impl IpcDataGenerator {
192227
));
193228
}
194229
}
195-
_ => (),
230+
_ => self._encode_dictionaries(
231+
column,
232+
encoded_dictionaries,
233+
dictionary_tracker,
234+
write_options,
235+
)?,
196236
}
197237

198238
Ok(())
@@ -205,7 +245,7 @@ impl IpcDataGenerator {
205245
write_options: &IpcWriteOptions,
206246
) -> Result<(Vec<EncodedData>, EncodedData)> {
207247
let schema = batch.schema();
208-
let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len());
248+
let mut encoded_dictionaries = Vec::with_capacity(schema.all_fields().len());
209249

210250
for (i, field) in schema.fields().iter().enumerate() {
211251
let column = batch.column(i);

0 commit comments

Comments
 (0)