Skip to content

Commit fd49090

Browse files
suofacebook-github-bot
authored andcommitted
improve granularity of PooledEmbeddingArchAwaitable (pytorch#1267)
Summary: Pull Request resolved: pytorch#1267 Today, if I do a `__getitem__` call on `PooledEmbeddingArchAwaitable`, it triggers the wait. We'd like to defer that further to when the result of `__getitem__` is actually used. So instead, have `__getitem__` return another `LazyAwaitable` which represents the pooled embedding. Usage of that value in the context of a torchfunction will trigger the wait as desired. This ends up being important for PT2 IR integration, which eagerly dumps a bunch of `__getitem__` calls right after the sparse arch because PT2 IR prefers to operate on "flat" values. With improved granularity, we still get the desired lazy behavior. For pure eager users, this should be a no-op (we generally only call `__getitem__` right before use, so this doesn't reorder anything). The laziness affects the ordering of comms/compute, which is important in two ways: 1. PEA design means that the per-rank feature processing behavior causes the specific order of execution to be load-bearing. Without the laziness, the execution order of ranks with vs. without feature processing will diverge, causing training hangs. 2. getting comms/compute overlapping for the all to all comms vs. dense compute is likely to be a performance improvement, although it is hard to make a direct comparison because of issue pytorch#1. Further details can be found in: https://fb.workplace.com/groups/319878845696681/posts/1017888535895705 Reviewed By: dstaay-fb Differential Revision: D47272219 fbshipit-source-id: e3250caf23d800783202c07ae669c2e00708ab6e
1 parent 2aed06b commit fd49090

File tree

3 files changed

+79
-3
lines changed

3 files changed

+79
-3
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
EmbeddingModuleShardingPlan,
5151
EnumerableShardingSpec,
5252
LazyAwaitable,
53+
LazyGetItemMixin,
5354
NullShardedModuleContext,
5455
ParameterSharding,
5556
QuantizedCommCodecs,
@@ -289,7 +290,9 @@ def construct_output_kt(
289290
)
290291

291292

292-
class EmbeddingBagCollectionAwaitable(LazyAwaitable[KeyedTensor]):
293+
class EmbeddingBagCollectionAwaitable(
294+
LazyGetItemMixin[str, Tensor], LazyAwaitable[KeyedTensor]
295+
):
293296
def __init__(
294297
self,
295298
awaitables: List[Awaitable[torch.Tensor]],

torchrec/distributed/tests/test_lazy_awaitable.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
import torch.fx
13-
from torchrec.distributed.types import LazyAwaitable
13+
from torchrec.distributed.types import LazyAwaitable, LazyGetItemMixin
1414

1515

1616
class NeedWait(LazyAwaitable[torch.Tensor]):
@@ -252,3 +252,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
252252
self.assertTrue(torch.equal(ref_res, 17 * torch.ones(3, 4)))
253253

254254
tempFile.close()
255+
256+
def test_lazy_getitem_mixin(self) -> None:
257+
class LazyGetItemAwaitable(
258+
LazyGetItemMixin[str, torch.Tensor], LazyAwaitable[Dict[str, torch.Tensor]]
259+
):
260+
def __init__(self, actual_value: Dict[str, torch.Tensor]):
261+
super().__init__()
262+
self.actual_value = actual_value
263+
264+
def _wait_impl(self) -> Dict[str, torch.Tensor]:
265+
for v in self.actual_value.values():
266+
v *= 3
267+
return self.actual_value
268+
269+
actual_value = {"foo": torch.tensor(1), "bar": torch.tensor(2)}
270+
a = LazyGetItemAwaitable(actual_value)
271+
lazy_foo = a["foo"]
272+
lazy_bar = a["bar"]
273+
# The returned value should be lazy
274+
self.assertIsInstance(lazy_foo, LazyAwaitable)
275+
self.assertIsInstance(lazy_bar, LazyAwaitable)
276+
277+
# Our lazy values should not have been waited yet
278+
self.assertIsNone(lazy_foo._result)
279+
self.assertIsNone(lazy_bar._result)
280+
self.assertIsNone(a._result)
281+
282+
# The use of a torch op should trigger exactly one wait on the parent object.
283+
result = torch.add(lazy_foo, lazy_bar)
284+
self.assertEqual(result, torch.tensor(1 * 3 + 2 * 3))

torchrec/distributed/types.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,9 @@ def _wait_async(obj: Any) -> Any:
342342
else:
343343
return obj
344344

345+
@classmethod
345346
# pyre-ignore [2, 3]
346-
def __torch_function__(self, func, types, args=(), kwargs=None):
347+
def __torch_function__(cls, func, types, args=(), kwargs=None):
347348
"""
348349
The LazyAwaitable type has a `__torch_function__` implementation.
349350
This means when this type is seens as an argument to a PyTorch
@@ -391,6 +392,48 @@ def _wait_impl(self) -> W:
391392
return self._obj
392393

393394

395+
KT = TypeVar("KT")
396+
VT_co = TypeVar("VT_co")
397+
ParentW = TypeVar("ParentW")
398+
399+
400+
class LazyGetItemMixin(Generic[KT, VT_co]):
401+
"""Augments the base LazyAwaitable with a lazy __getitem__ method.
402+
403+
Instead of triggering a wait() on a __getitem__ call, KeyedLazyAwaitable
404+
will return another awaitable. This can achieve better
405+
communication/computation overlap by deferring the wait() until the
406+
tensor data is actually needed.
407+
408+
This is intended for Awaitables that model keyed collections, like
409+
dictionaries or EmbeddingBagCollectionAwaitable.
410+
411+
NOTE: if using this mixin, please include it before LazyAwaitable in the
412+
inheritance list, so that Python MRO can properly select this __getitem__
413+
implementation.
414+
"""
415+
416+
def __getitem__(self, key: KT) -> LazyAwaitable[VT_co]:
417+
return GetItemLazyAwaitable(self, key)
418+
419+
420+
class GetItemLazyAwaitable(LazyAwaitable[W], Generic[W, ParentW, KT]):
421+
"""The LazyAwaitable returned from a __getitem__ call on `LazyGetItemMixin`.
422+
423+
When the actual value of this awaitable is requested, wait on the parent and
424+
then call __getitem__ on the result.
425+
"""
426+
427+
def __init__(self, parent: LazyAwaitable[ParentW], key: KT) -> None:
428+
super().__init__()
429+
self._parent = parent
430+
self._key = key
431+
432+
def _wait_impl(self) -> W:
433+
kt = LazyAwaitable._wait_async(self._parent)
434+
return kt[self._key]
435+
436+
394437
# install magic methods
395438
for orig_method_name in torch.fx.graph.magic_methods:
396439
as_magic = f"__{orig_method_name}__"

0 commit comments

Comments
 (0)