|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +# pyre-strict |
| 9 | + |
| 10 | +import logging |
| 11 | +import random |
| 12 | +from typing import List, Optional |
| 13 | + |
| 14 | +import torch |
| 15 | +import torch.distributed as dist |
| 16 | +import torch.nn as nn |
| 17 | +from torchrec.distributed.embeddingbag import EmbeddingBagCollection |
| 18 | + |
| 19 | +from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta |
| 20 | + |
| 21 | +from torchrec.distributed.sharding_plan import ( |
| 22 | + column_wise, |
| 23 | + construct_module_sharding_plan, |
| 24 | + table_wise, |
| 25 | +) |
| 26 | + |
| 27 | +from torchrec.distributed.test_utils.test_sharding import generate_rank_placements |
| 28 | +from torchrec.distributed.types import EmbeddingModuleShardingPlan |
| 29 | + |
| 30 | +logger: logging.Logger = logging.getLogger(__name__) |
| 31 | + |
| 32 | + |
| 33 | +class ReshardingHandler: |
| 34 | + """ |
| 35 | + Handles the resharding of a training module by generating and applying different sharding plans. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__(self, train_module: nn.Module, num_plans: int) -> None: |
| 39 | + """ |
| 40 | + Initializes the ReshardingHandler with a training module and the number of sharding plans to generate. |
| 41 | +
|
| 42 | + Args: |
| 43 | + train_module (nn.Module): The training module to be resharded. |
| 44 | + num_plans (int): The number of sharding plans to generate. |
| 45 | + """ |
| 46 | + self._train_module = train_module |
| 47 | + if not hasattr(train_module, "_module"): |
| 48 | + raise RuntimeError("Incorrect train module") |
| 49 | + |
| 50 | + if not hasattr(train_module._module, "plan"): |
| 51 | + raise RuntimeError("sharding plan cannot be found") |
| 52 | + |
| 53 | + # Pyre-ignore |
| 54 | + plan = train_module._module.plan.plan |
| 55 | + key = "main_module.sparse_arch.embedding_bag_collection" |
| 56 | + module = ( |
| 57 | + # Pyre-ignore |
| 58 | + train_module._module.module.main_module.sparse_arch.embedding_bag_collection |
| 59 | + ) |
| 60 | + self._resharding_plans: List[EmbeddingModuleShardingPlan] = [] |
| 61 | + world_size = dist.get_world_size() |
| 62 | + |
| 63 | + # TODO: change this logic when, proper planner is integrated |
| 64 | + if key in plan: |
| 65 | + ebc = plan[key] |
| 66 | + num_tables = len(ebc) |
| 67 | + ranks_per_tables = [1 for _ in range(num_tables)] |
| 68 | + ranks_per_tables_for_CW = [] |
| 69 | + for index, table_config in enumerate(module._embedding_bag_configs): |
| 70 | + # CW sharding |
| 71 | + valid_candidates = [ |
| 72 | + i |
| 73 | + for i in range(1, world_size + 1) |
| 74 | + if table_config.embedding_dim % i == 0 |
| 75 | + ] |
| 76 | + rng = random.Random(index) |
| 77 | + ranks_per_tables_for_CW.append(rng.choice(valid_candidates)) |
| 78 | + |
| 79 | + for i in range(num_plans): |
| 80 | + new_ranks = generate_rank_placements( |
| 81 | + world_size, num_tables, ranks_per_tables, i |
| 82 | + ) |
| 83 | + new_ranks_cw = generate_rank_placements( |
| 84 | + world_size, num_tables, ranks_per_tables_for_CW, i |
| 85 | + ) |
| 86 | + new_per_param_sharding = {} |
| 87 | + for i, (talbe_id, param) in enumerate(ebc.items()): |
| 88 | + if param.sharding_type == "column_wise": |
| 89 | + cw_gen = column_wise( |
| 90 | + ranks=new_ranks_cw[i], |
| 91 | + compute_kernel=param.compute_kernel, |
| 92 | + ) |
| 93 | + new_per_param_sharding[talbe_id] = cw_gen |
| 94 | + else: |
| 95 | + tw_gen = table_wise( |
| 96 | + rank=new_ranks[i][0], |
| 97 | + compute_kernel=param.compute_kernel, |
| 98 | + ) |
| 99 | + new_per_param_sharding[talbe_id] = tw_gen |
| 100 | + |
| 101 | + lightweight_ebc = EmbeddingBagCollection( |
| 102 | + tables=module._embedding_bag_configs, |
| 103 | + device=torch.device( |
| 104 | + "meta" |
| 105 | + ), # Use meta device to avoid actual memory allocation |
| 106 | + ) |
| 107 | + |
| 108 | + meta_device = torch.device("meta") |
| 109 | + new_plan = construct_module_sharding_plan( |
| 110 | + lightweight_ebc, |
| 111 | + per_param_sharding=new_per_param_sharding, # Pyre-ignore |
| 112 | + local_size=world_size, |
| 113 | + world_size=world_size, |
| 114 | + # Pyre-ignore |
| 115 | + device_type=meta_device, |
| 116 | + ) |
| 117 | + self._resharding_plans.append(new_plan) |
| 118 | + else: |
| 119 | + raise RuntimeError(f"Plan does not have key: {key}") |
| 120 | + |
| 121 | + def step(self, batch_no: int) -> float: |
| 122 | + """ |
| 123 | + Executes a step in the training process by selecting and applying a sharding plan. |
| 124 | +
|
| 125 | + Args: |
| 126 | + batch_no (int): The current batch number. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + float: The data volume of the sharding plan delta. |
| 130 | + """ |
| 131 | + # Pyre-ignore |
| 132 | + plan = self._train_module._module.plan.plan |
| 133 | + key = "main_module.sparse_arch.embedding_bag_collection" |
| 134 | + |
| 135 | + # Use the current step as a seed to ensure all ranks get the same random number |
| 136 | + # but it changes on each call |
| 137 | + |
| 138 | + rng = random.Random(batch_no) |
| 139 | + index = rng.randint(0, len(self._resharding_plans) - 1) |
| 140 | + logger.info(f"Selected resharding plan index {index} for step {batch_no}") |
| 141 | + # Get the selected plan |
| 142 | + selected_plan = self._resharding_plans[index] |
| 143 | + |
| 144 | + # Fix the device mismatch by updating the placement device in the sharding spec |
| 145 | + # This is necessary because the plan was created with meta device but needs to be applied on CUDA |
| 146 | + for _, param_sharding in selected_plan.items(): |
| 147 | + sharding_spec = param_sharding.sharding_spec |
| 148 | + if sharding_spec is not None: |
| 149 | + # pyre-ignore |
| 150 | + for shard in sharding_spec.shards: |
| 151 | + placement = shard.placement |
| 152 | + rank: Optional[int] = placement.rank() |
| 153 | + assert rank is not None |
| 154 | + current_device = ( |
| 155 | + torch.cuda.current_device() |
| 156 | + if rank == torch.distributed.get_rank() |
| 157 | + else rank % torch.cuda.device_count() |
| 158 | + ) |
| 159 | + shard.placement = torch.distributed._remote_device( |
| 160 | + f"rank:{rank}/cuda:{current_device}" |
| 161 | + ) |
| 162 | + |
| 163 | + data_volume, delta_plan = output_sharding_plan_delta( |
| 164 | + plan[key], selected_plan, True |
| 165 | + ) |
| 166 | + # Pyre-ignore |
| 167 | + self._train_module.module.reshard( |
| 168 | + sharded_module_fqn="main_module.sparse_arch.embedding_bag_collection", |
| 169 | + changed_shard_to_params=delta_plan, |
| 170 | + ) |
| 171 | + return data_volume |
0 commit comments