Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only flatten a pytree once per container transform #6769

Open
pmeier opened this issue Oct 14, 2022 · 0 comments
Open

Only flatten a pytree once per container transform #6769

pmeier opened this issue Oct 14, 2022 · 0 comments

Comments

@pmeier
Copy link
Collaborator

pmeier commented Oct 14, 2022

This issue is ablation from #6760. #6760 (comment) was deemed good at first, but there were unexpected issues #6760 (comment) that made this proposal more controversial.


When implementing augmentation pipelines, the individual transformations are usually wrapped in a container transform like transforms.Compose. Under the assumptions that the children transforms

  1. support arbitrarily structured inputs, i.e. pytrees, and
  2. keep the structure in tact

it would suffice to only flatten the input once in the container transform, let all children operate on the flattened inputs without trying to do so again, and only unflatten at the end of the container. This would reduce the number of tree_{flatten, unflatten} calls to a fixed and low single digit number per pipeline rather than being dependent on the number of transforms.

All builtin transforms fulfill these assumptions. However, our container transforms also support custom transforms and we have no way of knowing if they also fulfill them or not.

There are two ways we could communicate this information:

  1. Annotate the transforms: Each transform could have a supports_pytree: bool attribute that the containers are looking for. For example, inside transforms.Compose

    self.transforms = transforms

    we could do

    flatten_once = all(getattr(transform, "supports_pytree", False) for transform in self.transforms)

    Since a transforms.Transform supports pytree inputs by default, I think it is reasonable to add a supports_pytree: bool = True parameter to its constructor. That means all of our builtin transformations would be supported out of the box. If we do this, we need to clearly document that users that subclass from transforms.Transform, but opt out the _check_inputs / _get_params / _transform prototocol by overwriting forward need to set this flag to the appropriate value. Otherwise, wrapping their custom transformation into a transforms.Compose will likely fail.

  2. Annotate the container transform: Instead of relying on automagic detection whether pytree objects are supported by the children, we could simply add a flatten_once: bool = False flag to container transforms. Note that this would need to be turned off by default to avoid failures if not all children support pytrees. Meaning, the user has to opt in into this feature.

Of the options above, I lean towards 1. Since 2. is opt-in, most users will probably never use the feature. I think that is worse than having users consciously set a flag if they subclass from our base class, but opt out of the features.

However, the benefits by adopting this proposal regardless of which option we choose are insignificant for a single call. They only manifest for large scale trainings. Even there we are looking at shaving double digit minutes on runs that take single digit days. Thus, we should also discuss if this change is worth it at all to introduce new API surface. If we want / need this performance gain, but don't want to touch the API, maybe there is a way to only implement this in our references. It will be probably be more complicated though, since some of the changes need to happen on transforms.Transform and no only on the container transforms.

cc @vfdev-5 @datumbox @bjuncek

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant