@@ -698,3 +698,53 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
698698
699699 def decode (self , variable : Variable , name : T_Name = None ) -> Variable :
700700 raise NotImplementedError ()
701+
702+
703+ class IntervalCoder (VariableCoder ):
704+ """
705+ Xarray-specific Interval Coder to roundtrip 1D pd.IntervalArray objects.
706+ """
707+
708+ encoded_dtype = "pandas_interval"
709+ encoded_bounds_dim = "__xarray_bounds__"
710+
711+ def encode (self , variable : Variable , name : T_Name = None ) -> Variable :
712+ if isinstance (dtype := variable .dtype , pd .IntervalDtype ):
713+ dims , data , attrs , encoding = unpack_for_encoding (variable )
714+
715+ new_data = np .stack ([data .left , data .right ], axis = 0 )
716+ dims = (self .encoded_bounds_dim , * dims )
717+ safe_setitem (attrs , "closed" , dtype .closed , name = name )
718+ safe_setitem (attrs , "dtype" , self .encoded_dtype , name = name )
719+ safe_setitem (attrs , "bounds_dim" , self .encoded_bounds_dim , name = name )
720+ return Variable (dims , new_data , attrs , encoding , fastpath = True )
721+ else :
722+ return variable
723+
724+ def decode (self , variable : Variable , name : T_Name = None ) -> Variable :
725+ if (
726+ variable .attrs .get ("dtype" , None ) == self .encoded_dtype
727+ and self .encoded_bounds_dim in variable .dims
728+ ):
729+ if variable .ndim != 2 :
730+ raise ValueError (
731+ f"Cannot decode intervals for variable named { name !r} with more than two dimensions."
732+ )
733+
734+ dims , data , attrs , encoding = unpack_for_decoding (variable )
735+ pop_to (attrs , encoding , "dtype" , name = name )
736+ pop_to (attrs , encoding , "bounds_dim" , name = name )
737+ closed = pop_to (attrs , encoding , "closed" , name = name )
738+
739+ _ , new_dims = variable .dims
740+ variable = variable .load ()
741+ new_data = pd .arrays .IntervalArray .from_arrays (
742+ variable .isel ({self .encoded_bounds_dim : 0 }).data ,
743+ variable .isel ({self .encoded_bounds_dim : 1 }).data ,
744+ closed = closed ,
745+ )
746+ return Variable (
747+ dims = new_dims , data = new_data , attrs = attrs , encoding = encoding
748+ )
749+ else :
750+ return variable
0 commit comments