1919
2020from __future__ import annotations
2121
22+ import warnings
2223from typing import TYPE_CHECKING , Any , Protocol
2324
25+ import pyarrow as pa
26+
2427try :
2528 from warnings import deprecated # Python 3.13+
2629except ImportError :
4245
4346 import pandas as pd
4447 import polars as pl
45- import pyarrow as pa
4648
4749 from datafusion .plan import ExecutionPlan , LogicalPlan
4850
@@ -539,7 +541,7 @@ def register_listing_table(
539541 self ,
540542 name : str ,
541543 path : str | pathlib .Path ,
542- table_partition_cols : list [tuple [str , str ]] | None = None ,
544+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
543545 file_extension : str = ".parquet" ,
544546 schema : pa .Schema | None = None ,
545547 file_sort_order : list [list [Expr | SortExpr ]] | None = None ,
@@ -560,6 +562,7 @@ def register_listing_table(
560562 """
561563 if table_partition_cols is None :
562564 table_partition_cols = []
565+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
563566 file_sort_order_raw = (
564567 [sort_list_to_raw_sort_list (f ) for f in file_sort_order ]
565568 if file_sort_order is not None
@@ -778,7 +781,7 @@ def register_parquet(
778781 self ,
779782 name : str ,
780783 path : str | pathlib .Path ,
781- table_partition_cols : list [tuple [str , str ]] | None = None ,
784+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
782785 parquet_pruning : bool = True ,
783786 file_extension : str = ".parquet" ,
784787 skip_metadata : bool = True ,
@@ -806,6 +809,7 @@ def register_parquet(
806809 """
807810 if table_partition_cols is None :
808811 table_partition_cols = []
812+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
809813 self .ctx .register_parquet (
810814 name ,
811815 str (path ),
@@ -869,7 +873,7 @@ def register_json(
869873 schema : pa .Schema | None = None ,
870874 schema_infer_max_records : int = 1000 ,
871875 file_extension : str = ".json" ,
872- table_partition_cols : list [tuple [str , str ]] | None = None ,
876+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
873877 file_compression_type : str | None = None ,
874878 ) -> None :
875879 """Register a JSON file as a table.
@@ -890,6 +894,7 @@ def register_json(
890894 """
891895 if table_partition_cols is None :
892896 table_partition_cols = []
897+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
893898 self .ctx .register_json (
894899 name ,
895900 str (path ),
@@ -906,7 +911,7 @@ def register_avro(
906911 path : str | pathlib .Path ,
907912 schema : pa .Schema | None = None ,
908913 file_extension : str = ".avro" ,
909- table_partition_cols : list [tuple [str , str ]] | None = None ,
914+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
910915 ) -> None :
911916 """Register an Avro file as a table.
912917
@@ -922,6 +927,7 @@ def register_avro(
922927 """
923928 if table_partition_cols is None :
924929 table_partition_cols = []
930+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
925931 self .ctx .register_avro (
926932 name , str (path ), schema , file_extension , table_partition_cols
927933 )
@@ -981,7 +987,7 @@ def read_json(
981987 schema : pa .Schema | None = None ,
982988 schema_infer_max_records : int = 1000 ,
983989 file_extension : str = ".json" ,
984- table_partition_cols : list [tuple [str , str ]] | None = None ,
990+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
985991 file_compression_type : str | None = None ,
986992 ) -> DataFrame :
987993 """Read a line-delimited JSON data source.
@@ -1001,6 +1007,7 @@ def read_json(
10011007 """
10021008 if table_partition_cols is None :
10031009 table_partition_cols = []
1010+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
10041011 return DataFrame (
10051012 self .ctx .read_json (
10061013 str (path ),
@@ -1020,7 +1027,7 @@ def read_csv(
10201027 delimiter : str = "," ,
10211028 schema_infer_max_records : int = 1000 ,
10221029 file_extension : str = ".csv" ,
1023- table_partition_cols : list [tuple [str , str ]] | None = None ,
1030+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
10241031 file_compression_type : str | None = None ,
10251032 ) -> DataFrame :
10261033 """Read a CSV data source.
@@ -1045,6 +1052,7 @@ def read_csv(
10451052 """
10461053 if table_partition_cols is None :
10471054 table_partition_cols = []
1055+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
10481056
10491057 path = [str (p ) for p in path ] if isinstance (path , list ) else str (path )
10501058
@@ -1064,7 +1072,7 @@ def read_csv(
10641072 def read_parquet (
10651073 self ,
10661074 path : str | pathlib .Path ,
1067- table_partition_cols : list [tuple [str , str ]] | None = None ,
1075+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
10681076 parquet_pruning : bool = True ,
10691077 file_extension : str = ".parquet" ,
10701078 skip_metadata : bool = True ,
@@ -1093,6 +1101,7 @@ def read_parquet(
10931101 """
10941102 if table_partition_cols is None :
10951103 table_partition_cols = []
1104+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
10961105 file_sort_order = (
10971106 [sort_list_to_raw_sort_list (f ) for f in file_sort_order ]
10981107 if file_sort_order is not None
@@ -1114,7 +1123,7 @@ def read_avro(
11141123 self ,
11151124 path : str | pathlib .Path ,
11161125 schema : pa .Schema | None = None ,
1117- file_partition_cols : list [tuple [str , str ]] | None = None ,
1126+ file_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
11181127 file_extension : str = ".avro" ,
11191128 ) -> DataFrame :
11201129 """Create a :py:class:`DataFrame` for reading Avro data source.
@@ -1130,6 +1139,7 @@ def read_avro(
11301139 """
11311140 if file_partition_cols is None :
11321141 file_partition_cols = []
1142+ file_partition_cols = self ._convert_table_partition_cols (file_partition_cols )
11331143 return DataFrame (
11341144 self .ctx .read_avro (str (path ), schema , file_partition_cols , file_extension )
11351145 )
@@ -1146,3 +1156,41 @@ def read_table(self, table: Table) -> DataFrame:
11461156 def execute (self , plan : ExecutionPlan , partitions : int ) -> RecordBatchStream :
11471157 """Execute the ``plan`` and return the results."""
11481158 return RecordBatchStream (self .ctx .execute (plan ._raw_plan , partitions ))
1159+
1160+ @staticmethod
1161+ def _convert_table_partition_cols (
1162+ table_partition_cols : list [tuple [str , str | pa .DataType ]],
1163+ ) -> list [tuple [str , pa .DataType ]]:
1164+ warn = False
1165+ converted_table_partition_cols = []
1166+
1167+ for col , data_type in table_partition_cols :
1168+ if isinstance (data_type , str ):
1169+ warn = True
1170+ if data_type == "string" :
1171+ converted_data_type = pa .string ()
1172+ elif data_type == "int" :
1173+ converted_data_type = pa .int32 ()
1174+ else :
1175+ message = (
1176+ f"Unsupported literal data type '{ data_type } ' for partition "
1177+ "column. Supported types are 'string' and 'int'"
1178+ )
1179+ raise ValueError (message )
1180+ else :
1181+ converted_data_type = data_type
1182+
1183+ converted_table_partition_cols .append ((col , converted_data_type ))
1184+
1185+ if warn :
1186+ message = (
1187+ "using literals for table_partition_cols data types is deprecated,"
1188+ "use pyarrow types instead"
1189+ )
1190+ warnings .warn (
1191+ message ,
1192+ category = DeprecationWarning ,
1193+ stacklevel = 2 ,
1194+ )
1195+
1196+ return converted_table_partition_cols
0 commit comments