Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,14 +744,17 @@ def _get_node_args_helper(
pipelined_preprocs: Set[PipelinedPreproc],
context: TrainPipelineContext,
pipeline_preproc: bool,
# Add `None` constants to arg info only for preproc modules
# Defaults to False for backward compatibility
for_preproc_module: bool = False,
) -> Tuple[List[ArgInfo], int]:
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
"""
arg_info_list = [ArgInfo([], [], [], [], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list):
if arg is None:
if not for_preproc_module and arg is None:
num_found += 1
continue
while True:
Expand Down Expand Up @@ -911,7 +914,12 @@ def _get_node_args_helper(
# is either made of preproc module or non-modifying train batch input
# transformations
preproc_args, num_found_safe_preproc_args = _get_node_args(
model, child_node, pipelined_preprocs, context, pipeline_preproc
model,
child_node,
pipelined_preprocs,
context,
pipeline_preproc,
True,
)
if num_found_safe_preproc_args == total_num_args:
logger.info(
Expand Down Expand Up @@ -956,6 +964,7 @@ def _get_node_args(
pipelined_preprocs: Set[PipelinedPreproc],
context: TrainPipelineContext,
pipeline_preproc: bool,
for_preproc_module: bool = False,
) -> Tuple[List[ArgInfo], int]:
num_found = 0

Expand All @@ -966,6 +975,7 @@ def _get_node_args(
pipelined_preprocs,
context,
pipeline_preproc,
for_preproc_module,
)
kwargs_arg_info_list, num_found = _get_node_args_helper(
model,
Expand All @@ -974,6 +984,7 @@ def _get_node_args(
pipelined_preprocs,
context,
pipeline_preproc,
for_preproc_module,
)

# Replace with proper names for kwargs
Expand Down