Skip to content

Commit 772756f

Browse files
Refactor, and fix function.
1 parent 112851a commit 772756f

File tree

1 file changed

+75
-31
lines changed

1 file changed

+75
-31
lines changed

pymc_marketing/mmm/plot.py

Lines changed: 75 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@
170170
"""
171171

172172
import itertools
173+
from collections import namedtuple
173174
from collections.abc import Iterable
174175

175176
import arviz as az
@@ -243,39 +244,82 @@ def _build_subplot_title(
243244
return ", ".join(title_parts)
244245
return fallback_title
245246

246-
def _align_y_axes(self, ax, ax2, include_zero=False):
247+
def _align_y_axes(self, ax_left, ax_right) -> None:
247248
"""Align y=0 of primary and secondary y-axis."""
248-
if ax.axes.get_ylim()[0] < 0 or ax2.axes.get_ylim()[0] < 0:
249-
ylims1 = ax.axes.get_ylim()
250-
ylims2 = ax2.axes.get_ylim()
251-
# Find the ratio of negative vs. positive part of the axes.
252-
if ylims1[1]:
253-
ax1_yratio = ylims1[0] / ylims1[1]
254-
else:
255-
# Fully negative axis.
256-
ax1_yratio = -1
249+
# Store limits of both axes in named tuples.
250+
YLimits = namedtuple("YLimits", ["bottom", "top"])
251+
ylims_left = YLimits(*ax_left.axes.get_ylim())
252+
ylims_right = YLimits(*ax_right.axes.get_ylim())
253+
254+
# Calculate the relative position of zero on both axes.
255+
# 0 means all values are positive (zero is at the bottom)
256+
# 1 means all values are negative (zero is at the top)
257+
zero_rel_pos_left = -ylims_left.bottom / (ylims_left.top - ylims_left.bottom)
258+
zero_rel_pos_right = -ylims_right.bottom / (
259+
ylims_right.top - ylims_right.bottom
260+
)
257261

258-
if ylims2[1]:
259-
ax2_yratio = ylims2[0] / ylims2[1]
262+
# If relative positions are equal, no action needed
263+
if zero_rel_pos_left == zero_rel_pos_right:
264+
return
265+
266+
# If both axes include mixed values, edit one to match the other by solving
267+
# rel_pos_other = (0 - new_bottom) / (top - new_bottom) for new_bottom.
268+
if zero_rel_pos_left not in [0, 1] and zero_rel_pos_right not in [0, 1]:
269+
if zero_rel_pos_left < zero_rel_pos_right:
270+
ax_left.set_ylim(
271+
bottom=-zero_rel_pos_right
272+
* ylims_left.top
273+
/ (1 - zero_rel_pos_right)
274+
)
260275
else:
261-
# Fully negative axis, may need to reflect the other
262-
ax2_yratio = -1
263-
264-
# Make axis adjustments. If both axes fully negative, no adjustment.
265-
if ax1_yratio < ax2_yratio:
266-
ax2.set_ylim(bottom=ylims2[1] * ax1_yratio)
267-
if ax1_yratio == -1:
268-
# if the axis is fully negative, center zero.
269-
ax.set_ylim(top=-ylims1[0])
270-
elif ax2_yratio < ax1_yratio:
271-
ax.set_ylim(bottom=ylims1[1] * ax2_yratio)
272-
if ax2_yratio == -1:
273-
# if the axis is fully negative, center zero.
274-
ax2.set_ylim(top=-ylims2[0])
275-
elif include_zero:
276-
# Ensure both axes start at zero
277-
ax.set_ylim(bottom=0)
278-
ax2.set_ylim(bottom=0)
276+
ax_right.set_ylim(
277+
bottom=-zero_rel_pos_left
278+
* ylims_right.top
279+
/ (1 - zero_rel_pos_left)
280+
)
281+
282+
# If one relative position is 1, edit the top by solving
283+
# rel_pos_other = (0 - bottom) / (new_top - bottom) for new_top.
284+
if zero_rel_pos_left == 1:
285+
# Left axis is all negative, right axis has positive values
286+
# if other axis is fully positive, place y=0 at the center of this axis.
287+
ax_left.set_ylim(
288+
top=ylims_left.bottom * (1 - 1 / zero_rel_pos_right)
289+
if zero_rel_pos_right
290+
else -ylims_left.bottom
291+
)
292+
# Update lims and zero_rel_pos in case we need to edit the other axis.
293+
ylims_left = YLimits(*ax_left.axes.get_ylim())
294+
zero_rel_pos_left = -ylims_left.bottom / (
295+
ylims_left.top - ylims_left.bottom
296+
)
297+
elif zero_rel_pos_right == 1:
298+
# Right axis is all negative, left axis has positive values
299+
# if other axis is fully positive, place y=0 at the center of this axis.
300+
ax_right.set_ylim(
301+
top=ylims_right.bottom * (1 - 1 / zero_rel_pos_left)
302+
if zero_rel_pos_left
303+
else -ylims_right.bottom
304+
)
305+
# Update lims and zero_rel_pos in case we need to edit the other axis.
306+
ylims_right = YLimits(*ax_right.axes.get_ylim())
307+
zero_rel_pos_right = -ylims_right.bottom / (
308+
ylims_right.top - ylims_right.bottom
309+
)
310+
311+
# If one relative position is 0, edit bottom by solving
312+
# rel_pos_other = (0 - new_bottom) / (top - new_bottom) for new_bottom.
313+
if zero_rel_pos_left == 0:
314+
# Left axis is all positive, right axis has negative values
315+
ax_left.set_ylim(
316+
bottom=-zero_rel_pos_right * ylims_left.top / (1 - zero_rel_pos_right)
317+
)
318+
elif zero_rel_pos_right == 0:
319+
# Right axis is all positive, left axis has negative values
320+
ax_right.set_ylim(
321+
bottom=-zero_rel_pos_left * ylims_right.top / (1 - zero_rel_pos_left)
322+
)
279323

280324
def _get_additional_dim_combinations(
281325
self,
@@ -1225,7 +1269,7 @@ def _plot_budget_allocation_bars(
12251269
ax.tick_params(axis="x", rotation=90)
12261270

12271271
# Ensure that y=0 are aligned between ax and ax2.
1228-
self._align_y_axes(ax, ax2, include_zero=True)
1272+
self._align_y_axes(ax, ax2)
12291273

12301274
# Turn off grid and add legend
12311275
ax.grid(False)

0 commit comments

Comments
 (0)