Skip to content

Commit

Permalink
add column types to from_csv to override auto inference (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored Oct 14, 2024
1 parent cdab709 commit e0c654e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
18 changes: 17 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from datachain.utils import batched_it, inside_notebook

if TYPE_CHECKING:
from pyarrow import DataType as ArrowDataType
from typing_extensions import Concatenate, ParamSpec, Self

from datachain.lib.hf import HFDatasetType
Expand Down Expand Up @@ -1709,6 +1710,7 @@ def from_csv(
nrows=None,
session: Optional[Session] = None,
settings: Optional[dict] = None,
column_types: Optional[dict[str, "Union[str, ArrowDataType]"]] = None,
**kwargs,
) -> "DataChain":
"""Generate chain from csv files.
Expand All @@ -1727,6 +1729,9 @@ def from_csv(
nrows : Optional row limit.
session : Session to use for the chain.
settings : Settings to use for the chain.
column_types : Dictionary of column names and their corresponding types.
It is passed to CSV reader and for each column specified type auto
inference is disabled.
Example:
Reading a csv file:
Expand All @@ -1742,6 +1747,15 @@ def from_csv(
from pandas.io.parsers.readers import STR_NA_VALUES
from pyarrow.csv import ConvertOptions, ParseOptions, ReadOptions
from pyarrow.dataset import CsvFileFormat
from pyarrow.lib import type_for_alias

if column_types:
column_types = {
name: type_for_alias(typ) if isinstance(typ, str) else typ
for name, typ in column_types.items()
}
else:
column_types = {}

chain = DataChain.from_storage(
path, session=session, settings=settings, **kwargs
Expand All @@ -1767,7 +1781,9 @@ def from_csv(
parse_options = ParseOptions(delimiter=delimiter)
read_options = ReadOptions(column_names=column_names)
convert_options = ConvertOptions(
strings_can_be_null=True, null_values=STR_NA_VALUES
strings_can_be_null=True,
null_values=STR_NA_VALUES,
column_types=column_types,
)
format = CsvFileFormat(
parse_options=parse_options,
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,17 @@ def test_from_csv_nrows(tmp_dir, test_session):
assert df1.equals(df[:2])


def test_from_csv_column_types(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.csv"
df.to_csv(path, index=False)
dc = DataChain.from_csv(
path.as_uri(), column_types={"age": "str"}, session=test_session
)
df1 = dc.select("first_name", "age", "city").to_pandas()
assert df1["age"].dtype == pd.StringDtype


def test_to_csv_features(tmp_dir, test_session):
dc_to = DataChain.from_values(
f1=features, num=range(len(features)), session=test_session
Expand Down

0 comments on commit e0c654e

Please sign in to comment.