Skip to content

DRAFT: remove all OpenGL<x> mobs and the ConvertToOpenGL metaclass. #2454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions example_scenes/opengl.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,13 @@ def construct(self):
night_texture = script_location / "assets" / "1280px-The_earth_at_night.jpg"

surfaces = [
OpenGLTexturedSurface(surface, day_texture, night_texture)
TexturedSurface(surface, day_texture, night_texture)
for surface in [sphere, torus1, torus2]
]

for mob in surfaces:
mob.shift(IN)
mob.mesh = OpenGLSurfaceMesh(mob)
mob.mesh = SurfaceMesh(mob)
mob.mesh.set_stroke(BLUE, 1, opacity=0.5)

# Set perspective
Expand Down
1 change: 0 additions & 1 deletion manim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
from .mobject.three_dimensions import *
from .mobject.types.dot_cloud import *
from .mobject.types.image_mobject import *
from .mobject.types.opengl_point_cloud_mobject import *
from .mobject.types.point_cloud_mobject import *
from .mobject.types.vectorized_mobject import *
from .mobject.value_tracker import *
Expand Down
30 changes: 0 additions & 30 deletions manim/_config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,36 +1172,6 @@ def renderer(self):
@renderer.setter
def renderer(self, val: str) -> None:
"""Renderer for animations."""
try:
from ..mobject.mobject import Mobject
from ..mobject.opengl_compatibility import ConvertToOpenGL
from ..mobject.opengl_mobject import OpenGLMobject
from ..mobject.types.opengl_vectorized_mobject import OpenGLVMobject
from ..mobject.types.vectorized_mobject import VMobject

for cls in ConvertToOpenGL._converted_classes:
if val == "opengl":
conversion_dict = {
Mobject: OpenGLMobject,
VMobject: OpenGLVMobject,
}
else:
conversion_dict = {
OpenGLMobject: Mobject,
OpenGLVMobject: VMobject,
}

cls.__bases__ = tuple(
conversion_dict.get(base, base) for base in cls.__bases__
)
except ImportError:
# The renderer is set during the initial import of the
# library for the first time. The imports above cause an
# ImportError due to circular imports. However, the
# metaclass sets stuff up correctly in this case, so we
# can just do nothing.
pass

