Skip to content

multiprocessing metatensor efficiency? #6472

Open
@wyli

Description

@wyli

Describe the bug

follow up of #6468, may require benchmarking

with contextlib.suppress(BaseException):
from multiprocessing.reduction import ForkingPickler
def _rebuild_meta(cls, storage, dtype, metadata):

map_classes_to_indices utility (and ClassesToIndices transforms)
currently returns a list of coordinates for each class, but the type is
List[ of MetaTensors], where each class coordinate sub-list is its own
MetaTensor. This PR changes it to return a list of torch.Tensors (or
ndarray), since we don't need a MetaTensor here.

I ran into an issue with current MetaTensor list, where it would just
freeze without any errors, when trying to save the cached indices
(returned from ClassesToIndices to cache, which is ListProxy shared
mem). It randomly happens, but much more frequently when the number of
classes is large (e.g. 105 output classes, so ClassesToIndices returns a
list of 105 MetaTensors). I'm not sure what the cause of the freeze is,
but my guess is that ListProxy tries to pickle each element of this list
(and struggles with MetaTensors). Disabling MetaTensor return type here,
solves the issue.

cc @myron

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions