This is the repository for the paper "Hierarchical Refinement: Optimal Transport to Infinity and Beyond," which scales optimal transport linearly in space and log-linearly in time by using a hierarchical strategy that constructs multiscale partitions from low-rank optimal transport.
Figure 1: Hierarchical Refinement algorithm: low-rank optimal transport is used to progressively refine partitions at the previous scale, with the coarsest scale partitions denoted
Hierarchical Refinement (HiRef) only requires two n×d dimensional point clouds X
and Y
(torch tensors) as input.
Before running HiRef, call the rank-annealing scheduler to find a sequence of ranks that minimizes the number of calls to the low-rank optimal transport subroutine while remaining under a machine-specific maximal rank.
n
: The size of the datasethierarchy_depth (κ)
: The depth of the hierarchy of levels used in the refinement strategymax_Q
: The maximal terminal rank at the base casemax_rank
: The maximal rank of the intermediate sub-problems
Import the rank annealing module and compute the rank schedule:
import rank_annealing
rank_schedule = rank_annealing.optimal_rank_schedule(
n=n, hierarchy_depth=hierarchy_depth, max_Q=max_Q, max_rank=max_rank
)
Import HR_OT and initialize the class using only the point clouds (you can additionally input the cost C
if desired) along with any relevant parameters (e.g., sq_Euclidean) for your problem.
import HR_OT
hrot = HR_OT.HierarchicalRefinementOT.init_from_point_clouds(
X, Y, rank_schedule, base_rank=1, device=device
)
Run and return paired tuples from X
and Y
(the bijective Monge map between the datasets):
Gamma_hrot = hrot.run(return_as_coupling=False)
To print the Optimal Transport (OT) cost, simply call:
cost_hrot = hrot.compute_OT_cost()
print(f"Refinement Cost: {cost_hr_ot.item()}")
For questions, discussions, or collaboration inquiries, feel free to reach out at ph3641@princeton.edu or jg7090@princeton.edu.
While the default hyperparameter settings for HiRef have changed, the exact experiments and hyperparameter settings used are available on OpenReview. At the time of benchmarking, the default epsilon of Sinkhorn in ott-jax was 0.05, which has since been modified as well.