Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion torchtitan/experiments/auto_parallel/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,48 @@ def input_fn():
"dp_shard": Shard(0),
"tp": Replicate(),
}
# only used if loss parallel is enabled
possible_output_shardings = {
# maps relative to mesh dim names used in torchtitan
"dp_shard": Shard(0),
"tp": Shard(2),
}
assert all(
name in possible_input_shardings for name in world_mesh.mesh_dim_names
), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
x_sharding = tuple(
possible_input_shardings[name] for name in world_mesh.mesh_dim_names
)
out_sharding = x_sharding
if parallel_dims.loss_parallel_enabled:
out_sharding = tuple(
possible_output_shardings[name]
for name in world_mesh.mesh_dim_names
if name != "dp_replicate"
)
autop.add_input_constraints([x_sharding])
autop.add_output_constraints([x_sharding])
autop.add_output_constraints([out_sharding])
t0 = time.time()
sharding_placement = autop.optimize_placement()
t1 = time.time()
logger.info(f"AutoParallel took {t1 - t0} seconds")
parallel_mod = autop.apply_placement(sharding_placement)

if parallel_dims.loss_parallel_enabled:

# current PyTorch's implementation of loss parallel assumes
# that the DTensor has a 1d device mesh. This is not true
# in our case, but we can work around it by adding
# casting the output to a DTensor on a 1d device mesh.
# We should just use AutoParallel to do this for us, but
# it would require putting the loss inside the model as well
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that overall we should just put the loss in the model, but I like the approach here for now because it's useful to be as structurally similar to torchtitan as possible for drop-in purposes

def _return_as_dtensor_for_loss_parallel(module, args, output):
return torch.distributed.tensor.DTensor.from_local(
output, world_mesh["tp"], (Shard(2),)
)

# not keeping a reference to the hook, don't plan on
# removing it at any point
parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel)

return parallel_mod