Skip to content

Commit 22acfdc

Browse files
committed
Ensure map-like row data preserves column order
1 parent 30cc9a0 commit 22acfdc

File tree

2 files changed

+144
-8
lines changed

2 files changed

+144
-8
lines changed

src/serializer.rs

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,9 @@ impl<'a, 'w, W: io::Write> SerializeMap for &'a mut SeRecord<'w, W> {
292292

293293
fn serialize_key<T: ?Sized + Serialize>(
294294
&mut self,
295-
_key: &T,
295+
key: &T,
296296
) -> Result<(), Self::Error> {
297-
Ok(())
297+
self.wtr.check_map_key(key)
298298
}
299299

300300
fn serialize_value<T: ?Sized + Serialize>(
@@ -305,6 +305,7 @@ impl<'a, 'w, W: io::Write> SerializeMap for &'a mut SeRecord<'w, W> {
305305
}
306306

307307
fn end(self) -> Result<Self::Ok, Self::Error> {
308+
self.wtr.on_map_end();
308309
Ok(())
309310
}
310311
}
@@ -742,6 +743,7 @@ impl<'a, 'w, W: io::Write> SerializeMap for &'a mut SeHeader<'w, W> {
742743
return Err(err);
743744
}
744745

746+
self.wtr.check_map_key(key)?;
745747
let mut key_serializer = SeRecord { wtr: self.wtr };
746748
key.serialize(&mut key_serializer)?;
747749
self.state = HeaderState::InStructField;
@@ -763,6 +765,7 @@ impl<'a, 'w, W: io::Write> SerializeMap for &'a mut SeHeader<'w, W> {
763765
}
764766

765767
fn end(self) -> Result<Self::Ok, Self::Error> {
768+
self.wtr.on_map_end();
766769
Ok(())
767770
}
768771
}
@@ -856,6 +859,18 @@ mod tests {
856859
s.serialize(&mut SeHeader::new(&mut wtr)).unwrap_err()
857860
}
858861

