Skip to content

procrustes alignment #2691

Closed
Closed
@heth27

Description

@heth27

🚀 Feature

spatial procrustes alignment, a similarity test for two data sets

Motivation

Procrustes alignment is a staple when calculating metrics for 3d human pose estimation, but there seems to be no library that offers this function for pytorch, so I guess everyone just maintains their own version.

Pitch

There is a variant in scipy
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.procrustes.html

Alternatives

Additional context

The implementation I'm using, don't know if it is any good.

def procrustes(pts1: torch.Tensor, pts2: torch.Tensor):
    assert pts1.shape == pts2.shape, f"{pts1.shape} != {pts2.shape}"
    assert pts1.shape[-1] == 3 and len(pts1.shape) == 2, f"{pts1.shape}"
    # estimate a sim3 transformation to align two point clouds
    # find M = argmin ||P1 - M @ P2||
    t1 = pts1.mean(dim=0)
    t2 = pts2.mean(dim=0)
    pts1 = pts1 - t1[None, :]
    pts2 = pts2 - t2[None, :]

    s1 = pts1.square().sum(dim=-1).mean().sqrt()
    s2 = pts2.square().sum(dim=-1).mean().sqrt()
    pts1 = pts1 / s1
    pts2 = pts2 / s2
    try:

        U, _, V = (pts1.T @ pts2).double().svd()
        U: torch.Tensor = U
        V: torch.Tensor = V
    except:
        print("Procustes failed: SVD did not converge!")
        s = s1 / s2
        return 1, torch.eye(3, device=pts1.device), torch.zeros_like(t1)
    # build rotation matrix
    R = (U @ V.T).float()
    if R.det() < 0:
        R[:, 2] *= -1
    s = s1 / s2
    t = t1 - s * t2 @ R.T

    # use as mat4: [sR, t] @ pts2
    # or as s * R @ pts2 + t

    # s, R, mean_1, mean_2 = procrustes(pts1, pts2)
    #
    # procrustes_aligned = torch.einsum("jd, od -> jo", coords3d_pred_rel_dataset_format[index_in_batch] - mean_2,
    #                                               s * R) + mean_1
    return s, R, t1, t2

example usage:

s, R, mean_1, mean_2 = procrustes(coords_3d_true,
                                              coords_3d_prediction)
procrustes_aligned = torch.einsum("jd, od -> jo", coords_3d_prediction - mean_2,
                                              s * R) + mean_1

The problem with this version is that it does not work on batches.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions