Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 49 additions & 19 deletions dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Type,
NamedTuple,
TypeVar,
cast,
)

from dlt.common.json import json
Expand Down Expand Up @@ -569,24 +570,21 @@ def write_data(self, items: Sequence[TDataItem]) -> None:
for item in items:
if isinstance(item, (pyarrow.Table, pyarrow.RecordBatch)):
if not self.writer:
if self.quoting == "quote_needed":
quoting = "needed"
elif self.quoting == "quote_all":
quoting = "all_valid"
elif self.quoting == "quote_none":
quoting = "none"
else:
raise ValueError(self.quoting)
try:
self.writer = pyarrow.csv.CSVWriter(
self._f,
item.schema,
write_options=pyarrow.csv.WriteOptions(
include_header=self.include_header,
# set include_header to False to handle header separately until
# https://github.com/apache/arrow/issues/47575 is released
# see _make_csv_header() for details
include_header=False,
delimiter=self._delimiter_b,
quoting_style=quoting,
quoting_style=self._get_arrow_quoting_style(),
),
)
if self.include_header:
self._f.write(self._make_csv_header())
self._first_schema = item.schema
except pyarrow.ArrowInvalid as inv_ex:
if "Unsupported Type" in str(inv_ex):
Expand Down Expand Up @@ -636,16 +634,10 @@ def write_data(self, items: Sequence[TDataItem]) -> None:
self.items_count += item.num_rows

def write_footer(self) -> None:
default_arrow_line_terminator = b"\n"
if self.writer is None and self.include_header:
# write empty file
self._f.write(
self._delimiter_b.join(
[
b'"' + col["name"].encode("utf-8") + b'"'
for col in self._columns_schema.values()
]
)
)
# empty file: emit only the header line (no data rows)
self._f.write(self._make_csv_header().rstrip(default_arrow_line_terminator))

def close(self) -> None:
if self.writer:
Expand All @@ -665,6 +657,44 @@ def writer_spec(cls) -> FileWriterSpec:
supports_compression=True,
)

def _get_arrow_quoting_style(self) -> str:
if self.quoting == "quote_needed":
return "needed"
elif self.quoting == "quote_all":
return "all_valid"
elif self.quoting == "quote_none":
return "none"
else:
raise ValueError(self.quoting)

def _make_csv_header(self) -> bytes:
# In pyarrow 21.0.0, the CSVWriter does not support specifying the header quote style.
# This is a workaround to create a header which respects the quote style.
# See https://github.com/apache/arrow/issues/47575 for details.
# This needs to be removed once https://github.com/apache/arrow/issues/47575 is released.
from dlt.common.libs.pyarrow import pyarrow
import pyarrow.csv

names = [col["name"] for col in self._columns_schema.values()]
arrays = [pyarrow.array([n]) for n in names]
schema = pyarrow.schema([pyarrow.field(n, pyarrow.string()) for n in names])
table = pyarrow.Table.from_arrays(arrays, schema=schema)

# Write into an in-memory Arrow sink so schema doesn't affect the real writer
sink = pyarrow.BufferOutputStream()
header_writer = pyarrow.csv.CSVWriter(
sink,
schema,
write_options=pyarrow.csv.WriteOptions(
include_header=False,
delimiter=self._delimiter_b,
quoting_style=self._get_arrow_quoting_style(),
),
)
header_writer.write(table)
header_writer.close()
return cast(bytes, sink.getvalue().to_pybytes())


class ArrowToObjectAdapter:
"""A mixin that will convert object writer into arrow writer."""
Expand Down
155 changes: 153 additions & 2 deletions tests/libs/test_csv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,12 @@ def test_arrow_csv_writer_quoting_parameters(quoting: CsvQuoting) -> None:
writer.write_header(mock_schema)
writer.write_data([test_data])

mock_csv_writer.assert_called_once()
mock_csv_writer.assert_called()
call_args = mock_csv_writer.call_args
write_options = call_args.kwargs["write_options"]
assert write_options.quoting_style == expected_quoting_mapping[quoting]

mock_writer_instance.write.assert_called_once_with(test_data)
mock_writer_instance.write.assert_called_with(test_data)


def test_arrow_csv_writer_quote_none_with_special_characters() -> None:
Expand Down Expand Up @@ -339,3 +339,154 @@ def test_csv_lineterminator(test_case: Dict[str, str]) -> None:
with open(writer.closed_files[0].file_path, "rb") as f:
content = f.read()
assert content == expected


