Skip to content

Improved Tensor Dimension Handling in predict_masks Method #581

Open
@sushmanthreddy

Description

@sushmanthreddy

Issue:
In the predict_masks method of the MaskDecoder class, there's an enhancement regarding tensor dimension handling. Here's a detailed breakdown:

  1. Conditional Check:

    • A new check if image_embeddings.shape[0] != tokens.shape[0]: has been added to ascertain tensor dimension consistency before applying torch.repeat_interleave.
  2. Usage of torch.repeat_interleave:

    • Ensures image_embeddings tensor's batch size aligns with tokens by expanding it along the batch dimension.
  3. Ensuring Consistency:

    • This check ensures that torch.repeat_interleave is applied only when necessary, ensuring consistent tensor handling within the predict_masks method, as opposed to the original implementation where torch.repeat_interleave is applied directly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions