Skip to content

Add MTIA info to sharder #3032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,7 @@ def storage_usage(
storage_map = {
"cuda": ParameterStorage.HBM,
"cpu": ParameterStorage.DDR,
# TODO: Update it later. Setting for MTIA is same as CPU's for now.
"mtia": ParameterStorage.DDR,
"mtia": ParameterStorage.HBM,
}
return {
storage_map[compute_device_type].value: get_tensor_size_bytes(tensor)
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def __init__(
self._use_exact_enumerate_order: bool = (
use_exact_enumerate_order if use_exact_enumerate_order else False
)
memory_type = "hbm_cap" if topology.compute_device == "cuda" else "ddr_cap"
memory_type = (
"hbm_cap" if topology.compute_device in {"cuda", "mtia"} else "ddr_cap"
)
self._device_memory_sizes: Optional[
List[int]
] = ( # only used with custom topology where memory is different within a topology
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ def calculate_shard_storages(
count_ephemeral_storage_cost=count_ephemeral_storage_cost,
is_inference=is_inference,
)
if compute_device == "cuda"
if compute_device in {"cuda", "mtia"}
else 0
)
for input_size, output_size, hbm_specific_size in zip(
Expand All @@ -1273,7 +1273,7 @@ def calculate_shard_storages(
ddr_sizes: List[int] = [
(
input_size + output_size + ddr_specific_size
if compute_device in {"cpu", "mtia"} and not is_inference
if compute_device == "cpu" and not is_inference
else ddr_specific_size
)
for input_size, output_size, ddr_specific_size in zip(
Expand Down
8 changes: 4 additions & 4 deletions torchrec/distributed/planner/storage_reservations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def _reserve_dense_storage(
dense_tensor_size = dense_tensor_estimate

dense_tensor_storage = Storage(
hbm=dense_tensor_size if topology.compute_device == "cuda" else 0,
ddr=dense_tensor_size if topology.compute_device in {"cpu", "mtia"} else 0,
hbm=dense_tensor_size if topology.compute_device in {"cuda", "mtia"} else 0,
ddr=dense_tensor_size if topology.compute_device == "cpu" else 0,
)

for device in topology.devices:
Expand All @@ -93,8 +93,8 @@ def _reserve_kjt_storage(
kjt_size = math.ceil(sum(batch_inputs) * float(input_data_type_size)) * multiplier

kjt_storage = Storage(
hbm=kjt_size if topology.compute_device == "cuda" else 0,
ddr=kjt_size if topology.compute_device in {"cpu", "mtia"} else 0,
hbm=kjt_size if topology.compute_device in {"cuda", "mtia"} else 0,
ddr=kjt_size if topology.compute_device == "cpu" else 0,
)

for device in topology.devices:
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def __init__(
self._world_size = world_size

hbm_per_device = [0] * world_size
if self._compute_device == "cuda":
if self._compute_device == "cuda" or self._compute_device == "mtia":
hbm_per_device = [hbm_cap if hbm_cap else HBM_CAP] * world_size
ddr_cap_per_rank = [ddr_cap if ddr_cap else DDR_CAP] * world_size

Expand Down
3 changes: 1 addition & 2 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,8 +1197,7 @@ def storage_usage(
storage_map = {
"cuda": ParameterStorage.HBM,
"cpu": ParameterStorage.DDR,
# TODO: Update it later. Setting for MTIA is same as CPU's for now.
"mtia": ParameterStorage.DDR,
"mtia": ParameterStorage.HBM,
}
return {storage_map[compute_device_type].value: get_tensor_size_bytes(tensor)}

Expand Down
Loading