Skip to content

set fake stbe for ec #1490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ class TestModelInfo:


def prep_inputs(
model_info: TestModelInfo, world_size: int, batch_size: int = 1, count: int = 5
model_info: TestModelInfo,
world_size: int,
batch_size: int = 1,
count: int = 5,
long_indices: bool = True,
) -> List[ModelInput]:
inputs = []
for _ in range(count):
Expand All @@ -96,6 +100,7 @@ def prep_inputs(
num_float_features=model_info.num_float_features,
tables=model_info.tables,
weighted_tables=model_info.weighted_tables,
long_indices=long_indices,
)[1][0],
)
return inputs
Expand Down
15 changes: 13 additions & 2 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def generate(
]
] = None,
variable_batch_size: bool = False,
long_indices: bool = True,
) -> Tuple["ModelInput", List["ModelInput"]]:
"""
Returns a global (single-rank training) batch
Expand Down Expand Up @@ -109,7 +110,12 @@ def generate(
else:
lengths = lengths_
num_indices = cast(int, torch.sum(lengths).item())
indices = torch.randint(0, ind_range, (num_indices,))
indices = torch.randint(
0,
ind_range,
(num_indices,),
dtype=torch.long if long_indices else torch.int32,
)
global_idlist_lengths.append(lengths)
global_idlist_indices.append(indices)
global_idlist_kjt = KeyedJaggedTensor(
Expand All @@ -133,7 +139,12 @@ def generate(
else:
lengths = lengths_
num_indices = cast(int, torch.sum(lengths).item())
indices = torch.randint(0, ind_range, (num_indices,))
indices = torch.randint(
0,
ind_range,
(num_indices,),
dtype=torch.long if long_indices else torch.int32,
)
weights = torch.rand((num_indices,))
global_idscore_lengths.append(lengths)
global_idscore_indices.append(indices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _test_sharded_forward(
dense_device=cuda_device,
sparse_device=cuda_device,
generate=generate,
long_indices=False,
)
global_model = quantize_callable(global_model, **quantize_callable_kwargs)
local_input = _inputs[0][1][default_rank].to(cuda_device)
Expand Down
5 changes: 5 additions & 0 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __call__(
Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]]
] = None,
variable_batch_size: bool = False,
long_indices: bool = True,
) -> Tuple["ModelInput", List["ModelInput"]]:
...

Expand All @@ -121,6 +122,7 @@ def generate_inputs(
batch_size: int = 4,
num_float_features: int = 16,
variable_batch_size: bool = False,
long_indices: bool = True,
) -> Tuple[ModelInput, List[ModelInput]]:
return generate(
batch_size=batch_size,
Expand All @@ -130,6 +132,7 @@ def generate_inputs(
dedup_tables=dedup_tables,
weighted_tables=weighted_tables or [],
variable_batch_size=variable_batch_size,
long_indices=long_indices,
)


Expand All @@ -148,6 +151,7 @@ def gen_model_and_input(
variable_batch_size: bool = False,
batch_size: int = 4,
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
long_indices: bool = True,
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
torch.manual_seed(0)
if dedup_feature_names:
Expand Down Expand Up @@ -188,6 +192,7 @@ def gen_model_and_input(
num_float_features=num_float_features,
variable_batch_size=variable_batch_size,
batch_size=batch_size,
long_indices=long_indices,
)
]
return (model, inputs)
Expand Down
8 changes: 4 additions & 4 deletions torchrec/distributed/tests/test_fx_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def shard_modules_QEBC(
env=ShardingEnv.from_local(world_size=world_size, rank=0),
)

inputs = prep_inputs(model_info, world_size)
inputs = prep_inputs(model_info, world_size, long_indices=False)

return (
model_info.quant_model,
Expand All @@ -214,7 +214,7 @@ def shard_modules_QEC(
env=ShardingEnv.from_local(world_size=world_size, rank=0),
)

inputs = prep_inputs(model_info, world_size)
inputs = prep_inputs(model_info, world_size, long_indices=False)

return (
model_info.quant_model,
Expand Down Expand Up @@ -259,7 +259,7 @@ def DMP_QEBC(

dmp = dmp.copy(model_info.sparse_device)

inputs = prep_inputs(model_info, world_size)
inputs = prep_inputs(model_info, world_size, long_indices=False)

m = dmp.module if unwrap_dmp else dmp
return (
Expand Down Expand Up @@ -305,7 +305,7 @@ def DMP_QEC(
)
model_info.model = m.module

inputs = prep_inputs(model_info, world_size)
inputs = prep_inputs(model_info, world_size, long_indices=False)

return (
model_info.quant_model,
Expand Down
14 changes: 7 additions & 7 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_rw(self, weight_dtype: torch.dtype) -> None:
)
inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size)
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]

sharded_model.load_state_dict(non_sharded_model.state_dict())
Expand Down Expand Up @@ -205,7 +205,7 @@ def test_cw(self, test_permute: bool, weight_dtype: torch.dtype) -> None:
)
inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size)
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]
sharded_model.load_state_dict(non_sharded_model.state_dict())
# torchrec.distributed.test_utils.test_sharding.copy_state_dict(sharded_model.state_dict(), non_sharded_model.state_dict()) does not work for CW due to non-trivial qscaleshift copy which is handled in shardedQEBC load_state_dict
Expand Down Expand Up @@ -292,7 +292,7 @@ def test_cw_with_smaller_emb_dim(
)
inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size)
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]
sharded_model.load_state_dict(non_sharded_model.state_dict())
# torchrec.distributed.test_utils.test_sharding.copy_state_dict(sharded_model.state_dict(), non_sharded_model.state_dict()) does not work for CW due to non-trivial qscaleshift copy which is handled in shardedQEBC load_state_dict
Expand Down Expand Up @@ -387,7 +387,7 @@ def test_cw_multiple_tables_with_permute(self, weight_dtype: torch.dtype) -> Non
)
inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size)
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]
sharded_model.load_state_dict(non_sharded_model.state_dict())
# torchrec.distributed.test_utils.test_sharding.copy_state_dict(sharded_model.state_dict(), non_sharded_model.state_dict()) does not work for CW due to non-trivial qscaleshift copy which is handled in shardedQEBC load_state_dict
Expand Down Expand Up @@ -489,7 +489,7 @@ def test_cw_irregular_shard_placement(self, weight_dtype: torch.dtype) -> None:
)
inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size)
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]
sharded_model.load_state_dict(non_sharded_model.state_dict())
# torchrec.distributed.test_utils.test_sharding.copy_state_dict(sharded_model.state_dict(), non_sharded_model.state_dict()) does not work for CW due to non-trivial qscaleshift copy which is handled in shardedQEBC load_state_dict
Expand Down Expand Up @@ -587,7 +587,7 @@ def test_cw_sequence(self, weight_dtype: torch.dtype) -> None:
)
inputs = [
model_input_to_forward_args_kjt(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size)
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]

sharded_model.load_state_dict(non_sharded_model.state_dict())
Expand Down Expand Up @@ -687,7 +687,7 @@ def test_rw_sequence(self, weight_dtype: torch.dtype) -> None:

inputs = [
model_input_to_forward_args_kjt(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size)
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]

sharded_model.load_state_dict(non_sharded_model.state_dict())
Expand Down
12 changes: 6 additions & 6 deletions torchrec/distributed/tests/test_quant_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_qebc_pruned_tw(self) -> None:

inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size)
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]
sharded_model.load_state_dict(quant_model.state_dict())
quant_output = quant_model(*inputs[0])
Expand Down Expand Up @@ -264,8 +264,8 @@ def test_qebc_pruned_tw_one_ebc(self) -> None:

kjt = KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0"],
values=torch.LongTensor([0, 1, 2]).cuda(),
lengths=torch.LongTensor([1, 1, 1]).cuda(),
values=torch.tensor([0, 1, 2], dtype=torch.int32).cuda(),
lengths=torch.tensor([1, 1, 1], dtype=torch.int32).cuda(),
weights=None,
)

Expand Down Expand Up @@ -352,7 +352,7 @@ def test_qebc_pruned_cw(self) -> None:

inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size)
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]
sharded_model.load_state_dict(quant_model.state_dict())
quant_output = quant_model(*inputs[0])
Expand Down Expand Up @@ -485,8 +485,8 @@ def test_qebc_pruned_cw_one_ebc(self) -> None:

kjt = KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0"],
values=torch.LongTensor([0, 1, 2, 197, 198, 199]).cuda(),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]).cuda(),
values=torch.tensor([0, 1, 2, 197, 198, 199], dtype=torch.int32).cuda(),
lengths=torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.int32).cuda(),
weights=None,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def test_quant_pred_shard(
num_float_features=10,
tables=self.tables,
weighted_tables=[],
long_indices=False,
)
local_batch = local_batch.to(device)
sharded_quant_model(local_batch.idlist_features)
26 changes: 18 additions & 8 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@
DEFAULT_ROW_ALIGNMENT = 16


@torch.fx.wrap
def set_fake_stbe_offsets(values: torch.Tensor) -> torch.Tensor:
return torch.arange(
0,
values.numel() + 1,
device=values.device,
dtype=values.dtype,
)


def for_each_module_of_type_do(
module: nn.Module,
module_types: List[Type[torch.nn.Module]],
Expand Down Expand Up @@ -464,14 +474,14 @@ def forward(
embeddings.append(
# Syntax for FX to generate call_module instead of call_function to keep TBE copied unchanged to fx.GraphModule, can be done only for registered module
emb_op(
indices=indices.int(),
offsets=offsets.int(),
indices=indices,
offsets=offsets,
per_sample_weights=weights if self._is_weighted else None,
)
if self.register_tbes
else emb_op.forward(
indices=indices.int(),
offsets=offsets.int(),
indices=indices,
offsets=offsets,
per_sample_weights=weights if self._is_weighted else None,
)
)
Expand Down Expand Up @@ -659,7 +669,7 @@ def __init__( # noqa C901
else EmbeddingLocation.DEVICE,
)
],
pooling_mode=PoolingMode.NONE,
pooling_mode=PoolingMode.SUM,
weight_lists=weight_lists,
device=device,
output_dtype=data_type_to_sparse_type(dtype_to_data_type(output_dtype)),
Expand Down Expand Up @@ -733,12 +743,12 @@ def forward(
):
f = jt_dict[feature_name]
values = f.values()
offsets = f.offsets()
offsets = set_fake_stbe_offsets(values)
# Syntax for FX to generate call_module instead of call_function to keep TBE copied unchanged to fx.GraphModule, can be done only for registered module
lookup = (
emb_module(indices=values.int(), offsets=offsets.int())
emb_module(indices=values, offsets=offsets)
if self.register_tbes
else emb_module.forward(indices=values.int(), offsets=offsets.int())
else emb_module.forward(indices=values, offsets=offsets)
)
feature_embeddings[embedding_name] = JaggedTensor(
values=lookup,
Expand Down