self._set_from_list(
"renderer",
val,
Expand Down
20 changes: 4 additions & 16 deletions manim/animation/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@


from .. import config, logger
from ..mobject import mobject, opengl_mobject
from ..mobject.mobject import Mobject
from ..mobject.opengl_mobject import OpenGLMobject
from ..utils.rate_functions import smooth

__all__ = ["Animation", "Wait", "override_animation"]
Expand Down Expand Up @@ -141,14 +139,8 @@ def __init__(
self.remover: bool = remover
self.suspend_mobject_updating: bool = suspend_mobject_updating
self.lag_ratio: float = lag_ratio
if config["renderer"] == "opengl":
self.starting_mobject: OpenGLMobject = OpenGLMobject()
self.mobject: OpenGLMobject = (
mobject if mobject is not None else OpenGLMobject()
)
else:
self.starting_mobject: Mobject = Mobject()
self.mobject: Mobject = mobject if mobject is not None else Mobject()
self.starting_mobject: Mobject = Mobject()
self.mobject: Mobject = mobject if mobject is not None else Mobject()
if kwargs:
logger.debug("Animation received extra kwargs: %s", kwargs)

Expand All @@ -163,7 +155,7 @@ def __init__(
def _typecheck_input(self, mobject: Union[Mobject, None]) -> None:
if mobject is None:
logger.debug("Animation with empty mobject")
elif not isinstance(mobject, (Mobject, OpenGLMobject)):
elif not isinstance(mobject, Mobject):
raise TypeError("Animation only works on Mobjects")

def __str__(self) -> str:
Expand Down Expand Up @@ -237,11 +229,7 @@ def get_all_mobjects(self) -> Sequence[Mobject]:
return self.mobject, self.starting_mobject

def get_all_families_zipped(self) -> Iterable[Tuple]:
if config["renderer"] == "opengl":
return zip(*(mob.get_family() for mob in self.get_all_mobjects()))
return zip(
*(mob.family_members_with_points() for mob in self.get_all_mobjects())
)
return zip(*(mob.get_family() for mob in self.get_all_mobjects()))

def update_mobjects(self, dt: float) -> None:
"""
Expand Down
9 changes: 2 additions & 7 deletions manim/animation/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from .._config import config
from ..animation.animation import Animation, prepare_animation
from ..mobject.mobject import Group, Mobject
from ..mobject.opengl_mobject import OpenGLGroup
from ..scene.scene import Scene
from ..utils.iterables import remove_list_redundancies
from ..utils.rate_functions import linear

if TYPE_CHECKING:
from ..mobject.types.opengl_vectorized_mobject import OpenGLVGroup
from ..mobject.types.vectorized_mobject import VGroup

__all__ = ["AnimationGroup", "Succession", "LaggedStart", "LaggedStartMap"]
Expand All @@ -27,7 +25,7 @@ class AnimationGroup(Animation):
def __init__(
self,
*animations: Animation,
group: Union[Group, "VGroup", OpenGLGroup, "OpenGLVGroup"] = None,
group: Union[Group, "VGroup"] = None,
run_time: Optional[float] = None,
rate_func: Callable[[float], float] = linear,
lag_ratio: float = 0,
Expand All @@ -39,10 +37,7 @@ def __init__(
mobjects = remove_list_redundancies(
[anim.mobject for anim in self.animations],
)
if config["renderer"] == "opengl":
self.group = OpenGLGroup(*mobjects)
else:
self.group = Group(*mobjects)
self.group = Group(*mobjects)
super().__init__(self.group, rate_func=rate_func, lag_ratio=lag_ratio, **kwargs)
self.run_time: float = self.init_run_time(run_time)

Expand Down
23 changes: 10 additions & 13 deletions manim/animation/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def construct(self):
from ..animation.animation import Animation
from ..animation.composition import Succession
from ..mobject.mobject import Group, Mobject
from ..mobject.types.opengl_surface import OpenGLSurface
from ..mobject.types.opengl_vectorized_mobject import OpenGLVMobject
from ..mobject.types.surface import Surface
from ..mobject.types.vectorized_mobject import VMobject
from ..utils.bezier import integer_interpolate
from ..utils.rate_functions import double_smooth, linear, smooth
Expand All @@ -112,9 +111,7 @@ class ShowPartial(Animation):

"""

def __init__(
self, mobject: Union[VMobject, OpenGLVMobject, OpenGLSurface, None], **kwargs
):
def __init__(self, mobject: Optional[Union[VMobject, Surface]], **kwargs):
pointwise = getattr(mobject, "pointwise_become_partial", None)
if not callable(pointwise):
raise NotImplementedError("This animation is not defined for this Mobject.")
Expand Down Expand Up @@ -163,7 +160,7 @@ def construct(self):

def __init__(
self,
mobject: Union[VMobject, OpenGLVMobject, OpenGLSurface],
mobject: Union[VMobject, Surface],
lag_ratio: float = 1.0,
**kwargs,
) -> None:
Expand Down Expand Up @@ -192,7 +189,7 @@ def construct(self):

def __init__(
self,
mobject: Union[VMobject, OpenGLVMobject],
mobject: VMobject,
rate_func: Callable[[float], float] = lambda t: smooth(1 - t),
remover: bool = True,
**kwargs,
Expand All @@ -214,7 +211,7 @@ def construct(self):

def __init__(
self,
vmobject: Union[VMobject, OpenGLVMobject],
vmobject: VMobject,
run_time: float = 2,
rate_func: Callable[[float], float] = double_smooth,
stroke_width: float = 2,
Expand All @@ -231,8 +228,8 @@ def __init__(
self.fill_animation_config = fill_animation_config
self.outline = self.get_outline()

def _typecheck_input(self, vmobject: Union[VMobject, OpenGLVMobject]) -> None:
if not isinstance(vmobject, (VMobject, OpenGLVMobject)):
def _typecheck_input(self, vmobject: VMobject) -> None:
if not isinstance(vmobject, VMobject):
raise TypeError("DrawBorderThenFill only works for vectorized Mobjects")

def begin(self) -> None:
Expand All @@ -246,7 +243,7 @@ def get_outline(self) -> Mobject:
sm.set_stroke(color=self.get_stroke_color(sm), width=self.stroke_width)
return outline

def get_stroke_color(self, vmobject: Union[VMobject, OpenGLVMobject]) -> Color:
def get_stroke_color(self, vmobject: VMobject) -> Color:
if self.stroke_color:
return self.stroke_color
elif vmobject.get_stroke_width() > 0:
Expand Down Expand Up @@ -293,7 +290,7 @@ def construct(self):

def __init__(
self,
vmobject: Union[VMobject, OpenGLVMobject],
vmobject: VMobject,
rate_func: Callable[[float], float] = linear,
reverse: bool = False,
**kwargs,
Expand All @@ -316,7 +313,7 @@ def __init__(

def _set_default_config_from_length(
self,
vmobject: Union[VMobject, OpenGLVMobject],
vmobject: VMobject,
run_time: Optional[float],
lag_ratio: Optional[float],
) -> Tuple[float, float]:
Expand Down
6 changes: 2 additions & 4 deletions manim/animation/fading.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ def construct(self):

import numpy as np

from manim.mobject.opengl_mobject import OpenGLMobject

from ..animation.transform import Transform
from ..constants import DOWN, ORIGIN
from ..constants import ORIGIN
from ..mobject.mobject import Group, Mobject
from ..scene.scene import Scene

Expand Down Expand Up @@ -66,7 +64,7 @@ def __init__(
self.point_target = False
if shift is None:
if target_position is not None:
if isinstance(target_position, (Mobject, OpenGLMobject)):
if isinstance(target_position, Mobject):
target_position = target_position.get_center()
shift = target_position - mobject.get_center()
self.point_target = True
Expand Down
19 changes: 5 additions & 14 deletions manim/animation/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from ..animation.animation import Animation
from ..constants import DEFAULT_POINTWISE_FUNCTION_RUN_TIME, DEGREES, ORIGIN, OUT
from ..mobject.mobject import Group, Mobject
from ..mobject.opengl_mobject import OpenGLGroup, OpenGLMobject
from ..utils.paths import path_along_arc, path_along_circles
from ..utils.rate_functions import smooth, squish_rate_func

Expand Down Expand Up @@ -114,10 +113,7 @@ def begin(self) -> None:
self.target_copy = self.target_mobject.copy()
# Note, this potentially changes the structure
# of both mobject and target_mobject
if config["renderer"] == "opengl":
self.mobject.align_data_and_family(self.target_copy)
else:
self.mobject.align_data(self.target_copy)
self.mobject.align_data_and_family(self.target_copy)
super().begin()

def create_target(self) -> Mobject:
Expand Down Expand Up @@ -145,9 +141,7 @@ def get_all_families_zipped(self) -> Iterable[tuple]: # more precise typing?
self.starting_mobject,
self.target_copy,
]
if config["renderer"] == "opengl":
return zip(*(mob.get_family() for mob in mobs))
return zip(*(mob.family_members_with_points() for mob in mobs))
return zip(*(mob.get_family() for mob in mobs))

def interpolate_submobject(
self,
Expand Down Expand Up @@ -299,7 +293,7 @@ def check_validity_of_input(self, method: Callable) -> None:
"Whoops, looks like you accidentally invoked "
"the method you want to animate",
)
assert isinstance(method.__self__, (Mobject, OpenGLMobject))
assert isinstance(method.__self__, Mobject)

def create_target(self) -> Mobject:
method = self.method
Expand Down Expand Up @@ -383,7 +377,7 @@ def __init__(self, function: types.MethodType, mobject: Mobject, **kwargs) -> No

def create_target(self) -> Any:
target = self.function(self.mobject.copy())
if not isinstance(target, (Mobject, OpenGLMobject)):
if not isinstance(target, Mobject):
raise TypeError(
"Functions passed to ApplyFunction must return object of type Mobject",
)
Expand Down Expand Up @@ -554,10 +548,7 @@ def __init__(self, mobject, target_mobject, stretch=True, dim_to_match=1, **kwar
self.stretch = stretch
self.dim_to_match = dim_to_match
mobject.save_state()
if config["renderer"] == "opengl":
group = OpenGLGroup(mobject, target_mobject.copy())
else:
group = Group(mobject, target_mobject.copy())
group = Group(mobject, target_mobject.copy())
super().__init__(group, **kwargs)

def begin(self):
Expand Down
13 changes: 2 additions & 11 deletions manim/animation/transform_matching_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from .._config import config
from ..mobject.mobject import Group, Mobject
from ..mobject.opengl_mobject import OpenGLGroup, OpenGLMobject
from ..mobject.types.opengl_vectorized_mobject import OpenGLVGroup, OpenGLVMobject
from ..mobject.types.vectorized_mobject import VGroup, VMobject
from .composition import AnimationGroup
from .fading import FadeIn, FadeOut
Expand Down Expand Up @@ -73,11 +71,7 @@ def __init__(
**kwargs
):

if isinstance(mobject, OpenGLVMobject):
group_type = OpenGLVGroup
elif isinstance(mobject, OpenGLMobject):
group_type = OpenGLGroup
elif isinstance(mobject, VMobject):
if isinstance(mobject, VMobject):
group_type = VGroup
else:
group_type = Group
Expand Down Expand Up @@ -142,10 +136,7 @@ def get_shape_map(self, mobject: "Mobject") -> dict:
for sm in self.get_mobject_parts(mobject):
key = self.get_mobject_key(sm)
if key not in shape_map:
if config["renderer"] == "opengl":
shape_map[key] = OpenGLVGroup()
else:
shape_map[key] = VGroup()
shape_map[key] = VGroup()
shape_map[key].add(sm)
return shape_map

Expand Down
2 changes: 1 addition & 1 deletion manim/camera/three_d_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from .. import config
from ..camera.camera import Camera
from ..constants import *
from ..mobject.mobject import Point
from ..mobject.three_d_utils import (
get_3d_vmob_end_corner,
get_3d_vmob_end_corner_unit_normal,
get_3d_vmob_start_corner,
get_3d_vmob_start_corner_unit_normal,
)
from ..mobject.types.point_cloud_mobject import Point
from ..mobject.value_tracker import ValueTracker
from ..utils.color import get_shaded_rgb
from ..utils.family import extract_mobject_family_members
Expand Down
Loading