diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 0db2833d2c1aa..cd2311e614eec 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -33,6 +33,7 @@ TimestampNTZType, DayTimeIntervalType, YearMonthIntervalType, + CalendarIntervalType, MapType, StringType, CharType, @@ -169,6 +170,8 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: elif isinstance(data_type, YearMonthIntervalType): ret.year_month_interval.start_field = data_type.startField ret.year_month_interval.end_field = data_type.endField + elif isinstance(data_type, CalendarIntervalType): + ret.calendar_interval.CopyFrom(pb2.DataType.CalendarInterval()) elif isinstance(data_type, StructType): struct = pb2.DataType.Struct() for field in data_type.fields: @@ -265,6 +268,8 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: else None ) return YearMonthIntervalType(startField=start, endField=end) + elif schema.HasField("calendar_interval"): + return CalendarIntervalType() elif schema.HasField("array"): return ArrayType( proto_schema_to_pyspark_data_type(schema.array.element_type), diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index 533506c7d2743..44171fd61a35b 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -86,7 +86,7 @@ def test_rdd_with_udt(self): def test_udt(self): super().test_udt() - @unittest.skip("SPARK-45018: should support CalendarIntervalType") + @unittest.skip("SPARK-45026: spark.sql should support datatypes not compatible with arrow") def test_calendar_interval_type(self): super().test_calendar_interval_type() diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index d45c4d7e808de..fb752b93a3308 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -1284,6 +1284,10 @@ def test_calendar_interval_type(self): schema1 = self.spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001)").schema self.assertEqual(schema1.fields[0].dataType, CalendarIntervalType()) + def test_calendar_interval_type_with_sf(self): + schema1 = self.spark.range(1).select(F.make_interval(F.lit(1))).schema + self.assertEqual(schema1.fields[0].dataType, CalendarIntervalType()) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055