Open
Description
Description
I’m using torch.compile with DistributedModelParallel. Running below code result in AttributeError: 'NoneType' object has no attribute '_dynamo_weak_dynamic_indices'. Note that this seems to only happen when using row-wise sharding. I would expect no such errors when running the above code.
Enviroment:
python=3.11.8, torch= '2.2.2+cu121', torchrec= '0.6.0+cu121'.
Reproduction code:
import os
from typing import Callable, List, Union, Tuple
import multiprocessing
import torch
import torch.distributed as dist
import torch.nn as nn
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
EmbeddingShardingPlanner,
Topology,
)
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
from torchrec.distributed.test_utils.test_sharding import create_test_sharder
from torchrec.distributed.test_utils.test_model import (
ModelInput,
)
from torchrec.distributed.types import (
ModuleSharder,
ShardingEnv,
ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.test_utils import get_free_port
class TestModel(nn.Module):
def __init__(self):
super().__init__()
# define model parameters
self.dense_in_feature = 820
self.dense_out_feature = 784
self.table_params = [
[311, 108],
[739, 408],
]
self.weighted_table_params = [
[159, 96],
[69, 24],
[412, 564],
[940, 300],
]
self.over_out_feature = 61
# sparse layer
self.tables = [
EmbeddingBagConfig(
num_embeddings=self.table_params[i][0],
embedding_dim=self.table_params[i][1],
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(len(self.table_params))
]
self.sparse = EmbeddingBagCollection(
tables=self.tables,
is_weighted=False,
)
# weighted sparse layer
self.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=self.weighted_table_params[i][0],
embedding_dim=self.weighted_table_params[i][1],
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(len(self.weighted_table_params))
]
self.sparse_weighted = EmbeddingBagCollection(
tables=self.weighted_tables,
is_weighted=True,
)
# dense layer
self.dense = nn.Linear(in_features=self.dense_in_feature, out_features=self.dense_out_feature, bias=True)
# over layer
in_features_concat = (
self.dense_out_feature
+ sum([table.embedding_dim * len(table.feature_names) for table in self.tables])
+ sum([table.embedding_dim * len(table.feature_names) for table in self.weighted_tables])
)
self.over = nn.Linear(in_features=in_features_concat, out_features=self.over_out_feature, bias=True)
def forward(
self,
input: ModelInput,
print_intermediate_layer: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# dense, sparse, weighted sparse layer output
dense_r = self.dense(input.float_features)
sparse_r = self.sparse(input.idlist_features)
sparse_weighted_r = self.sparse_weighted(input.idscore_features)
# concat dense, sparse, weighted sparse layer output
result = KeyedTensor(
keys=sparse_r.keys() + sparse_weighted_r.keys(),
length_per_key=sparse_r.length_per_key()
+ sparse_weighted_r.length_per_key(),
values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
)
_features = [feature for table in self.tables for feature in table.feature_names]
_weighted_features = [feature for table in self.weighted_tables for feature in table.feature_names]
ret_list = []
ret_list.append(dense_r)
for feature_name in _features:
ret_list.append(result[feature_name])
for feature_name in _weighted_features:
ret_list.append(result[feature_name])
ret_concat = torch.cat(ret_list, dim=1)
# over layer output
over_r = self.over(ret_concat)
# sigmoid output
pred = torch.sigmoid(torch.mean(over_r, dim=1))
return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)
def sharding_single_rank_test(
rank: int,
world_size: int,
model,
inputs,
sharders: List[ModuleSharder[nn.Module]],
backend: str,
compiled = True,
) -> None:
with MultiProcessContext(rank, world_size, backend) as ctx:
if compiled:
model = torch.compile(model)
local_model = model.to(ctx.device)
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size, ctx.device.type
),
)
plan: ShardingPlan = planner.collective_plan(local_model, sharders, ctx.pg)
local_model = DistributedModelParallel(
local_model,
env=ShardingEnv.from_process_group(ctx.pg),
plan=plan,
sharders=sharders,
device=ctx.device,
)
# Run a single training step of the sharded model.
local_input = inputs[0][1][rank].to(ctx.device)
local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)
# record the local prediction
all_local_pred = []
for _ in range(world_size):
all_local_pred.append(torch.empty_like(local_pred))
dist.all_gather(all_local_pred, local_pred, group=ctx.pg)
# record the local model's layer output
all_dense_r = []
for _ in range(world_size):
all_dense_r.append(torch.empty_like(dense_r))
dist.all_gather(all_dense_r, dense_r, group=ctx.pg)
# print(sparse_r.to_dict())
sparse_r_dict = sparse_r.to_dict()
all_sparse_r_dict = {}
for key in sparse_r_dict:
all_sparse_r_dict[key] = []
for _ in range(world_size):
all_sparse_r_dict[key].append(torch.empty_like(sparse_r_dict[key]))
dist.all_gather(all_sparse_r_dict[key], sparse_r_dict[key].contiguous(), group=ctx.pg)
sparse_weighted_r_dict = sparse_weighted_r.to_dict()
all_sparse_weighted_r_dict = {}
for key in sparse_weighted_r_dict:
all_sparse_weighted_r_dict[key] = []
for _ in range(world_size):
all_sparse_weighted_r_dict[key].append(torch.empty_like(sparse_weighted_r_dict[key]))
dist.all_gather(all_sparse_weighted_r_dict[key], sparse_weighted_r_dict[key].contiguous(), group=ctx.pg)
all_over_r = []
for _ in range(world_size):
all_over_r.append(torch.empty_like(over_r))
dist.all_gather(all_over_r, over_r, group=ctx.pg)
def setUp():
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"
os.environ["NCCL_SOCKET_IFNAME"] = "lo"
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def run_multi_process_test(
callable: Callable[
...,
None,
],
world_size: int,
# pyre-ignore
**kwargs,
) -> None:
setUp()
ctx = multiprocessing.get_context("forkserver")
processes = []
for rank in range(world_size):
kwargs["rank"] = rank
kwargs["world_size"] = world_size
p = ctx.Process(
target=callable,
kwargs=kwargs,
)
p.start()
processes.append(p)
for p in processes:
p.join()
def main_test(
sharders: List[ModuleSharder[nn.Module]],
backend: str,
world_size: int,
compiled: bool,
) -> None:
model = TestModel()
inputs = [ModelInput.generate(
batch_size=1200,
world_size=world_size,
num_float_features=model.dense_in_feature,
tables=model.tables,
weighted_tables=model.weighted_tables,
)]
run_multi_process_test(
callable=sharding_single_rank_test,
world_size=world_size,
model=model,
inputs=inputs,
sharders=sharders,
backend=backend,
compiled=compiled,
)
if __name__ == "__main__":
sharders = [create_test_sharder("embedding_bag_collection", "row_wise", "dense")]
backend = "nccl"
world_size = 2
main_test(
sharders = sharders,
backend = backend,
world_size = world_size,
compiled = True,
)
Log:
The error message is copied below.
Traceback (most recent call last):
File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/mnt/tests/reproduce_nccl_row_wise.py", line 154, in sharding_single_rank_test
local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/model_parallel.py", line 288, in forward
return self._dmp_wrapped_module(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
else self._run_ddp_forward(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/tests/reproduce_nccl_row_wise.py", line 86, in forward
def forward(
File "/mnt/tests/reproduce_nccl_row_wise.py", line 93, in resume_in_forward
sparse_r = self.sparse(input.idlist_features)
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/types.py", line 747, in forward
dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/embeddingbag.py", line 756, in input_dist
def input_dist(
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/embeddingbag.py", line 782, in resume_in_input_dist
awaitables.append(input_dist(features_by_shard))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/sharding/rw_sharding.py", line 292, in forward
def forward(
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 901, in forward
return compiled_fn(full_args)
^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
return f(*args)
^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 259, in runtime_wrapper
t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute '_dynamo_weak_dynamic_indices'
Metadata
Metadata
Assignees
Labels
No labels