@pytest.mark.parametrize(
"quoting,delimiter,schema,test_data_dict,expected_header,expected_data_rows",
[
pytest.param(
"quote_none",
",",
{
"col1": {"name": "col1", "data_type": "text"},
"col2": {"name": "col2", "data_type": "bigint"},
},
{
"col1": ["test_value", "another_value"],
"col2": [123, 456],
},
"col1,col2",
["test_value,123", "another_value,456"],
id="quote_none_with_data",
),
pytest.param(
"quote_all",
",",
{
"col1": {"name": "col1", "data_type": "text"},
"col2": {"name": "col2", "data_type": "bigint"},
},
{"col1": ["value1"], "col2": [123]},
'"col1","col2"',
['"value1","123"'],
id="quote_all_with_data",
),
pytest.param(
"quote_needed",
",",
{
"col1": {"name": "col1", "data_type": "text"},
"2": {"name": "2", "data_type": "bigint"},
},
{"col1": ["value1"], "col2": [123]},
'"col1","2"',
['"value1",123'],
id="quote_needed_with_data",
),
pytest.param(
"quote_none",
",",
{
"col1": {"name": "col1", "data_type": "text"},
"col2": {"name": "col2", "data_type": "bigint"},
},
None,
"col1,col2",
[],
id="quote_none_empty_file",
),
pytest.param(
"quote_none",
"|",
{
"col1": {"name": "col1", "data_type": "text"},
"col2": {"name": "col2", "data_type": "bigint"},
},
{"col1": ["value1"], "col2": [123]},
"col1|col2",
["value1|123"],
id="quote_none_custom_delimiter",
),
],
)
def test_arrow_csv_writer(
quoting: CsvQuoting,
delimiter: str,
schema: TTableSchemaColumns,
test_data_dict,
expected_header: str,
expected_data_rows,
) -> None:
# Test ArrowToCsvWriter header generation with various quoting styles for both header and data rows
import tempfile
import pyarrow as pa

with tempfile.NamedTemporaryFile(mode="w+b", suffix=".csv") as f:
writer = ArrowToCsvWriter(f, quoting=quoting, include_header=True, delimiter=delimiter)
writer.write_header(schema)

if test_data_dict is not None:
test_data = pa.table(test_data_dict)
writer.write_data([test_data])
else:
writer.write_footer()

f.flush()
f.seek(0)
lines = f.read().decode("utf-8").splitlines()

assert lines[0].strip() == expected_header

if expected_data_rows is not None and len(expected_data_rows) > 0:
for i, expected_row in enumerate(expected_data_rows):
assert lines[i + 1].strip() == expected_row


def test_arrow_csv_writer_special_chars_in_column_names_quote_none() -> None:
import tempfile
import pyarrow as pa
from pyarrow.lib import ArrowInvalid

# Column names with commas should fail with quote_none
mock_schema: TTableSchemaColumns = {
"col,with,comma": {"name": "col,with,comma", "data_type": "text"},
}

test_data = pa.table({"col,with,comma": ["value1"]})

with tempfile.NamedTemporaryFile(mode="wb", suffix=".csv") as f:
writer = ArrowToCsvWriter(f, quoting="quote_none", include_header=True)
writer.write_header(mock_schema)
with pytest.raises(ArrowInvalid, match="CSV values may not contain structural characters"):
writer.write_data([test_data])


def test_arrow_csv_writer_empty_schema() -> None:
import tempfile

mock_schema: TTableSchemaColumns = {}

with tempfile.NamedTemporaryFile(mode="w+b", suffix=".csv") as f:
writer = ArrowToCsvWriter(f, quoting="quote_none", include_header=True)
writer.write_header(mock_schema)
# this triggers _make_csv_header() with an empty schema
writer.write_footer()
f.flush()
f.seek(0)
assert f.read() == b""


def test_arrow_csv_writer_invalid_quoting_parameter() -> None:
import tempfile

mock_schema: TTableSchemaColumns = {
"col1": {"name": "col1", "data_type": "text"},
}

with tempfile.NamedTemporaryFile(mode="wb", suffix=".csv") as f:
writer = ArrowToCsvWriter(f, quoting="quote_none", include_header=True)
writer.write_header(mock_schema)
# Manually set invalid quoting to trigger the error in _make_csv_header
writer.quoting = "invalid_quoting_style" # type: ignore[assignment]
with pytest.raises(ValueError, match="invalid_quoting_style"):
writer.write_footer()
Loading