| 
170 | 170 | """  | 
171 | 171 | 
 
  | 
172 | 172 | import itertools  | 
 | 173 | +from collections import namedtuple  | 
173 | 174 | from collections.abc import Iterable  | 
174 | 175 | 
 
  | 
175 | 176 | import arviz as az  | 
@@ -243,39 +244,82 @@ def _build_subplot_title(  | 
243 | 244 |             return ", ".join(title_parts)  | 
244 | 245 |         return fallback_title  | 
245 | 246 | 
 
  | 
246 |  | -    def _align_y_axes(self, ax, ax2, include_zero=False):  | 
 | 247 | +    def _align_y_axes(self, ax_left, ax_right) -> None:  | 
247 | 248 |         """Align y=0 of primary and secondary y-axis."""  | 
248 |  | -        if ax.axes.get_ylim()[0] < 0 or ax2.axes.get_ylim()[0] < 0:  | 
249 |  | -            ylims1 = ax.axes.get_ylim()  | 
250 |  | -            ylims2 = ax2.axes.get_ylim()  | 
251 |  | -            # Find the ratio of negative vs. positive part of the axes.  | 
252 |  | -            if ylims1[1]:  | 
253 |  | -                ax1_yratio = ylims1[0] / ylims1[1]  | 
254 |  | -            else:  | 
255 |  | -                # Fully negative axis.  | 
256 |  | -                ax1_yratio = -1  | 
 | 249 | +        # Store limits of both axes in named tuples.  | 
 | 250 | +        YLimits = namedtuple("YLimits", ["bottom", "top"])  | 
 | 251 | +        ylims_left = YLimits(*ax_left.axes.get_ylim())  | 
 | 252 | +        ylims_right = YLimits(*ax_right.axes.get_ylim())  | 
 | 253 | + | 
 | 254 | +        # Calculate the relative position of zero on both axes.  | 
 | 255 | +        # 0 means all values are positive (zero is at the bottom)  | 
 | 256 | +        # 1 means all values are negative (zero is at the top)  | 
 | 257 | +        zero_rel_pos_left = -ylims_left.bottom / (ylims_left.top - ylims_left.bottom)  | 
 | 258 | +        zero_rel_pos_right = -ylims_right.bottom / (  | 
 | 259 | +            ylims_right.top - ylims_right.bottom  | 
 | 260 | +        )  | 
257 | 261 | 
 
  | 
258 |  | -            if ylims2[1]:  | 
259 |  | -                ax2_yratio = ylims2[0] / ylims2[1]  | 
 | 262 | +        # If relative positions are equal, no action needed  | 
 | 263 | +        if zero_rel_pos_left == zero_rel_pos_right:  | 
 | 264 | +            return  | 
 | 265 | + | 
 | 266 | +        # If both axes include mixed values, edit one to match the other by solving  | 
 | 267 | +        # rel_pos_other = (0 - new_bottom) / (top - new_bottom) for new_bottom.  | 
 | 268 | +        if zero_rel_pos_left not in [0, 1] and zero_rel_pos_right not in [0, 1]:  | 
 | 269 | +            if zero_rel_pos_left < zero_rel_pos_right:  | 
 | 270 | +                ax_left.set_ylim(  | 
 | 271 | +                    bottom=-zero_rel_pos_right  | 
 | 272 | +                    * ylims_left.top  | 
 | 273 | +                    / (1 - zero_rel_pos_right)  | 
 | 274 | +                )  | 
260 | 275 |             else:  | 
261 |  | -                # Fully negative axis, may need to reflect the other  | 
262 |  | -                ax2_yratio = -1  | 
263 |  | - | 
264 |  | -            # Make axis adjustments. If both axes fully negative, no adjustment.  | 
265 |  | -            if ax1_yratio < ax2_yratio:  | 
266 |  | -                ax2.set_ylim(bottom=ylims2[1] * ax1_yratio)  | 
267 |  | -                if ax1_yratio == -1:  | 
268 |  | -                    # if the axis is fully negative, center zero.  | 
269 |  | -                    ax.set_ylim(top=-ylims1[0])  | 
270 |  | -            elif ax2_yratio < ax1_yratio:  | 
271 |  | -                ax.set_ylim(bottom=ylims1[1] * ax2_yratio)  | 
272 |  | -                if ax2_yratio == -1:  | 
273 |  | -                    # if the axis is fully negative, center zero.  | 
274 |  | -                    ax2.set_ylim(top=-ylims2[0])  | 
275 |  | -        elif include_zero:  | 
276 |  | -            # Ensure both axes start at zero  | 
277 |  | -            ax.set_ylim(bottom=0)  | 
278 |  | -            ax2.set_ylim(bottom=0)  | 
 | 276 | +                ax_right.set_ylim(  | 
 | 277 | +                    bottom=-zero_rel_pos_left  | 
 | 278 | +                    * ylims_right.top  | 
 | 279 | +                    / (1 - zero_rel_pos_left)  | 
 | 280 | +                )  | 
 | 281 | + | 
 | 282 | +        # If one relative position is 1, edit the top by solving  | 
 | 283 | +        # rel_pos_other = (0 - bottom) / (new_top - bottom) for new_top.  | 
 | 284 | +        if zero_rel_pos_left == 1:  | 
 | 285 | +            # Left axis is all negative, right axis has positive values  | 
 | 286 | +            # if other axis is fully positive, place y=0 at the center of this axis.  | 
 | 287 | +            ax_left.set_ylim(  | 
 | 288 | +                top=ylims_left.bottom * (1 - 1 / zero_rel_pos_right)  | 
 | 289 | +                if zero_rel_pos_right  | 
 | 290 | +                else -ylims_left.bottom  | 
 | 291 | +            )  | 
 | 292 | +            # Update lims and zero_rel_pos in case we need to edit the other axis.  | 
 | 293 | +            ylims_left = YLimits(*ax_left.axes.get_ylim())  | 
 | 294 | +            zero_rel_pos_left = -ylims_left.bottom / (  | 
 | 295 | +                ylims_left.top - ylims_left.bottom  | 
 | 296 | +            )  | 
 | 297 | +        elif zero_rel_pos_right == 1:  | 
 | 298 | +            # Right axis is all negative, left axis has positive values  | 
 | 299 | +            # if other axis is fully positive, place y=0 at the center of this axis.  | 
 | 300 | +            ax_right.set_ylim(  | 
 | 301 | +                top=ylims_right.bottom * (1 - 1 / zero_rel_pos_left)  | 
 | 302 | +                if zero_rel_pos_left  | 
 | 303 | +                else -ylims_right.bottom  | 
 | 304 | +            )  | 
 | 305 | +            # Update lims and zero_rel_pos in case we need to edit the other axis.  | 
 | 306 | +            ylims_right = YLimits(*ax_right.axes.get_ylim())  | 
 | 307 | +            zero_rel_pos_right = -ylims_right.bottom / (  | 
 | 308 | +                ylims_right.top - ylims_right.bottom  | 
 | 309 | +            )  | 
 | 310 | + | 
 | 311 | +        # If one relative position is 0, edit bottom by solving  | 
 | 312 | +        # rel_pos_other = (0 - new_bottom) / (top - new_bottom) for new_bottom.  | 
 | 313 | +        if zero_rel_pos_left == 0:  | 
 | 314 | +            # Left axis is all positive, right axis has negative values  | 
 | 315 | +            ax_left.set_ylim(  | 
 | 316 | +                bottom=-zero_rel_pos_right * ylims_left.top / (1 - zero_rel_pos_right)  | 
 | 317 | +            )  | 
 | 318 | +        elif zero_rel_pos_right == 0:  | 
 | 319 | +            # Right axis is all positive, left axis has negative values  | 
 | 320 | +            ax_right.set_ylim(  | 
 | 321 | +                bottom=-zero_rel_pos_left * ylims_right.top / (1 - zero_rel_pos_left)  | 
 | 322 | +            )  | 
279 | 323 | 
 
  | 
280 | 324 |     def _get_additional_dim_combinations(  | 
281 | 325 |         self,  | 
@@ -1225,7 +1269,7 @@ def _plot_budget_allocation_bars(  | 
1225 | 1269 |         ax.tick_params(axis="x", rotation=90)  | 
1226 | 1270 | 
 
  | 
1227 | 1271 |         # Ensure that y=0 are aligned between ax and ax2.  | 
1228 |  | -        self._align_y_axes(ax, ax2, include_zero=True)  | 
 | 1272 | +        self._align_y_axes(ax, ax2)  | 
1229 | 1273 | 
 
  | 
1230 | 1274 |         # Turn off grid and add legend  | 
1231 | 1275 |         ax.grid(False)  | 
 | 
0 commit comments