Skip to content

Commit 476ac4e

Browse files
leleogereLéo Géréhydrobeampre-commit-ci[bot]Léo Géré
authored
Fix :meth:~.input_to_graph_point when passing a line_graph (#1994)
* Fix the bounds of the binary search * Fix the binary_search when the point in on a bound (return the proportion instead of the function value) * Add an error when x is not in the graph range * Allow to pass a VMobject to input_to_graph_point (type of a line_graph?) * Modify the test to reflect fixed bug * Update according to pre-commit * Add the exception in the docstring * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Convert a string into f-string Co-authored-by: Léo Géré <leo.gere@inrae.fr> Co-authored-by: Laith Bahodi <70682032+hydrobeam@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Léo Géré <gere@etud.insa-toulouse.fr>
1 parent 58092a7 commit 476ac4e

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

manim/mobject/coordinate_systems.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@
3333
)
3434
from ..mobject.number_line import NumberLine
3535
from ..mobject.svg.tex_mobject import MathTex
36-
from ..mobject.types.vectorized_mobject import Mobject, VDict, VectorizedPoint, VGroup
36+
from ..mobject.types.vectorized_mobject import (
37+
Mobject,
38+
VDict,
39+
VectorizedPoint,
40+
VGroup,
41+
VMobject,
42+
)
3743
from ..utils.color import (
3844
BLACK,
3945
BLUE,
@@ -699,7 +705,11 @@ def get_parametric_curve(self, function, **kwargs):
699705
graph.underlying_function = function
700706
return graph
701707

702-
def input_to_graph_point(self, x: float, graph: "ParametricFunction") -> np.ndarray:
708+
def input_to_graph_point(
709+
self,
710+
x: float,
711+
graph: Union["ParametricFunction", VMobject],
712+
) -> np.ndarray:
703713
"""Returns the coordinates of the point on a ``graph`` corresponding to an ``x`` value.
704714
705715
Examples
@@ -730,6 +740,11 @@ def construct(self):
730740
-------
731741
:class:`np.ndarray`
732742
The coordinates of the point on the :attr:`graph` corresponding to the :attr:`x` value.
743+
744+
Raises
745+
------
746+
:exc:`ValueError`
747+
When the target x is not in the range of the line graph.
733748
"""
734749

735750
if hasattr(graph, "underlying_function"):
@@ -740,13 +755,15 @@ def construct(self):
740755
0
741756
],
742757
target=x,
743-
lower_bound=self.x_range[0],
744-
upper_bound=self.x_range[1],
758+
lower_bound=0,
759+
upper_bound=1,
745760
)
746761
if alpha is not None:
747762
return graph.point_from_proportion(alpha)
748763
else:
749-
return None
764+
raise ValueError(
765+
f"x={x} not located in the range of the graph ([{self.p2c(graph.get_start())[0]}, {self.p2c(graph.get_end())[0]}])",
766+
)
750767

751768
def i2gp(self, x: float, graph: "ParametricFunction") -> np.ndarray:
752769
"""

manim/utils/simple_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def binary_search(function, target, lower_bound, upper_bound, tolerance=1e-4):
9595
mh = np.mean([lh, rh])
9696
lx, mx, rx = (function(h) for h in (lh, mh, rh))
9797
if lx == target:
98-
return lx
98+
return lh
9999
if rx == target:
100-
return rx
100+
return rh
101101

102102
if lx <= target and rx >= target:
103103
if mx > target:

tests/test_coordinate_system.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,14 @@ def test_coords_to_point():
107107
def test_input_to_graph_point():
108108
ax = Axes()
109109
curve = ax.get_graph(lambda x: np.cos(x))
110+
line_graph = ax.get_line_graph([1, 3, 5], [-1, 2, -2], add_vertex_dots=False)[
111+
"line_graph"
112+
]
110113

111114
# move a square to PI on the cosine curve.
112115
position = np.around(ax.input_to_graph_point(x=PI, graph=curve), decimals=4)
113116
assert np.array_equal(position, (2.6928, -0.75, 0))
117+
118+
# test the line_graph implementation
119+
position = np.around(ax.input_to_graph_point(x=PI, graph=line_graph), decimals=4)
120+
assert np.array_equal(position, (2.6928, 1.2876, 0))

0 commit comments

Comments
 (0)