862+
#[derive(Debug)]
863+
struct CustomOrderMap(Vec<(&'static str, f64)>);
864+
865+
impl Serialize for CustomOrderMap {
866+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
867+
where
868+
S: serde::Serializer,
869+
{
870+
serializer.collect_map(self.0.iter().copied())
871+
}
872+
}
873+
859874
#[test]
860875
fn bool() {
861876
let got = serialize(true);
@@ -1126,6 +1141,40 @@ mod tests {
11261141
}
11271142
}
11281143

1144+
#[test]
1145+
fn ordered_map() {
1146+
let mut map = BTreeMap::new();
1147+
map.insert("a", 2.0);
1148+
map.insert("b", 1.0);
1149+
1150+
let got = serialize(&map);
1151+
assert_eq!(got, "2.0,1.0\n");
1152+
let (wrote, got) = serialize_header(map);
1153+
assert!(wrote);
1154+
assert_eq!(got, "a,b");
1155+
}
1156+
1157+
#[test]
1158+
fn unordered_map() {
1159+
let mut writer = Writer::from_writer(vec![]);
1160+
writer
1161+
.serialize(CustomOrderMap(vec![("a", 2.0), ("b", 1.0)]))
1162+
.unwrap();
1163+
writer
1164+
.serialize(CustomOrderMap(vec![("a", 3.0), ("b", 4.0)]))
1165+
.unwrap();
1166+
writer.flush().unwrap();
1167+
let csv = String::from_utf8(writer.get_ref().clone()).unwrap();
1168+
assert_eq!(csv, "a,b\n2.0,1.0\n3.0,4.0\n");
1169+
let error = writer
1170+
.serialize(CustomOrderMap(vec![("b", 2.0), ("a", 1.0)])) // Wrong key order
1171+
.unwrap_err();
1172+
assert!(
1173+
matches!(error.kind(), ErrorKind::Serialize(_)),
1174+
"Got unexpected error: {error}"
1175+
)
1176+
}
1177+
11291178
#[test]
11301179
fn struct_no_headers() {
11311180
#[derive(Serialize)]
@@ -1391,6 +1440,36 @@ mod tests {
13911440
assert_eq!(got, format!("x,y,extra1,extra2"));
13921441
}
13931442

1443+
#[test]
1444+
fn flatten_map_with_different_key_order() {
1445+
#[derive(Serialize, Debug)]
1446+
struct Row {
1447+
x: f64,
1448+
y: f64,
1449+
#[serde(flatten)]
1450+
extra: CustomOrderMap,
1451+
}
1452+
let mut writer = Writer::from_writer(vec![]);
1453+
writer
1454+
.serialize(Row {
1455+
x: 1.0,
1456+
y: 2.0,
1457+
extra: CustomOrderMap(vec![("extra1", 3.0), ("extra2", 4.0)]),
1458+
})
1459+
.unwrap();
1460+
let error = writer
1461+
.serialize(Row {
1462+
x: 1.0,
1463+
y: 2.0,
1464+
extra: CustomOrderMap(vec![("extra2", 4.0), ("extra1", 3.0)]),
1465+
})
1466+
.unwrap_err();
1467+
assert!(
1468+
matches!(error.kind(), ErrorKind::Serialize(_)),
1469+
"Expected ErrorKind::Serialize but got '{error}'"
1470+
);
1471+
}
1472+
13941473
#[test]
13951474
fn flatten_map_different_num_entries() {
13961475
#[derive(Clone, Serialize, Debug, PartialEq)]
@@ -1400,20 +1479,20 @@ mod tests {
14001479
#[serde(flatten)]
14011480
extra: BTreeMap<&'static str, f64>,
14021481
}
1403-
let mut wtr = Writer::from_writer(vec![]);
1482+
let mut writer = Writer::from_writer(vec![]);
14041483

14051484
let mut extra = BTreeMap::new();
14061485
extra.insert("extra1", 3.0);
14071486
extra.insert("extra2", 4.0);
14081487
let row = Row { x: 1.0, y: 2.0, extra };
1409-
wtr.serialize(row).unwrap();
1488+
writer.serialize(row).unwrap();
14101489

14111490
let mut extra = BTreeMap::new();
1412-
extra.insert("extra3", 3.0);
1413-
extra.insert("extra4", 4.0);
1414-
extra.insert("extra5", 5.0);
1491+
extra.insert("extra1", 3.0);
1492+
extra.insert("extra2", 4.0);
1493+
extra.insert("extra3", 5.0);
14151494
let row = Row { x: 1.0, y: 2.0, extra };
1416-
let error = wtr.serialize(row).unwrap_err();
1495+
let error = writer.serialize(row).unwrap_err();
14171496
match *error.kind() {
14181497
ErrorKind::UnequalLengths {
14191498
pos: None,

src/writer.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,15 @@ pub struct Writer<W: io::Write> {
540540
state: WriterState,
541541
}
542542

543+
/// State for tracking headers while writing.
544+
#[derive(Debug)]
545+
struct HeaderTrackingState {
546+
/// The serialized headers in their expected order.
547+
expected_headers: Vec<Vec<u8>>,
548+
/// The index into the `expected_headers` list of the next expected header.
549+
next_expected_index: usize,
550+
}
551+
543552
#[derive(Debug)]
544553
struct WriterState {
545554
/// Whether the Serde serializer should attempt to write a header row.
@@ -557,6 +566,9 @@ struct WriterState {
557566
/// immediately after flushing the buffer. This avoids flushing the buffer
558567
/// twice if the inner writer panics.
559568
panicked: bool,
569+
/// Header tracking state for map like data, to ensure that column order
570+
/// is preserved across all rows.
571+
header_tracking: Option<HeaderTrackingState>,
560572
}
561573

562574
/// HeaderState encodes a small state machine for handling header writes.
@@ -638,6 +650,10 @@ impl<W: io::Write> Writer<W> {
638650
first_field_count: None,
639651
fields_written: 0,
640652
panicked: false,
653+
header_tracking: Some(HeaderTrackingState {
654+
expected_headers: Vec::new(),
655+
next_expected_index: 0,
656+
}),
641657
},
642658
}
643659
}
@@ -1180,6 +1196,47 @@ impl<W: io::Write> Writer<W> {
11801196
}
11811197
Ok(())
11821198
}
1199+
1200+
/// Track the `key` of a map entry. If this is not the first row, also verify that the
1201+
/// `key` matches the expected next key.
1202+
pub(crate) fn check_map_key<T: ?Sized + Serialize>(
1203+
&mut self,
1204+
key: &T,
1205+
) -> Result<()> {
1206+
let Some(tracking) = &mut self.state.header_tracking else {
1207+
return Ok(());
1208+
};
1209+
let mut encoded_key_serializer = Writer::from_writer(Vec::new());
1210+
serialize(&mut encoded_key_serializer, key)?;
1211+
let encoded_key =
1212+
encoded_key_serializer.into_inner().map_err(|error| {
1213+
Error::new(ErrorKind::Serialize(format!(
1214+
"Failed to serialize key to bytes: {error:?}"
1215+
)))
1216+
})?;
1217+
if let Some(expected_key) =
1218+
tracking.expected_headers.get(tracking.next_expected_index)
1219+
{
1220+
if expected_key != &encoded_key {
1221+
return Err(Error::new(ErrorKind::Serialize(format!(
1222+
"Out of order key `{}`",
1223+
String::from_utf8_lossy(&encoded_key)
1224+
))));
1225+
}
1226+
} else {
1227+
// Even if this is not the first row, accept more keys. If the writer is flexible then adding more fields is allowed.
1228+
tracking.expected_headers.push(encoded_key);
1229+
}
1230+
tracking.next_expected_index += 1;
1231+
Ok(())
1232+
}
1233+
1234+
/// Reset the map key tracking at the end of a row.
1235+
pub(crate) fn on_map_end(&mut self) {
1236+
if let Some(tracking) = &mut self.state.header_tracking {
1237+
tracking.next_expected_index = 0;
1238+
}
1239+
}
11831240
}
11841241

11851242
impl Buffer {

0 commit comments

Comments
 (0)