Skip to content

Add logic to stream weights in EmbeddingKVDB #4058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,19 @@ class UVMCacheStatsIndex(enum.IntEnum):
num_conflict_misses = 5


@dataclass
class RESParams:
res_server_port: int = 0 # the port of the res server
res_store_shards: int = 1 # the number of shards to store the raw embeddings
table_names: List[str] = field(default_factory=list) # table names the TBE holds
table_offsets: List[int] = field(
default_factory=list
) # table offsets for the global rows the TBE holds
table_sizes: List[int] = field(
default_factory=list
) # table sizes for the global rows the TBE holds


def construct_split_state(
embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]],
rowwise: bool,
Expand Down
25 changes: 24 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
apply_split_helper,
CounterBasedRegularizationDefinition,
CowClipDefinition,
RESParams,
UVMCacheStatsIndex,
WeightDecayMode,
)
Expand Down Expand Up @@ -89,6 +90,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
weights_placements: Tensor
weights_offsets: Tensor
_local_instance_index: int = -1
res_params: RESParams

def __init__(
self,
Expand Down Expand Up @@ -157,6 +159,8 @@ def __init__(
lazy_bulk_init_enabled: bool = False,
backend_type: BackendType = BackendType.SSD,
kv_zch_params: Optional[KVZCHParams] = None,
enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming
res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
) -> None:
super(SSDTableBatchedEmbeddingBags, self).__init__()

Expand All @@ -169,6 +173,19 @@ def __init__(
# pyre-fixme[8]: Attribute has type `device`; used as `int`.
self.current_device: torch.device = torch.cuda.current_device()

self.enable_raw_embedding_streaming = enable_raw_embedding_streaming
# initialize the raw embedding streaming related variables
self.res_params: RESParams = res_params or RESParams()
if self.enable_raw_embedding_streaming:
self.res_params.table_sizes = [0] + list(itertools.accumulate(rows))
res_port_from_env = os.getenv("LOCAL_RES_PORT")
self.res_params.res_server_port = (
int(res_port_from_env) if res_port_from_env else 0
)
logging.info(
f"get env {self.res_params.res_server_port=}, at rank {dist.get_rank()}, with {self.res_params=}"
)

self.feature_table_map: List[int] = (
feature_table_map if feature_table_map is not None else list(range(T_))
)
Expand Down Expand Up @@ -464,7 +481,7 @@ def __init__(
f"write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size},max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num},"
f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
f"row_storage_bitwidth={weights_precision.bit_rate()},block_cache_size_per_tbe={ssd_block_cache_size_per_tbe},"
f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB"
f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB, enable_raw_embedding_streaming:{self.enable_raw_embedding_streaming}"
)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
Expand All @@ -489,6 +506,12 @@ def __init__(
tbe_unique_id,
l2_cache_size,
enable_async_update,
self.enable_raw_embedding_streaming,
self.res_params.res_store_shards,
self.res_params.res_server_port,
self.res_params.table_names,
self.res_params.table_offsets,
self.res_params.table_sizes,
)
if self.bulk_init_chunk_size > 0:
self.ssd_uniform_init_lower: float = ssd_uniform_init_lower
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
bool use_passed_in_path = false,
int64_t tbe_unique_id = 0,
int64_t l2_cache_size_gb = 0,
bool enable_async_update = false)
bool enable_async_update = false,
bool enable_raw_embedding_streaming = false,
int64_t res_store_shards = 0,
int64_t res_server_port = 0,
std::vector<std::string> table_names = {},
std::vector<int64_t> table_offsets = {},
const std::vector<int64_t>& table_sizes = {})
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
path,
num_shards,
Expand All @@ -56,7 +62,13 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
use_passed_in_path,
tbe_unique_id,
l2_cache_size_gb,
enable_async_update)) {}
enable_async_update,
enable_raw_embedding_streaming,
res_store_shards,
res_server_port,
std::move(table_names),
std::move(table_offsets),
table_sizes)) {}

void set_cuda(
at::Tensor indices,
Expand Down
Loading
Loading