Skip to content

Commit

Permalink
Seed Luus Jaakola with left cost since we already know it. (#1692)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1692

In scaleup, we already know the cost of the left hand boundary (the
plan using min-working set). It's possible the clamped random search
might probe the left-hand margin again, which wastes a (cached) call
to the partitioner, using up one of our precious iterations without
learning anything new. Instead we inform Luus Jaakola search we
already know the left margin cost, so in these cases we can choose a
new probe point that hopefully discovers a better minimum.

Reviewed By: henrylhtsang

Differential Revision: D53297896

fbshipit-source-id: e5676c5695d285204e923b976ce3fa4e14eeb85a
  • Loading branch information
Damian Reeves authored and facebook-github-bot committed Feb 8, 2024
1 parent 8d93bc2 commit 5bfbeab
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 6 deletions.
4 changes: 3 additions & 1 deletion torchrec/distributed/planner/proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ def feedback(
logger.info(
f"EmbeddingOffloadScaleupProposer - cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, exploring [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + hbm_available), 2)}] GB"
)
self.search = LuusJaakolaSearch(0, hbm_available, max_iterations=16)
self.search = LuusJaakolaSearch(
0, hbm_available, max_iterations=16, left_cost=perf_rating
)

logger.info(
f"EmbeddingOffloadScaleupProposer - proposed size={round(bytes_to_gb(hbm_used_previously), 2)} GB, score={perf_rating}"
Expand Down
26 changes: 23 additions & 3 deletions torchrec/distributed/planner/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import math
import unittest
from typing import Callable, List
from typing import Callable, List, Optional
from unittest.mock import MagicMock

import torch
Expand Down Expand Up @@ -124,12 +124,17 @@ class TestLuusJaakolaSearch(unittest.TestCase):
# just getting lucky.
# Returns a Nx2 tensor of [xs, ys] of discovered minimums.
@staticmethod
def evaluate(x0: float, x1: float, f: Callable[[float], float]) -> torch.Tensor:
def evaluate(
x0: float,
x1: float,
f: Callable[[float], float],
left_cost: Optional[float] = None,
) -> torch.Tensor:
xs = []
ys = []
iterations = 16
for i in range(5):
search = LuusJaakolaSearch(x0, x1, iterations, seed=i)
search = LuusJaakolaSearch(x0, x1, iterations, seed=i, left_cost=left_cost)
y = search.next(0.0)
while y is not None:
fy = f(y)
Expand Down Expand Up @@ -360,3 +365,18 @@ def f(x: float) -> float:
],
)
torch.testing.assert_close(results, want)

results = TestLuusJaakolaSearch.evaluate(
mem.min().item(), mem.max().item(), f, left_cost=cost[0].item()
)
want = torch.tensor(
[
[5.370294e11, 2.314406e02],
# 2nd search finds better result given left_cost
[5.918126e11, 2.308140e02],
[5.908549e11, 2.308194e02],
[5.755533e11, 2.309337e02],
[6.184178e11, 2.308121e02],
],
)
torch.testing.assert_close(results, want)
11 changes: 9 additions & 2 deletions torchrec/distributed/planner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,14 @@ class LuusJaakolaSearch:
See https://en.wikipedia.org/wiki/Luus-Jaakola.
"""

def __init__(self, A: float, B: float, max_iterations: int, seed: int = 42) -> None:
def __init__(
self,
A: float,
B: float,
max_iterations: int,
seed: int = 42,
left_cost: Optional[float] = None,
) -> None:
self.left = A
self.right = B
self.iteration = -1
Expand All @@ -184,7 +191,7 @@ def __init__(self, A: float, B: float, max_iterations: int, seed: int = 42) -> N
self.x: float = self.uniform(self.left, self.right)
self.fx: float = 0.0
self.y: float = math.nan
self.fleft: Optional[float] = None
self.fleft: Optional[float] = left_cost
self.fright: Optional[float] = None
self.d: float = self.right - self.left

Expand Down

0 comments on commit 5bfbeab

Please sign in to comment.