Skip to content

Commit bfe322a

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Simplify permute indices for Sharded EBC output_dist creation (#2856)
Summary: Simplify create_output_dist call Differential Revision: D72079015
1 parent f0ae23d commit bfe322a

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,16 +1172,13 @@ def _create_output_dist(self) -> None:
11721172
for i, name in enumerate(self._uncombined_embedding_names):
11731173
embedding_name_order.setdefault(name, i)
11741174

1175-
def sort_key(input: Tuple[int, str]) -> Tuple[int, int]:
1176-
index, name = input
1177-
return (embedding_name_order[name], embedding_shard_offsets[index])
1178-
1179-
permute_indices = [
1180-
i
1181-
for i, _ in sorted(
1182-
enumerate(self._uncombined_embedding_names), key=sort_key
1183-
)
1184-
]
1175+
permute_indices = sorted(
1176+
range(len(self._uncombined_embedding_names)),
1177+
key=lambda i: (
1178+
embedding_name_order[self._uncombined_embedding_names[i]],
1179+
embedding_shard_offsets[i],
1180+
),
1181+
)
11851182
self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings(
11861183
self._uncombined_embedding_dims, permute_indices, self._device
11871184
)

0 commit comments

Comments
 (0)