Skip to content

refactor planner to use device_type #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,18 +490,20 @@ def shard(
def sharding_types(self) -> List[str]:
return [ShardingType.DATA_PARALLEL.value]

def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]:
def compute_kernels(
self, sharding_type: str, compute_device_type: str
) -> List[str]:
return [
EmbeddingComputeKernel.BATCHED_QUANT.value,
]

def storage_usage(
self, tensor: torch.Tensor, device: torch.device, compute_kernel: str
self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str
) -> Dict[str, int]:
tensor_bytes = tensor.numel() * tensor.element_size() + tensor.shape[0] * 4
assert device.type in {"cuda", "cpu"}
assert compute_device_type in {"cuda", "cpu"}
storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR}
return {storage_map[device.type].value: tensor_bytes}
return {storage_map[compute_device_type].value: tensor_bytes}

def shardable_parameters(
self, module: QuantEmbeddingBagCollection
Expand Down
14 changes: 8 additions & 6 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def sharding_types(self) -> List[str]:

return types

def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]:
def compute_kernels(
self, sharding_type: str, compute_device_type: str
) -> List[str]:
ret = [
EmbeddingComputeKernel.DENSE.value,
EmbeddingComputeKernel.BATCHED_DENSE.value,
Expand All @@ -184,7 +186,7 @@ def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]
EmbeddingComputeKernel.BATCHED_FUSED.value,
EmbeddingComputeKernel.SPARSE.value,
]
if device.type in {"cuda"}:
if compute_device_type in {"cuda"}:
ret += [
EmbeddingComputeKernel.BATCHED_FUSED_UVM.value,
EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value,
Expand All @@ -196,7 +198,7 @@ def fused_params(self) -> Optional[Dict[str, Any]]:
return self._fused_params

def storage_usage(
self, tensor: torch.Tensor, device: torch.device, compute_kernel: str
self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str
) -> Dict[str, int]:
"""
List of system resources and corresponding usage given a compute device and
Expand All @@ -207,12 +209,12 @@ def storage_usage(
EmbeddingComputeKernel.BATCHED_FUSED_UVM.value,
EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value,
}:
assert device.type in {"cuda"}
assert compute_device_type in {"cuda"}
return {ParameterStorage.DDR.value: tensor_bytes}
else:
assert device.type in {"cuda", "cpu"}
assert compute_device_type in {"cuda", "cpu"}
storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR}
return {
storage_map[device.type].value: tensor.element_size()
storage_map[compute_device_type].value: tensor.element_size()
* tensor.nelement()
}
2 changes: 1 addition & 1 deletion torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(

# 2. Call ShardingPlanner.collective_plan passing all found modules and corresponding sharders.
if plan is None:
planner = EmbeddingShardingPlanner(self._env.world_size, self.device)
planner = EmbeddingShardingPlanner(self._env.world_size, self.device.type)
pg = self._env.process_group
if pg is not None:
plan = planner.collective_plan(module, sharders, pg)
Expand Down
19 changes: 12 additions & 7 deletions torchrec/distributed/planner/embedding_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class EmbeddingShardingPlanner(ShardingPlanner):
def __init__(
self,
world_size: int,
device: torch.device,
compute_device_type: str = "cuda",
hints: Optional[Dict[str, ParameterHints]] = None,
input_stats: Optional[Dict[str, ParameterInputStats]] = None,
storage: Optional[Dict[str, int]] = None,
Expand All @@ -62,7 +62,7 @@ def __init__(
self._input_stats: Dict[str, ParameterInputStats] = (
input_stats if input_stats else {}
)
self._device = device
self._compute_device_type = compute_device_type

if cost_functions is None:
self._cost_functions: List[Callable[[CostInput], int]] = [
Expand All @@ -71,7 +71,9 @@ def __init__(
else:
self._cost_functions = cost_functions

self._topology: Topology = get_topology(world_size, device, storage)
self._topology: Topology = get_topology(
world_size, compute_device_type, storage
)
self._counter: int = 1

def collective_plan(
Expand Down Expand Up @@ -131,7 +133,7 @@ def plan(

sharding_plan = to_plan(
param_infos,
self._device,
self._compute_device_type,
self._world_size,
self._local_size,
)
Expand Down Expand Up @@ -419,11 +421,14 @@ def _get_param_infos(
name, param, sharding_type
)
for compute_kernel in self._filter_compute_kernels(
name, sharder.compute_kernels(sharding_type, self._device)
name,
sharder.compute_kernels(
sharding_type, self._compute_device_type
),
):
cost_input = CostInput(
param=param,
device=self._device,
compute_device_type=self._compute_device_type,
compute_kernel=compute_kernel,
sharding_type=sharding_type,
input_stats=self._input_stats.get(name, None),
Expand All @@ -440,7 +445,7 @@ def _get_param_infos(
sharding_type=sharding_type,
compute_kernel=compute_kernel,
storage_usage=sharder.storage_usage(
param, self._device, compute_kernel
param, self._compute_device_type, compute_kernel
),
_num_col_wise_shards=num_col_wise_shards,
col_wise_shard_dim=shard_size,
Expand Down
36 changes: 18 additions & 18 deletions torchrec/distributed/planner/parameter_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def _rw_shard_table_rows(hash_size: int, world_size: int) -> Tuple[List[int], in


def _device_placement(
device: torch.device,
compute_device_type: str,
rank: int,
local_size: int,
) -> str:
param_device = device
if device.type == "cuda":
param_device = torch.device("cpu")
if compute_device_type == "cuda":
param_device = torch.device("cuda", rank % local_size)
return f"rank:{rank}/{param_device}"

Expand All @@ -78,31 +78,31 @@ class ParameterShardingFactory(abc.ABC):
@staticmethod
def shard_parameters(
param_info: ParameterInfo,
device: torch.device,
compute_device_type: str,
world_size: int,
local_size: int,
) -> ParameterSharding:
sharding_option = param_info.sharding_options[0]
sharding_type = sharding_option.sharding_type
if sharding_type == ShardingType.TABLE_WISE.value:
parameter_sharding = TwParameterSharding.shard_parameters(
param_info, device, world_size, local_size
param_info, compute_device_type, world_size, local_size
)
elif sharding_type == ShardingType.ROW_WISE.value:
parameter_sharding = RwParameterSharding.shard_parameters(
param_info, device, world_size, local_size
param_info, compute_device_type, world_size, local_size
)
elif sharding_type == ShardingType.TABLE_ROW_WISE.value:
parameter_sharding = TwRwParameterSharding.shard_parameters(
param_info, device, world_size, local_size
param_info, compute_device_type, world_size, local_size
)
elif sharding_type == ShardingType.COLUMN_WISE.value:
parameter_sharding = CwParameterSharding.shard_parameters(
param_info, device, world_size, local_size
param_info, compute_device_type, world_size, local_size
)
elif sharding_type == ShardingType.DATA_PARALLEL.value:
parameter_sharding = DpParameterSharding.shard_parameters(
param_info, device, world_size, local_size
param_info, compute_device_type, world_size, local_size
)
else:
raise ValueError(
Expand All @@ -116,7 +116,7 @@ class TwParameterSharding:
def shard_parameters(
cls,
param_info: ParameterInfo,
device: torch.device,
compute_device_type: str,
world_size: int,
local_size: int,
) -> ParameterSharding:
Expand All @@ -131,7 +131,7 @@ def shard_parameters(
tensor.shape[1],
],
shard_offsets=[0, 0],
placement=_device_placement(device, rank, local_size),
placement=_device_placement(compute_device_type, rank, local_size),
)
]
return ParameterSharding(
Expand All @@ -147,7 +147,7 @@ class RwParameterSharding:
def shard_parameters(
cls,
param_info: ParameterInfo,
device: torch.device,
compute_device_type: str,
world_size: int,
local_size: int,
) -> ParameterSharding:
Expand All @@ -163,7 +163,7 @@ def shard_parameters(
tensor.shape[1],
],
shard_offsets=[block_size * min(rank, last_rank), 0],
placement=_device_placement(device, rank, local_size),
placement=_device_placement(compute_device_type, rank, local_size),
)
for rank in range(world_size)
]
Expand All @@ -180,7 +180,7 @@ class TwRwParameterSharding:
def shard_parameters(
cls,
param_info: ParameterInfo,
device: torch.device,
compute_device_type: str,
world_size: int,
local_size: int,
) -> ParameterSharding:
Expand All @@ -203,7 +203,7 @@ def shard_parameters(
local_cols[rank],
],
shard_offsets=[local_row_offsets[rank], 0],
placement=_device_placement(device, rank, local_size),
placement=_device_placement(compute_device_type, rank, local_size),
)
for rank in range(table_node * local_size, (table_node + 1) * local_size)
]
Expand All @@ -221,7 +221,7 @@ class CwParameterSharding:
def shard_parameters(
cls,
param_info: ParameterInfo,
device: torch.device,
compute_device_type: str,
world_size: int,
local_size: int,
) -> ParameterSharding:
Expand Down Expand Up @@ -250,7 +250,7 @@ def shard_parameters(
merged_sizes[i],
],
shard_offsets=[0, offsets[i]],
placement=_device_placement(device, rank, local_size),
placement=_device_placement(compute_device_type, rank, local_size),
)
for i, rank in enumerate(merged_ranks)
]
Expand All @@ -267,7 +267,7 @@ class DpParameterSharding:
def shard_parameters(
cls,
param_info: ParameterInfo,
device: torch.device,
compute_device_type: str,
world_size: int,
local_size: int,
) -> ParameterSharding:
Expand Down
Loading