Skip to content

Commit

Permalink
Add more comments in sharded_tbes_weights_spec (pytorch#1596)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1596

As titled.

Reviewed By: henrylhtsang

Differential Revision: D52341591

fbshipit-source-id: 168c6345f2b236fe774d4c0c8d3a5a5a7bff17c4
  • Loading branch information
gnahzg authored and facebook-github-bot committed Feb 6, 2024
1 parent fed5202 commit 4f234a2
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchrec/distributed/quant_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ def sharded_tbes_weights_spec(
# "ebc.tbes.1.1.table_1.weight_qscale":WeightSpec("ebc.embedding_bags.table_1.weight_qscale", [500, 0], [500, 2])
# "ebc.tbes.1.1.table_1.weight_qbias":WeightSpec("ebc.embedding_bags.table_1.weight_qbias", [500, 0], [500, 2])
# }
# In the format of ebc.tbes.i.j.table_k.weight, where i is the index of the TBE, j is the index of the embedding bag within TBE i, k is the index of the original table set in the ebc embedding_configs
# e.g. ebc.tbes.1.1.table_1.weight, it represents second embedding bag within the second TBE. This part of weight is from a shard of table_1

ret: Dict[str, WeightSpec] = {}
for module_fqn, module in sharded_model.named_modules():
Expand Down

0 comments on commit 4f234a2

Please sign in to comment.