3737from xarray .core .types import Dims , QuantileMethods , T_DataArray , T_Xarray
3838from xarray .core .utils import (
3939 either_dict_or_kwargs ,
40+ emit_user_level_warning ,
4041 hashable ,
4142 is_scalar ,
4243 maybe_wrap_array ,
@@ -73,6 +74,21 @@ def check_reduce_dims(reduce_dims, dimensions):
7374 )
7475
7576
77+ def _maybe_squeeze_indices (
78+ indices , squeeze : bool | None , grouper : ResolvedGrouper , warn : bool
79+ ):
80+ if squeeze in [None , True ] and grouper .can_squeeze :
81+ if squeeze is None and warn :
82+ emit_user_level_warning (
83+ "The `squeeze` kwarg to GroupBy is being removed."
84+ "Pass .groupby(..., squeeze=False) to silence this warning."
85+ )
86+ if isinstance (indices , slice ):
87+ assert indices .stop - indices .start == 1
88+ indices = indices .start
89+ return indices
90+
91+
7692def unique_value_groups (
7793 ar , sort : bool = True
7894) -> tuple [np .ndarray | pd .Index , T_GroupIndices , np .ndarray ]:
@@ -366,10 +382,10 @@ def dims(self):
366382 return self .group1d .dims
367383
368384 @abstractmethod
369- def _factorize (self , squeeze : bool ) -> T_FactorizeOut :
385+ def factorize (self ) -> T_FactorizeOut :
370386 raise NotImplementedError
371387
372- def factorize (self , squeeze : bool ) -> None :
388+ def _factorize (self ) -> None :
373389 # This design makes it clear to mypy that
374390 # codes, group_indices, unique_coord, and full_index
375391 # are set by the factorize method on the derived class.
@@ -378,7 +394,7 @@ def factorize(self, squeeze: bool) -> None:
378394 self .group_indices ,
379395 self .unique_coord ,
380396 self .full_index ,
381- ) = self ._factorize ( squeeze )
397+ ) = self .factorize ( )
382398
383399 @property
384400 def is_unique_and_monotonic (self ) -> bool :
@@ -393,15 +409,19 @@ def group_as_index(self) -> pd.Index:
393409 self ._group_as_index = self .group1d .to_index ()
394410 return self ._group_as_index
395411
412+ @property
413+ def can_squeeze (self ):
414+ is_dimension = self .group .dims == (self .group .name ,)
415+ return is_dimension and self .is_unique_and_monotonic
416+
396417
397418@dataclass
398419class ResolvedUniqueGrouper (ResolvedGrouper ):
399420 grouper : UniqueGrouper
400421
401- def _factorize (self , squeeze ) -> T_FactorizeOut :
402- is_dimension = self .group .dims == (self .group .name ,)
403- if is_dimension and self .is_unique_and_monotonic :
404- return self ._factorize_dummy (squeeze )
422+ def factorize (self ) -> T_FactorizeOut :
423+ if self .can_squeeze :
424+ return self ._factorize_dummy ()
405425 else :
406426 return self ._factorize_unique ()
407427
@@ -424,15 +444,12 @@ def _factorize_unique(self) -> T_FactorizeOut:
424444
425445 return codes , group_indices , unique_coord , full_index
426446
427- def _factorize_dummy (self , squeeze ) -> T_FactorizeOut :
447+ def _factorize_dummy (self ) -> T_FactorizeOut :
428448 size = self .group .size
429449 # no need to factorize
430- if not squeeze :
431- # use slices to do views instead of fancy indexing
432- # equivalent to: group_indices = group_indices.reshape(-1, 1)
433- group_indices : T_GroupIndices = [slice (i , i + 1 ) for i in range (size )]
434- else :
435- group_indices = list (range (size ))
450+ # use slices to do views instead of fancy indexing
451+ # equivalent to: group_indices = group_indices.reshape(-1, 1)
452+ group_indices : T_GroupIndices = [slice (i , i + 1 ) for i in range (size )]
436453 size_range = np .arange (size )
437454 if isinstance (self .group , _DummyGroup ):
438455 codes = self .group .to_dataarray ().copy (data = size_range )
@@ -448,7 +465,7 @@ def _factorize_dummy(self, squeeze) -> T_FactorizeOut:
448465class ResolvedBinGrouper (ResolvedGrouper ):
449466 grouper : BinGrouper
450467
451- def _factorize (self , squeeze : bool ) -> T_FactorizeOut :
468+ def factorize (self ) -> T_FactorizeOut :
452469 from xarray .core .dataarray import DataArray
453470
454471 data = self .group1d .values
@@ -546,7 +563,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
546563 _apply_loffset (self .grouper .loffset , first_items )
547564 return first_items , codes
548565
549- def _factorize (self , squeeze : bool ) -> T_FactorizeOut :
566+ def factorize (self ) -> T_FactorizeOut :
550567 full_index , first_items , codes_ = self ._get_index_and_items ()
551568 sbins = first_items .values .astype (np .int64 )
552569 group_indices : T_GroupIndices = [
@@ -591,14 +608,14 @@ class TimeResampleGrouper(Grouper):
591608 loffset : datetime .timedelta | str | None
592609
593610
594- def _validate_groupby_squeeze (squeeze : bool ) -> None :
611+ def _validate_groupby_squeeze (squeeze : bool | None ) -> None :
595612 # While we don't generally check the type of every arg, passing
596613 # multiple dimensions as multiple arguments is common enough, and the
597614 # consequences hidden enough (strings evaluate as true) to warrant
598615 # checking here.
599616 # A future version could make squeeze kwarg only, but would face
600617 # backward-compat issues.
601- if not isinstance (squeeze , bool ):
618+ if squeeze is not None and not isinstance (squeeze , bool ):
602619 raise TypeError (f"`squeeze` must be True or False, but { squeeze } was supplied" )
603620
604621
@@ -730,7 +747,7 @@ def __init__(
730747 self ._original_obj = obj
731748
732749 for grouper_ in self .groupers :
733- grouper_ .factorize ( squeeze )
750+ grouper_ ._factorize ( )
734751
735752 (grouper ,) = self .groupers
736753 self ._original_group = grouper .group
@@ -762,9 +779,14 @@ def sizes(self) -> Mapping[Hashable, int]:
762779 Dataset.sizes
763780 """
764781 if self ._sizes is None :
765- self ._sizes = self ._obj .isel (
766- {self ._group_dim : self ._group_indices [0 ]}
767- ).sizes
782+ (grouper ,) = self .groupers
783+ index = _maybe_squeeze_indices (
784+ self ._group_indices [0 ],
785+ self ._squeeze ,
786+ grouper ,
787+ warn = True ,
788+ )
789+ self ._sizes = self ._obj .isel ({self ._group_dim : index }).sizes
768790
769791 return self ._sizes
770792
@@ -798,14 +820,22 @@ def groups(self) -> dict[GroupKey, GroupIndex]:
798820 # provided to mimic pandas.groupby
799821 if self ._groups is None :
800822 (grouper ,) = self .groupers
801- self ._groups = dict (zip (grouper .unique_coord .values , self ._group_indices ))
823+ squeezed_indices = (
824+ _maybe_squeeze_indices (ind , self ._squeeze , grouper , warn = idx > 0 )
825+ for idx , ind in enumerate (self ._group_indices )
826+ )
827+ self ._groups = dict (zip (grouper .unique_coord .values , squeezed_indices ))
802828 return self ._groups
803829
804830 def __getitem__ (self , key : GroupKey ) -> T_Xarray :
805831 """
806832 Get DataArray or Dataset corresponding to a particular group label.
807833 """
808- return self ._obj .isel ({self ._group_dim : self .groups [key ]})
834+ (grouper ,) = self .groupers
835+ index = _maybe_squeeze_indices (
836+ self .groups [key ], self ._squeeze , grouper , warn = True
837+ )
838+ return self ._obj .isel ({self ._group_dim : index })
809839
810840 def __len__ (self ) -> int :
811841 (grouper ,) = self .groupers
@@ -826,7 +856,11 @@ def __repr__(self) -> str:
826856
827857 def _iter_grouped (self ) -> Iterator [T_Xarray ]:
828858 """Iterate over each element in this group"""
829- for indices in self ._group_indices :
859+ (grouper ,) = self .groupers
860+ for idx , indices in enumerate (self ._group_indices ):
861+ indices = _maybe_squeeze_indices (
862+ indices , self ._squeeze , grouper , warn = idx > 0
863+ )
830864 yield self ._obj .isel ({self ._group_dim : indices })
831865
832866 def _infer_concat_args (self , applied_example ):
@@ -1309,7 +1343,11 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):
13091343 @property
13101344 def dims (self ) -> tuple [Hashable , ...]:
13111345 if self ._dims is None :
1312- self ._dims = self ._obj .isel ({self ._group_dim : self ._group_indices [0 ]}).dims
1346+ (grouper ,) = self .groupers
1347+ index = _maybe_squeeze_indices (
1348+ self ._group_indices [0 ], self ._squeeze , grouper , warn = True
1349+ )
1350+ self ._dims = self ._obj .isel ({self ._group_dim : index }).dims
13131351
13141352 return self ._dims
13151353
@@ -1318,7 +1356,11 @@ def _iter_grouped_shortcut(self):
13181356 metadata
13191357 """
13201358 var = self ._obj .variable
1321- for indices in self ._group_indices :
1359+ (grouper ,) = self .groupers
1360+ for idx , indices in enumerate (self ._group_indices ):
1361+ indices = _maybe_squeeze_indices (
1362+ indices , self ._squeeze , grouper , warn = idx > 0
1363+ )
13221364 yield var [{self ._group_dim : indices }]
13231365
13241366 def _concat_shortcut (self , applied , dim , positions = None ):
@@ -1517,7 +1559,14 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):
15171559 @property
15181560 def dims (self ) -> Frozen [Hashable , int ]:
15191561 if self ._dims is None :
1520- self ._dims = self ._obj .isel ({self ._group_dim : self ._group_indices [0 ]}).dims
1562+ (grouper ,) = self .groupers
1563+ index = _maybe_squeeze_indices (
1564+ self ._group_indices [0 ],
1565+ self ._squeeze ,
1566+ grouper ,
1567+ warn = True ,
1568+ )
1569+ self ._dims = self ._obj .isel ({self ._group_dim : index }).dims
15211570
15221571 return self ._dims
15231572
0 commit comments