Skip to content

Commit bd78c43

Browse files
committed
Make the svg groups available
1 parent 1a9212b commit bd78c43

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

manim/mobject/svg/svg_mobject.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __init__(
127127
self.stroke_color = stroke_color
128128
self.stroke_opacity = stroke_opacity # type: ignore[assignment]
129129
self.stroke_width = stroke_width # type: ignore[assignment]
130+
self.id_to_vgroup_dict: dict[str, VGroup] = {}
130131
if self.stroke_width is None:
131132
self.stroke_width = 0
132133

@@ -203,8 +204,9 @@ def generate_mobject(self) -> None:
203204
svg = se.SVG.parse(modified_file_path)
204205
modified_file_path.unlink()
205206

206-
mobjects = self.get_mobjects_from(svg)
207+
mobjects, mobject_dict = self.get_mobjects_from(svg)
207208
self.add(*mobjects)
209+
self.id_to_vgroup_dict = mobject_dict
208210
self.flip(RIGHT) # Flip y
209211

210212
def get_file_path(self) -> Path:
@@ -258,7 +260,9 @@ def generate_config_style_dict(self) -> dict[str, str]:
258260
result[svg_key] = str(svg_default_dict[style_key])
259261
return result
260262

261-
def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]:
263+
def get_mobjects_from(
264+
self, svg: se.SVG
265+
) -> tuple[list[VMobject], dict[str, VGroup]]:
262266
"""Convert the elements of the SVG to a list of mobjects.
263267
264268
Parameters
@@ -282,11 +286,10 @@ def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]:
282286
except Exception:
283287
group_name = f"numbered_group_{group_id_number}"
284288
group_id_number += 1
285-
if isinstance(element, se.Group):
286-
vg = VGroup()
287-
vgroup_names.append(group_name)
288-
vgroup_stack.append(group_name)
289-
vgroups[group_name] = vg
289+
vg = VGroup()
290+
vgroup_names.append(group_name)
291+
vgroup_stack.append(group_name)
292+
vgroups[group_name] = vg
290293

291294
if isinstance(element, (se.Group, se.Use)):
292295
for subelement in element[::-1]:
@@ -310,11 +313,12 @@ def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]:
310313
if mob is not None:
311314
result.append(mob)
312315
parent_name = vgroup_stack[-2]
313-
vgroups[parent_name].add(vgroups[group_name])
316+
for parent_name in vgroup_stack[:-2]:
317+
vgroups[parent_name].add(mob)
314318
except Exception as e:
315319
print(e)
316320

317-
return result
321+
return result, vgroups
318322

319323
def get_mob_from_shape_element(self, shape: se.SVGElement) -> VMobject | None:
320324
if isinstance(shape, se.Path):

0 commit comments

Comments
 (0)