Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
52ea0c1
[WIP] Add basic DeepSeekV3
fmassa Jul 4, 2025
0d3ae2d
Lint
fmassa Jul 4, 2025
98d9dfd
Workarounds to make graph capture pass
fmassa Jul 4, 2025
61a63c4
Add dummy propagation rules just to see what we need to implement
fmassa Jul 4, 2025
67eb264
Cleanup
fmassa Jul 4, 2025
86d53ff
prims.fma comes from softmax_backward
fmassa Jul 4, 2025
7864f4d
Make _geenrate_dummy_strategy more generic
fmassa Jul 5, 2025
60ccf1a
Add proper redistribute_cost to dummy strategies
fmassa Jul 5, 2025
dbbc205
Hack around missing dtypes in compute estimation and handle grouped_m…
fmassa Jul 5, 2025
d92f8c6
Add representative batch size
fmassa Jul 5, 2025
e25ff7b
Fix grouped_mm stride issue
wconstab Jul 18, 2025
3b7e7fa
get DS3 running forward, OOM at backward
wconstab Jul 18, 2025
3833a06
WIP factory_strategy
wconstab Jul 18, 2025
3740b45
Start rebasing on top of main
fmassa Jul 25, 2025
39fedfd
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 25, 2025
6bec5f5
Fixes so that it runs
fmassa Jul 25, 2025
ce1c0a5
[WIP] Plumb fake_mode to avoid materializing memory
fmassa Jul 26, 2025
5d79bec
Use more representative values for DS3 example
fmassa Jul 26, 2025
daea5a2
Add approximate flop formula to grouped_mm
fmassa Jul 26, 2025
6d350e0
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 27, 2025
418ad55
Glimpses of having DeepSeekV3 returning a reasonable solution
fmassa Jul 27, 2025
fce321f
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 30, 2025
6d5747a
Use with_implicit_strategies instead of my generate_dummy_strategy
fmassa Jul 30, 2025
e0ae8a2
[WIP] Convert view->mm->view into matmul
fmassa Jul 30, 2025
1b83581
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 31, 2025
cf1229d
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Aug 4, 2025
4fe5a40
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Aug 9, 2025
67542ad
Remove sharding rules that have been since moved to PyTorch
fmassa Aug 9, 2025
779e808
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Sep 4, 2025
124034e
Fixes after rebase
fmassa Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .optimize_sharding import ShardingOptimizer
from .utils import _get_device_from_mesh

_APPLY_VIEW_MM_VIEW_PATTERN = False
_APPLY_VIEW_MM_VIEW_PATTERN = True


def try_convert_fake_to_real(tensors):
Expand All @@ -60,6 +60,8 @@ def _get_decomp_table():
decomp_table.pop(torch.ops.aten.native_layer_norm.default)
decomp_table.pop(torch.ops.aten.embedding_dense_backward.default)
decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default)
decomp_table.pop(torch.ops.aten._softmax_backward_data.default)
decomp_table.pop(torch.ops.aten._softmax.default)

