Skip to content

weight init for ads #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions distributed/dp_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def _shard(
compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel),
local_metadata=None,
global_metadata=None,
weight_init_max=config[0].weight_init_max,
weight_init_min=config[0].weight_init_min,
)
)
return tables_per_rank
Expand Down
2 changes: 2 additions & 0 deletions distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def _create_embedding_configs_by_sharding(
pooling=config.pooling,
is_weighted=module.is_weighted,
embedding_names=embedding_names,
weight_init_max=config.weight_init_max,
weight_init_min=config.weight_init_min,
),
parameter_sharding,
)
Expand Down
28 changes: 18 additions & 10 deletions distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def __init__(
embedding_config.local_cols,
device=device,
).uniform_(
-sqrt(1 / embedding_config.num_embeddings),
sqrt(1 / embedding_config.num_embeddings),
embedding_config.get_weight_init_min(),
embedding_config.get_weight_init_max(),
),
)
)
Expand Down Expand Up @@ -242,7 +242,7 @@ def forward(
assert sparse_features.id_list_features is not None
embeddings: List[torch.Tensor] = []
id_list_features_by_group = sparse_features.id_list_features.split(
self._id_list_feature_splits
self._id_list_feature_splits,
)
for emb_op, features in zip(self._emb_modules, id_list_features_by_group):
embeddings.append(emb_op(features).view(-1))
Expand Down Expand Up @@ -362,8 +362,8 @@ def _to_mode(pooling: PoolingType) -> str:
embedding_config.local_cols,
device=device,
).uniform_(
-sqrt(1 / embedding_config.num_embeddings),
sqrt(1 / embedding_config.num_embeddings),
embedding_config.get_weight_init_min(),
embedding_config.get_weight_init_max(),
),
)
)
Expand Down Expand Up @@ -479,6 +479,8 @@ def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode:
self._pooling: PoolingMode = to_pooling_mode(config.pooling)

self._local_rows: List[int] = []
self._weight_init_mins: List[float] = []
self._weight_init_maxs: List[float] = []
self._num_embeddings: List[int] = []
self._embedding_dims: List[int] = []
self._feature_table_map: List[int] = []
Expand All @@ -488,6 +490,8 @@ def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode:
shared_feature: Dict[str, bool] = {}
for idx, config in enumerate(self._config.embedding_tables):
self._local_rows.append(config.local_rows)
self._weight_init_mins.append(config.get_weight_init_min())
self._weight_init_maxs.append(config.get_weight_init_max())
self._num_embeddings.append(config.num_embeddings)
self._embedding_dims.append(config.local_cols)
self._feature_table_map.extend([idx] * config.num_features())
Expand All @@ -510,14 +514,18 @@ def init_parameters(self) -> None:
assert len(self._num_embeddings) == len(
self.emb_module.split_embedding_weights()
)
for (rows, num_emb, emb_dim, param) in zip(
for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip(
self._local_rows,
self._num_embeddings,
self._embedding_dims,
self._weight_init_mins,
self._weight_init_maxs,
self.emb_module.split_embedding_weights(),
):
assert param.shape == (rows, emb_dim)
param.data.uniform_(-sqrt(1 / num_emb), sqrt(1 / num_emb))
param.data.uniform_(
weight_init_min,
weight_init_max,
)

def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
values = self.emb_module(
Expand Down Expand Up @@ -896,15 +904,15 @@ def forward(
if len(self._emb_modules) > 0:
assert sparse_features.id_list_features is not None
id_list_features_by_group = sparse_features.id_list_features.split(
self._id_list_feature_splits
self._id_list_feature_splits,
)
for emb_op, features in zip(self._emb_modules, id_list_features_by_group):
embeddings.append(emb_op(features).values())
if len(self._score_emb_modules) > 0:
assert sparse_features.id_score_list_features is not None
id_score_list_features_by_group = (
sparse_features.id_score_list_features.split(
self._id_score_list_feature_splits
self._id_score_list_feature_splits,
)
)
for emb_op, features in zip(
Expand Down
2 changes: 2 additions & 0 deletions distributed/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def _shard(
compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel),
local_metadata=shards[rank],
global_metadata=global_metadata,
weight_init_max=config[0].weight_init_max,
weight_init_min=config[0].weight_init_min,
)
)
return tables_per_rank
Expand Down
2 changes: 2 additions & 0 deletions distributed/tw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def _shard(
compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel),
local_metadata=shards[0],
global_metadata=global_metadata,
weight_init_max=config[0].weight_init_max,
weight_init_min=config[0].weight_init_min,
)
)
return tables_per_rank
Expand Down
2 changes: 2 additions & 0 deletions distributed/twrw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def _shard(
compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel),
local_metadata=shards[rank_idx],
global_metadata=global_metadata,
weight_init_max=config[0].weight_init_max,
weight_init_min=config[0].weight_init_min,
)
)
return tables_per_rank
Expand Down
17 changes: 16 additions & 1 deletion modules/embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from dataclasses import dataclass, field
from enum import Enum, unique
from typing import List, Dict
from math import sqrt
from typing import Optional, List, Dict


@unique
Expand Down Expand Up @@ -39,6 +40,20 @@ class BaseEmbeddingConfig:
name: str = ""
data_type: DataType = DataType.FP32
feature_names: List[str] = field(default_factory=list)
weight_init_max: Optional[float] = None
weight_init_min: Optional[float] = None

def get_weight_init_max(self) -> float:
if self.weight_init_max is None:
return sqrt(1 / self.num_embeddings)
else:
return self.weight_init_max

def get_weight_init_min(self) -> float:
if self.weight_init_min is None:
return -sqrt(1 / self.num_embeddings)
else:
return self.weight_init_min

def num_features(self) -> int:
return len(self.feature_names)
Expand Down
7 changes: 7 additions & 0 deletions modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn as nn
from torchrec.modules.embedding_configs import (
DataType,
EmbeddingConfig,
EmbeddingBagConfig,
PoolingType,
Expand Down Expand Up @@ -108,12 +109,18 @@ def __init__(
if embedding_config.name in table_names:
raise ValueError(f"Duplicate table name {embedding_config.name}")
table_names.add(embedding_config.name)
dtype = (
torch.float32
if embedding_config.data_type == DataType.FP32
else torch.float16
)
self.embedding_bags[embedding_config.name] = nn.EmbeddingBag(
num_embeddings=embedding_config.num_embeddings,
embedding_dim=embedding_config.embedding_dim,
mode=_to_mode(embedding_config.pooling),
device=device,
include_last_offset=True,
dtype=dtype,
)
if not embedding_config.feature_names:
embedding_config.feature_names = [embedding_config.name]
Expand Down