|
36 | 36 | except ImportError:
|
37 | 37 | pass
|
38 | 38 |
|
| 39 | +try: |
| 40 | + import cftime |
| 41 | +except ImportError: |
| 42 | + pass |
| 43 | + |
39 | 44 |
|
40 | 45 | @pytest.mark.flaky
|
41 | 46 | @pytest.mark.skip(reason="maybe flaky")
|
@@ -2145,3 +2150,52 @@ def test_yticks_kwarg(self, da):
|
2145 | 2150 | da.plot(yticks=np.arange(5))
|
2146 | 2151 | expected = np.arange(5)
|
2147 | 2152 | assert np.all(plt.gca().get_yticks() == expected)
|
| 2153 | + |
| 2154 | + |
| 2155 | +@requires_matplotlib |
| 2156 | +class TestDataArrayGroupByPlot: |
| 2157 | + @requires_cftime |
| 2158 | + @requires_nc_time_axis |
| 2159 | + @pytest.mark.parametrize( |
| 2160 | + "time", |
| 2161 | + [ |
| 2162 | + cftime.num2date( |
| 2163 | + np.arange(0, 730), "days since 0001-01-01 00:00:00", calendar="noleap" |
| 2164 | + ), |
| 2165 | + pd.date_range("2001-01-01", freq="12H", periods=730), |
| 2166 | + ], |
| 2167 | + ) |
| 2168 | + def test_time_grouping(self, time): |
| 2169 | + # create spatial coordinate |
| 2170 | + lev = np.arange(100) |
| 2171 | + |
| 2172 | + # Create sample Dataset |
| 2173 | + ds = Dataset( |
| 2174 | + { |
| 2175 | + "sample_data": (["time", "lev"], np.random.rand(time.size, lev.size)), |
| 2176 | + "independent_data": (["lev"], np.random.rand(lev.size)), |
| 2177 | + }, |
| 2178 | + coords={"time": (["time"], time), "lev": (["lev"], lev)}, |
| 2179 | + ) |
| 2180 | + |
| 2181 | + ds.sample_data.groupby("time.month").plot( |
| 2182 | + col="month", col_wrap=4, x="time", sharey=True |
| 2183 | + ) |
| 2184 | + |
| 2185 | + def test_stacked_groupby(self): |
| 2186 | + ds = Dataset( |
| 2187 | + { |
| 2188 | + "variable": ( |
| 2189 | + ("lat", "lon", "time"), |
| 2190 | + np.arange(60.0).reshape((4, 3, 5)), |
| 2191 | + ), |
| 2192 | + "id": (("lat", "lon"), np.arange(12.0).reshape((4, 3))), |
| 2193 | + }, |
| 2194 | + coords={ |
| 2195 | + "lat": np.arange(4), |
| 2196 | + "lon": np.arange(3), |
| 2197 | + "time": pd.date_range(start="2001-01-01", freq="D", periods=5), |
| 2198 | + }, |
| 2199 | + ) |
| 2200 | + |
| 2201 | + ds.variable.groupby(ds.id).plot.line(col="id", col_wrap=2) |
0 commit comments