File tree Expand file tree Collapse file tree 1 file changed +7
-10
lines changed Expand file tree Collapse file tree 1 file changed +7
-10
lines changed Original file line number Diff line number Diff line change @@ -1172,16 +1172,13 @@ def _create_output_dist(self) -> None:
11721172 for i , name in enumerate (self ._uncombined_embedding_names ):
11731173 embedding_name_order .setdefault (name , i )
11741174
1175- def sort_key (input : Tuple [int , str ]) -> Tuple [int , int ]:
1176- index , name = input
1177- return (embedding_name_order [name ], embedding_shard_offsets [index ])
1178-
1179- permute_indices = [
1180- i
1181- for i , _ in sorted (
1182- enumerate (self ._uncombined_embedding_names ), key = sort_key
1183- )
1184- ]
1175+ permute_indices = sorted (
1176+ range (len (self ._uncombined_embedding_names )),
1177+ key = lambda i : (
1178+ embedding_name_order [self ._uncombined_embedding_names [i ]],
1179+ embedding_shard_offsets [i ],
1180+ ),
1181+ )
11851182 self ._permute_op : PermutePooledEmbeddings = PermutePooledEmbeddings (
11861183 self ._uncombined_embedding_dims , permute_indices , self ._device
11871184 )
You can’t perform that action at this time.
0 commit comments