Skip to content
54 changes: 45 additions & 9 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch._logging import trace_structured
from torch._subclasses import FakeTensorMode
from torch.distributed.tensor import DeviceMesh
from torch.nn.utils import stateless

from .apply_sharding import apply_sharding_to_model
from .export_module import aot_export_module, apply_node_renaming
Expand Down Expand Up @@ -380,21 +381,56 @@ def apply_placement(self, sharding_placement=None):
)
self.parallel_gm = parallel_gm

param_names = [k.replace(".", "/") for k, _ in self.model.named_parameters()]
buffer_names = [k.replace(".", "/") for k, _ in self.model.named_buffers()]
param_names = [k for k, _ in self.model.named_parameters()]
buffer_names = [k for k, _ in self.model.named_buffers()]
param_names_no_fqns = [k.replace(".", "/") for k in param_names]
buffer_names_no_fqns = [k.replace(".", "/") for k in buffer_names]
assert len(param_names) == len(sharded_weights)
assert len(buffer_names) == len(sharded_buffers)
sharded_weights = {k: v for k, v in zip(param_names, sharded_weights)}
sharded_buffers = {k: v for k, v in zip(buffer_names, sharded_buffers)}

self.sharded_weights = sharded_weights
self.sharded_buffers = sharded_buffers
sharded_weights_no_fqns = {
k: v for k, v in zip(param_names_no_fqns, sharded_weights)
}
sharded_buffers_no_fqns = {
k: v for k, v in zip(buffer_names_no_fqns, sharded_buffers)
}

# TODO: preserve state dict properly in the generated nn.module
self.sharded_weights = sharded_weights_no_fqns
self.sharded_buffers = sharded_buffers_no_fqns
self.parallel_model_fn, self.fwd_gm, self.bwd_gm = prepare_module(
parallel_gm, self.spec, self.metadata.num_outputs
)

sharded_weights = try_convert_fake_to_real(sharded_weights)
sharded_buffers = try_convert_fake_to_real(sharded_buffers)
self.parallel_model = self.parallel_model_fn(sharded_weights, sharded_buffers)
self.parallel_model = self.parallel_model_fn(
sharded_weights_no_fqns, sharded_buffers_no_fqns
)

# Right now we require a convention that the user model provides an init_weights method,
# although we could snoop for other methods too.
if hasattr(self.model, "init_weights"):

def init_weights(*args, **kwargs):
# TODO: once we have proper FQN support we should remove this
# Replace 'params.tok_embeddings/weight' -> 'tok_embeddings.weight'
# Replace 'buffers_.freqs_cis' -> 'freqs_cis'
sharded_params_buffers = {
k.replace("params.", "")
.replace("buffers_.", "")
.replace("/", "."): v
for k, v in self.parallel_model.state_dict().items()
}
with stateless._reparametrize_module(
self.model, sharded_params_buffers
):
self.model.init_weights(*args, **kwargs)

else:
init_weights = None

# assign an init_weights method onto the output mod.
# all it does is sneakily run the original user mod's init_weights method,
# but with our new DTensor sharded params attached to the user module.
self.parallel_model.init_weights = init_weights
Comment on lines +431 to +434
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with this for now, but then we would need to clearly specify the contract of what methods are propagated from the original model vs not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. if we want to be extreme, we could make this a strict requirement and error out if the user's mod doesn't have a init_weights method. Or maybe once we get further along adoption-wise we should just write some docs that clearly spell out the restrictions / user requirements, this being one of them?


return self.parallel_model
33 changes: 28 additions & 5 deletions autoparallel/apply_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import operator

import torch
from torch._subclasses.fake_tensor import FakeTensor, unset_fake_temporarily
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._redistribute import redistribute_local_tensor
Expand Down Expand Up @@ -155,7 +157,7 @@ def call_function(self, target, args, kwargs):
return out


def shard_nodes_given_placements(gm, sharding_placement, node_prefix):
def shard_nodes_given_placements(gm, sharding_placement, node_prefix, *, meta=False):
# NOTE: this relies my customized export_module
nodes = [
x for x in gm.graph.find_nodes(op="placeholder") if node_prefix in x.target
Expand All @@ -167,9 +169,21 @@ def shard_nodes_given_placements(gm, sharding_placement, node_prefix):
# all tensors start as replicated
curr_placement = (Replicate(),) * mesh.ndim
tensor = node.meta["val"]
sharded_tensor = DTensor.from_local(tensor, mesh, curr_placement).redistribute(
mesh, tgt_spec.placements
)

if meta:
assert isinstance(
tensor, FakeTensor
), f"only FakeTensor params supported for now, got {type(tensor)}"
ctx = unset_fake_temporarily
with ctx():
tensor = torch.randn(tensor.shape, dtype=tensor.dtype, device="meta")
else:
ctx = contextlib.nullcontext

with ctx():
sharded_tensor = DTensor.from_local(
tensor, mesh, curr_placement
).redistribute(mesh, tgt_spec.placements)
sharded_tensors.append(sharded_tensor)
return sharded_tensors

Expand All @@ -189,4 +203,13 @@ def apply_sharding_to_model(gm, sharding_placement):
args = [x.to_local() for x in args]
parallel_gm = make_fx(interp.run)(*args)

return parallel_gm, sharded_params, sharded_buffers
# We put DTensor(meta_tensor) tensors in the state dict, as the user expects to be
# able to call parallel_mod.to_empty(device='cuda'). This does not work with FakeTensors.
sharded_meta_params = shard_nodes_given_placements(
gm, sharding_placement, "param", meta=True
)
sharded_meta_buffers = shard_nodes_given_placements(
gm, sharding_placement, "buffer", meta=True
)

return parallel_gm, sharded_meta_params, sharded_meta_buffers
1 change: 0 additions & 1 deletion autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def fill_missing_redistribute_cost(op, specs, out_strat):
for strat in out_strat.strategies:
# TODO: check me
if strat.redistribute_cost is None:

# TODO: the torch.ops.aten.slice.Tensor is wrong here and in the input_spec!!!!!
handled_ops = {
torch.ops.aten.ones_like.default,
Expand Down
9 changes: 9 additions & 0 deletions examples/example_autoparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def __init__(self, nheads, dim1, dim2):
self.w1 = nn.Linear(dim1, dim2, bias=bias)
self.w2 = nn.Linear(dim2, dim1, bias=bias)

def init_weights(self):
for lin in [self.wq, self.wk, self.wv, self.wo, self.w1, self.w2]:
torch.nn.init.normal_(lin.weight)
if lin.bias is not None:
torch.nn.init.normal_(lin.bias)

def forward(self, x):
q = self.wq(x)
k = self.wk(x)
Expand Down Expand Up @@ -94,6 +100,9 @@ def input_fn():
sharding_placement = autop.optimize_placement()
parallel_mod = autop.apply_placement(sharding_placement)

# run weight init on our sharded DTensor params
parallel_mod.init_weights()

# now let's run it
x = (torch.rand(bs // mesh.shape[0], seq_len, dim1, device="cuda"),)
out = parallel_mod(*x)
Expand Down
4 changes: 3 additions & 1 deletion examples/example_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ def __init__(self, model_args: TransformerModelArgs):
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

def init_weights(
self,
Expand Down Expand Up @@ -628,6 +627,9 @@ def input_fn():
print(f"Took {time.time() - t:.2f} s")
parallel_mod = autop.apply_placement(sharding_placement)

# run weight init on our sharded DTensor params
parallel_mod.init_weights()

# now let's run it
x = (
torch.randint(
Expand Down