Skip to content

Optimized interpolate and bezier in manim.utils.bezier #3960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 153 additions & 39 deletions manim/utils/bezier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import numpy as np

from manim.typing import PointDType
from manim.utils.simple_functions import choose

if TYPE_CHECKING:
Expand All @@ -35,6 +34,7 @@
from manim.typing import (
BezierPoints,
BezierPoints_Array,
ColVector,
MatrixMN,
Point3D,
Point3D_Array,
Expand All @@ -45,47 +45,127 @@
# ruff: noqa: E741


@overload
def bezier(
points: Sequence[Point3D] | Point3D_Array,
) -> Callable[[float], Point3D]:
"""Classic implementation of a bezier curve.
points: BezierPoints,
) -> Callable[[float | ColVector], Point3D | Point3D_Array]: ...


@overload
def bezier(
points: Sequence[Point3D_Array],
) -> Callable[[float | ColVector], Point3D_Array]: ...


def bezier(points):
"""Classic implementation of a Bézier curve.

Parameters
----------
points
points defining the desired bezier curve.
:math:`(d+1, 3)`-shaped array of :math:`d+1` control points defining a single Bézier
curve of degree :math:`d`. Alternatively, for vectorization purposes, ``points`` can
also be a :math:`(d+1, M, 3)`-shaped sequence of :math:`d+1` arrays of :math:`M`
control points each, which define `M` Bézier curves instead.

Returns
-------
function describing the bezier curve.
You can pass a t value between 0 and 1 to get the corresponding point on the curve.
bezier_func : :class:`typing.Callable` [[:class:`float` | :class:`~.ColVector`], :class:`~.Point3D` | :class:`~.Point3D_Array`]
Function describing the Bézier curve. The behaviour of this function depends on
the shape of ``points``:

* If ``points`` was a :math:`(d+1, 3)` array representing a single Bézier curve,
then ``bezier_func`` can receive either:

* a :class:`float` ``t``, in which case it returns a
single :math:`(1, 3)`-shaped :class:`~.Point3D` representing the evaluation
of the Bézier at ``t``, or
* an :math:`(n, 1)`-shaped :class:`~.ColVector`
containing :math:`n` values to evaluate the Bézier curve at, returning instead
an :math:`(n, 3)`-shaped :class:`~.Point3D_Array` containing the points
resulting from evaluating the Bézier at each of the :math:`n` values.
.. warning::
If passing a vector of :math:`t`-values to ``bezier_func``, it **must**
be a column vector/matrix of shape :math:`(n, 1)`. Passing an 1D array of
shape :math:`(n,)` is not supported and **will result in undefined behaviour**.

* If ``points`` was a :math:`(d+1, M, 3)` array describing :math:`M` Bézier curves,
then ``bezier_func`` can receive either:

* a :class:`float` ``t``, in which case it returns an
:math:`(M, 3)`-shaped :class:`~.Point3D_Array` representing the evaluation
of the :math:`M` Bézier curves at the same value ``t``, or
* an :math:`(M, 1)`-shaped
:class:`~.ColVector` containing :math:`M` values, such that the :math:`i`-th
Bézier curve defined by ``points`` is evaluated at the corresponding :math:`i`-th
value in ``t``, returning again an :math:`(M, 3)`-shaped :class:`~.Point3D_Array`
containing those :math:`M` evaluations.
.. warning::
Unlike the previous case, if you pass a :class:`~.ColVector` to ``bezier_func``,
it **must** contain exactly :math:`M` values, each value for each of the :math:`M`
Bézier curves defined by ``points``. Any array of shape other than :math:`(M, 1)`
**will result in undefined behaviour**.
"""
n = len(points) - 1
# Cubic Bezier curve
if n == 3:
return lambda t: np.asarray(
(1 - t) ** 3 * points[0]
+ 3 * t * (1 - t) ** 2 * points[1]
+ 3 * (1 - t) * t**2 * points[2]
+ t**3 * points[3],
dtype=PointDType,
)
# Quadratic Bezier curve
if n == 2:
return lambda t: np.asarray(
(1 - t) ** 2 * points[0] + 2 * t * (1 - t) * points[1] + t**2 * points[2],
dtype=PointDType,
)
P = np.asarray(points)
degree = P.shape[0] - 1

return lambda t: np.asarray(
np.asarray(
[
(((1 - t) ** (n - k)) * (t**k) * choose(n, k) * point)
for k, point in enumerate(points)
],
dtype=PointDType,
).sum(axis=0)
)
if degree == 0:

def zero_bezier(t):
return np.ones_like(t) * P[0]

return zero_bezier

if degree == 1:

def linear_bezier(t):
return P[0] + t * (P[1] - P[0])

return linear_bezier

if degree == 2:

def quadratic_bezier(t):
t2 = t * t
mt = 1 - t
mt2 = mt * mt
return mt2 * P[0] + 2 * t * mt * P[1] + t2 * P[2]

return quadratic_bezier

if degree == 3:

def cubic_bezier(t):
t2 = t * t
t3 = t2 * t
mt = 1 - t
mt2 = mt * mt
mt3 = mt2 * mt
return mt3 * P[0] + 3 * t * mt2 * P[1] + 3 * t2 * mt * P[2] + t3 * P[3]

return cubic_bezier

def nth_grade_bezier(t):
is_scalar = not isinstance(t, np.ndarray)
if is_scalar:
B = np.empty((1, *P.shape))
else:
t = t.reshape(-1, *[1 for dim in P.shape])
B = np.empty((t.shape[0], *P.shape))
B[:] = P

for i in range(degree):
# After the i-th iteration (i in [0, ..., d-1]) there are evaluations at t
# of (d-i) Bezier curves of grade (i+1), stored in the first d-i slots of B
B[:, : degree - i] += t * (B[:, 1 : degree - i + 1] - B[:, : degree - i])

# In the end, there shall be the evaluation at t of a single Bezier curve of
# grade d, stored in the first slot of B
if is_scalar:
return B[0, 0]
return B[:, 0]

return nth_grade_bezier


def partial_bezier_points(points: BezierPoints, a: float, b: float) -> BezierPoints:
Expand Down Expand Up @@ -874,9 +954,10 @@ def bezier_remap(
An array of multiple Bézier curves of degree :math:`d` to be remapped. The shape of this array
must be ``(current_number_of_curves, nppc, dim)``, where:

* ``current_number_of_curves`` is the current amount of curves in the array ``bezier_tuples``,
* ``nppc`` is the amount of points per curve, such that their degree is ``nppc-1``, and
* ``dim`` is the dimension of the points, usually :math:`3`.
* ``current_number_of_curves`` is the current amount of curves in the array ``bezier_tuples``,
* ``nppc`` is the amount of points per curve, such that their degree is ``nppc-1``, and
* ``dim`` is the dimension of the points, usually :math:`3`.

new_number_of_curves
The number of curves that the output will contain. This needs to be higher than the current number.

Expand Down Expand Up @@ -926,14 +1007,47 @@ def bezier_remap(
def interpolate(start: float, end: float, alpha: float) -> float: ...


@overload
def interpolate(start: float, end: float, alpha: ColVector) -> ColVector: ...


@overload
def interpolate(start: Point3D, end: Point3D, alpha: float) -> Point3D: ...


def interpolate(
start: int | float | Point3D, end: int | float | Point3D, alpha: float | Point3D
) -> float | Point3D:
return (1 - alpha) * start + alpha * end
@overload
def interpolate(start: Point3D, end: Point3D, alpha: ColVector) -> Point3D_Array: ...


def interpolate(start, end, alpha):
"""Linearly interpolates between two values ``start`` and ``end``.

Parameters
----------
start
The start of the range.
end
The end of the range.
alpha
A float between 0 and 1, or an :math:`(n, 1)` column vector containing
:math:`n` floats between 0 and 1 to interpolate in a vectorized fashion.

Returns
-------
:class:`float` | :class:`~.ColVector` | :class:`~.Point3D` | :class:`~.Point3D_Array`
The result of the linear interpolation.

* If ``start`` and ``end`` are of type :class:`float`, and:

* ``alpha`` is also a :class:`float`, the return is simply another :class:`float`.
* ``alpha`` is a :class:`~.ColVector`, the return is another :class:`~.ColVector`.

* If ``start`` and ``end`` are of type :class:`~.Point3D`, and:

* ``alpha`` is a :class:`float`, the return is another :class:`~.Point3D`.
* ``alpha`` is a :class:`~.ColVector`, the return is a :class:`~.Point3D_Array`.
"""
return start + alpha * (end - start)


def integer_interpolate(
Expand Down
Loading