@@ -1653,6 +1653,74 @@ def test_set_dims_object_dtype(self):
1653
1653
expected = Variable (["x" ], exp_values )
1654
1654
assert_identical (actual , expected )
1655
1655
1656
+ def test_set_dims_without_broadcast (self ):
1657
+ class ArrayWithoutBroadcastTo (NDArrayMixin , indexing .ExplicitlyIndexed ):
1658
+ def __init__ (self , array ):
1659
+ self .array = array
1660
+
1661
+ # Broadcasting with __getitem__ is "easier" to implement
1662
+ # especially for dims of 1
1663
+ def __getitem__ (self , key ):
1664
+ return self .array [key ]
1665
+
1666
+ def __array_function__ (self , * args , ** kwargs ):
1667
+ raise NotImplementedError (
1668
+ "Not we don't want to use broadcast_to here "
1669
+ "https://github.com/pydata/xarray/issues/9462"
1670
+ )
1671
+
1672
+ arr = ArrayWithoutBroadcastTo (np .zeros ((3 , 4 )))
1673
+ # We should be able to add a new axis without broadcasting
1674
+ assert arr [np .newaxis , :, :].shape == (1 , 3 , 4 )
1675
+ with pytest .raises (NotImplementedError ):
1676
+ np .broadcast_to (arr , (1 , 3 , 4 ))
1677
+
1678
+ v = Variable (["x" , "y" ], arr )
1679
+ v_expanded = v .set_dims (["z" , "x" , "y" ])
1680
+ assert v_expanded .dims == ("z" , "x" , "y" )
1681
+ assert v_expanded .shape == (1 , 3 , 4 )
1682
+
1683
+ v_expanded = v .set_dims (["x" , "z" , "y" ])
1684
+ assert v_expanded .dims == ("x" , "z" , "y" )
1685
+ assert v_expanded .shape == (3 , 1 , 4 )
1686
+
1687
+ v_expanded = v .set_dims (["x" , "y" , "z" ])
1688
+ assert v_expanded .dims == ("x" , "y" , "z" )
1689
+ assert v_expanded .shape == (3 , 4 , 1 )
1690
+
1691
+ # Explicitly asking for a shape of 1 triggers a different
1692
+ # codepath in set_dims
1693
+ # https://github.com/pydata/xarray/issues/9462
1694
+ v_expanded = v .set_dims (["z" , "x" , "y" ], shape = (1 , 3 , 4 ))
1695
+ assert v_expanded .dims == ("z" , "x" , "y" )
1696
+ assert v_expanded .shape == (1 , 3 , 4 )
1697
+
1698
+ v_expanded = v .set_dims (["x" , "z" , "y" ], shape = (3 , 1 , 4 ))
1699
+ assert v_expanded .dims == ("x" , "z" , "y" )
1700
+ assert v_expanded .shape == (3 , 1 , 4 )
1701
+
1702
+ v_expanded = v .set_dims (["x" , "y" , "z" ], shape = (3 , 4 , 1 ))
1703
+ assert v_expanded .dims == ("x" , "y" , "z" )
1704
+ assert v_expanded .shape == (3 , 4 , 1 )
1705
+
1706
+ v_expanded = v .set_dims ({"z" : 1 , "x" : 3 , "y" : 4 })
1707
+ assert v_expanded .dims == ("z" , "x" , "y" )
1708
+ assert v_expanded .shape == (1 , 3 , 4 )
1709
+
1710
+ v_expanded = v .set_dims ({"x" : 3 , "z" : 1 , "y" : 4 })
1711
+ assert v_expanded .dims == ("x" , "z" , "y" )
1712
+ assert v_expanded .shape == (3 , 1 , 4 )
1713
+
1714
+ v_expanded = v .set_dims ({"x" : 3 , "y" : 4 , "z" : 1 })
1715
+ assert v_expanded .dims == ("x" , "y" , "z" )
1716
+ assert v_expanded .shape == (3 , 4 , 1 )
1717
+
1718
+ with pytest .raises (NotImplementedError ):
1719
+ v .set_dims ({"z" : 2 , "x" : 3 , "y" : 4 })
1720
+
1721
+ with pytest .raises (NotImplementedError ):
1722
+ v .set_dims (["z" , "x" , "y" ], shape = (2 , 3 , 4 ))
1723
+
1656
1724
def test_stack (self ):
1657
1725
v = Variable (["x" , "y" ], [[0 , 1 ], [2 , 3 ]], {"foo" : "bar" })
1658
1726
actual = v .stack (z = ("x" , "y" ))
0 commit comments