# decompose addmm to allow for TP on mm
decomp_table.pop(torch.ops.aten.addmm.default)
Expand Down Expand Up @@ -277,7 +279,7 @@ def build_model_graph(self):
_replace_view_mm_view_with_einsum(gm)
# now add aliases nodes to the graph to
# give more room for optimizations
_add_alias(gm, version="v1")
_add_alias(gm, version="v2")
trace_structured(
"artifact",
metadata_fn=lambda: {
Expand Down
42 changes: 37 additions & 5 deletions autoparallel/compute_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,37 @@
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula


@register_flop_formula(torch.ops.aten._grouped_mm)
def gmm_flop(
a_shape, b_shape, offs_shape=None, bias_shape=None, out_shape=None, **kwargs
) -> int:
"""Count flops for the gmm operation."""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor
if len(a_shape) == 2:
assert offs_shape is not None
(b,) = offs_shape
m0, k = a_shape
# assumption: assume roughtly balanced, so falls-back to bmm
m = m0 // b
else:
assert offs_shape is None
b, m, k = a_shape
if len(b_shape) == 2:
assert offs_shape is not None
(b2,) = offs_shape
k2, n0 = b_shape
# assumption: assume roughtly balanced, so falls-back to bmm
n = n0 // b2
else:
b2, k2, n = b_shape
assert b == b2
assert k == k2
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
flop = b * m * n * 2 * k
return flop


@register_flop_formula(torch.ops.aten.einsum, get_raw=True)
def einsum_flop(equation, tensors, out=None, **kwargs) -> int:
# from torch.distributed.tensor._ops._einsum_strategy import EinsumDims
Expand Down Expand Up @@ -180,12 +211,13 @@ def _get_device_tflops(dtype):
# from torch._inductor.utils import get_device_tflops

device_limit = _get_device_limit()
if dtype not in device_limit.gemm_tflops:
raise ValueError(
f"Dtype {dtype} not supported on {device_limit.name}. Supported dtypes: {list(device_limit.gemm_tflops.keys())}"
)
# TODO: add proper support for int64 etc
# if dtype not in device_limit.gemm_tflops:
# raise ValueError(
# f"Dtype {dtype} not supported on {device_limit.name}. Supported dtypes: {list(device_limit.gemm_tflops.keys())}"
# )

return device_limit.gemm_tflops[dtype]
return device_limit.gemm_tflops.get(dtype, 1)


def _get_device_gmem_bandwidth():
Expand Down
27 changes: 17 additions & 10 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,15 @@
_op_rules = {}


def register_rule(op):
def register_rule(ops):
global _op_rules

def wrapper(impl):
_op_rules[op] = impl
if isinstance(ops, list):
for op in ops:
_op_rules[op] = impl
else:
_op_rules[ops] = impl
return impl

return wrapper
Expand Down Expand Up @@ -626,23 +630,26 @@ def _unsafe_index_rule(mesh, op_schema):
raise NotImplementedError()


@register_opschema_rule(torch.ops.aten.index.Tensor)
# Disable this rule as it's implementation is inferior than the baseline
# @register_opschema_rule(torch.ops.aten.index.Tensor)
def index_rule(mesh, op_schema):
raise NotImplementedError("Needs hardening, only tested on a few cases")
print(f"Ops that need to be implemented {torch.ops.aten.index.Tensor}")
# raise NotImplementedError("Needs hardening, only tested on a few cases")
strat = op_schema.args_schema
specs = strat # TODO: clean this up
res = []
idxs_placements = [(Replicate(), Replicate()), (Shard(0), Replicate())]
if strat[1].childs[0] is None:
idxs_placements = idxs_placements[:1]
else:
idxs_placements = idxs_placements[1:]
idxs_placements = [(Replicate(),) * mesh.ndim]
# if strat[1].childs[0] is None:
# idxs_placements = idxs_placements[:1]
# else:
# idxs_placements = idxs_placements[1:]
# TODO: this is a nasty hack and won't work for most of the cases
for i, ss in enumerate(strat[0].strategies):
for i, ss in enumerate(strat[0].strategies[:1]):
for plt in idxs_placements:
ispec = ss.input_specs[0]
ospec = DTensorSpec(mesh=mesh, placements=ispec.placements)
assert ss.output_spec == ispec
# assert ss.output_spec == ispec, f"{ss.output_spec}, {ispec}"
idxs_strats = [
DTensorSpec(mesh, placements=plt)
for x in strat[1].childs
Expand Down
23 changes: 16 additions & 7 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,30 @@
)


def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
def _get_meta_tensors_for_op(op, user_args, user_kwargs):
out_t = op(*user_args, **user_kwargs)

if isinstance(out_t, torch.Tensor):
new_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype)
out_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype)
else:
new_tensor_meta = tree_map_only(
out_tensor_meta = tree_map_only(
torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), out_t
)

tensor_metas = tree_flatten(user_args)[0]
tensor_metas = tree_map_only(
torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), tensor_metas
input_tensor_metas = tree_flatten(user_args)[0]
input_tensor_metas = tree_map_only(
torch.Tensor,
lambda x: TensorMeta(x.shape, x.stride(), x.dtype),
input_tensor_metas,
)
input_tensor_metas = tuple(
x for x in input_tensor_metas if isinstance(x, TensorMeta)
)
tensor_metas = tuple(x for x in tensor_metas if isinstance(x, TensorMeta))
return out_tensor_meta, input_tensor_metas


def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
new_tensor_meta, tensor_metas = _get_meta_tensors_for_op(op, user_args, user_kwargs)

for strat in out_strat.strategies:
if isinstance(new_tensor_meta, TensorMeta):
Expand Down
Loading
Loading