Skip to content

col-wise ads config #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode:
self._weight_init_mins: List[float] = []
self._weight_init_maxs: List[float] = []
self._num_embeddings: List[int] = []
self._embedding_dims: List[int] = []
self._local_cols: List[int] = []
self._feature_table_map: List[int] = []
self._emb_names: List[str] = []
self._lengths_per_emb: List[int] = []
Expand All @@ -503,7 +503,7 @@ def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode:
self._weight_init_mins.append(config.get_weight_init_min())
self._weight_init_maxs.append(config.get_weight_init_max())
self._num_embeddings.append(config.num_embeddings)
self._embedding_dims.append(config.local_cols)
self._local_cols.append(config.local_cols)
self._feature_table_map.extend([idx] * config.num_features())
for feature_name in config.feature_names:
if feature_name not in shared_feature:
Expand All @@ -526,7 +526,7 @@ def init_parameters(self) -> None:
)
for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip(
self._local_rows,
self._embedding_dims,
self._local_cols,
self._weight_init_mins,
self._weight_init_maxs,
self.emb_module.split_embedding_weights(),
Expand Down Expand Up @@ -616,13 +616,28 @@ def __init__(
def to_rowwise_sharded_metadata(
local_metadata: ShardMetadata,
global_metadata: ShardedTensorMetadata,
sharding_dim: int,
) -> Tuple[ShardMetadata, ShardedTensorMetadata]:
rw_shards: List[ShardMetadata] = []
rw_local_shard: ShardMetadata = local_metadata
for shard in global_metadata.shards_metadata:
shards_metadata = global_metadata.shards_metadata
# column-wise sharding
# sort the metadata based on column offset and
# we construct the momentum tensor in row-wise sharded way
if sharding_dim == 1:
shards_metadata = sorted(
shards_metadata, key=lambda shard: shard.shard_offsets[1]
)

for idx, shard in enumerate(shards_metadata):
offset = shard.shard_offsets[0]
# for column-wise sharding, we still create row-wise sharded metadata for optimizer
# manually create a row-wise offset
if sharding_dim == 1:
offset = idx * shard.shard_lengths[0]
rw_shard = ShardMetadata(
shard_lengths=[shard.shard_lengths[0]],
shard_offsets=[shard.shard_offsets[0]],
shard_offsets=[offset],
placement=shard.placement,
)

Expand All @@ -638,10 +653,10 @@ def to_rowwise_sharded_metadata(
memory_format=global_metadata.tensor_properties.memory_format,
pin_memory=global_metadata.tensor_properties.pin_memory,
)

len_rw_shards = len(shards_metadata) if sharding_dim == 1 else 1
rw_metadata = ShardedTensorMetadata(
shards_metadata=rw_shards,
size=torch.Size([global_metadata.size[0]]),
size=torch.Size([global_metadata.size[0] * len_rw_shards]),
tensor_properties=tensor_properties,
)
return rw_local_shard, rw_metadata
Expand Down Expand Up @@ -673,10 +688,15 @@ def to_rowwise_sharded_metadata(
state[weight] = {}
# momentum1
assert table_config.local_rows == optimizer_states[0].size(0)
sharding_dim = (
1 if table_config.local_cols != table_config.embedding_dim else 0
)
momentum1_key = f"{table_config.name}.momentum1"
if optimizer_states[0].dim() == 1:
(local_metadata, sharded_tensor_metadata) = to_rowwise_sharded_metadata(
table_config.local_metadata, table_config.global_metadata
table_config.local_metadata,
table_config.global_metadata,
sharding_dim,
)
else:
(local_metadata, sharded_tensor_metadata) = (
Expand All @@ -699,7 +719,9 @@ def to_rowwise_sharded_metadata(
local_metadata,
sharded_tensor_metadata,
) = to_rowwise_sharded_metadata(
table_config.local_metadata, table_config.global_metadata
table_config.local_metadata,
table_config.global_metadata,
sharding_dim,
)
else:
(local_metadata, sharded_tensor_metadata) = (
Expand Down Expand Up @@ -769,9 +791,7 @@ def to_embedding_location(
self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=list(
zip(
self._local_rows, self._embedding_dims, managed, compute_devices
)
zip(self._local_rows, self._local_cols, managed, compute_devices)
),
feature_table_map=self._feature_table_map,
pooling_mode=self._pooling,
Expand Down Expand Up @@ -822,7 +842,7 @@ def __init__(

self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = (
DenseTableBatchedEmbeddingBagsCodegen(
list(zip(self._local_rows, self._embedding_dims)),
list(zip(self._local_rows, self._local_cols)),
feature_table_map=self._feature_table_map,
pooling_mode=self._pooling,
use_cpu=device is None or device.type == "cpu",
Expand Down
6 changes: 1 addition & 5 deletions distributed/planner/embedding_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
deallocate_param,
param_sort_key,
to_plan,
MIN_DIM,
)
from torchrec.distributed.types import (
ShardingPlan,
Expand Down Expand Up @@ -398,13 +397,10 @@ def _get_num_col_wise_shards(
col_wise_shard_dim = (
col_wise_shard_dim_hint
if col_wise_shard_dim_hint is not None
else MIN_DIM
else param.shape[1]
)
# column-wise shard the weights
num_col_wise_shards, residual = divmod(param.shape[1], col_wise_shard_dim)
assert (
num_col_wise_shards > 0
), f"the table {name} cannot be column-wise sharded into shards of {col_wise_shard_dim} dimensions"
if residual > 0:
num_col_wise_shards += 1
elif sharding_type == ShardingType.TABLE_WISE.value:
Expand Down
3 changes: 3 additions & 0 deletions distributed/planner/tests/test_embedding_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ def test_allocation_planner_cw_balanced(self) -> None:
hints={
"table_0": ParameterHints(
sharding_types=[ShardingType.COLUMN_WISE.value],
col_wise_shard_dim=32,
),
},
)
Expand Down Expand Up @@ -653,9 +654,11 @@ def test_allocation_planner_cw_two_big_rest_small_with_residual(self) -> None:
hints={
"table_0": ParameterHints(
sharding_types=[ShardingType.COLUMN_WISE.value],
col_wise_shard_dim=32,
),
"table_1": ParameterHints(
sharding_types=[ShardingType.COLUMN_WISE.value],
col_wise_shard_dim=32,
),
},
)
Expand Down