Skip to content

Commit

Permalink
[Checkpoint][2D] Minor update for dedup_tensors.py (pytorch#89542)
Browse files Browse the repository at this point in the history
Rename variables for better readability.

Pull Request resolved: pytorch#89542
Approved by: https://github.com/H-Huang
  • Loading branch information
wz337 authored and pytorchmergebot committed Nov 23, 2022
1 parent 74703eb commit f03e667
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions torch/distributed/checkpoint/dedup_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
all_plans = list(all_plans)
key_to_plan: Dict[MetadataIndex, List[int]] = {}
for plan_idx, plan in enumerate(all_plans):
for wi in plan.items:
key_to_plan.setdefault(wi.index, []).append(plan_idx)
for write_item in plan.items:
key_to_plan.setdefault(write_item.index, []).append(plan_idx)

replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1}

Expand All @@ -47,7 +47,9 @@ def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
key_set = set(keys)
# rewrite items and remove elements
new_items = [
wi for wi in all_plans[plan_idx].items if wi.index not in key_set
write_item
for write_item in all_plans[plan_idx].items
if write_item.index not in key_set
]
all_plans[plan_idx] = dataclasses.replace(
all_plans[plan_idx], items=new_items
Expand Down

0 comments on commit f03e667

Please sign in to comment.