Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deepspeed/compile/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,10 @@ def compiler_fn(gm, sample_inputs):
raise ValueError(f"Unsupported backend {backend}")

return backend_fn


def make_autosp_backend(backend, compile_kwargs={}, free_activation=False, debug_log=False, sp_size=2, dp_size=1):
def backend_fn(gm: GraphModule, real_inputs):
apply_autosp(gm, real_inputs, debug_log, sp_size=sp_size, dp_size=dp_size)
return torch._inductor.compile(gm, real_inputs)
return backend_fn
11 changes: 11 additions & 0 deletions deepspeed/compile/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

# DeepSpeed Team

from typing import List, Optional, Literal
from deepspeed.runtime.config_utils import DeepSpeedConfigModel

PassName = Literal["z1", "z3", "autosp"]

class CompileConfig(DeepSpeedConfigModel):
""" Configure compile settings """
Expand Down Expand Up @@ -53,3 +55,12 @@ class CompileConfig(DeepSpeedConfigModel):

keep_all_input_tensors: bool = False
""" Keep real values for all input tensors in InputStorage instead of using dummy values """

passes: Optional[List[PassName]] = None
""" Composes different optimizations. """

sp_size: int = 1
""" SP group-size """

dp_size: int = 1
""" DP group-size """
74 changes: 74 additions & 0 deletions deepspeed/compile/custom_ops/all_to_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not about this file) Don't we need __init__.py in custom_ops?

import torch.distributed as dist
from .sp_dp_registry import get_group, is_setup, sp_size, dp_size

@torch.library.custom_op("autosp::all_to_all", mutates_args=())
def all_to_all(
input: torch.Tensor,
scatter_idx: int,
gather_idx: int,
name: str,
) -> torch.Tensor:
"""
All-to-all collective for SDPA tensors [B, N, S, H].

For QKV (scatter_idx=1, gather_idx=2):
[B, N, S/P, H] -> [B, N/P, S, H]
For O (scatter_idx=2, gather_idx=1):
[B, N/P, S, H] -> [B, N, S/P, H]
"""
assert is_setup(), 'Incorrect initialization of SP/DP mesh.'
B, dim1, dim2, H = input.shape
gid = dist.get_rank() // sp_size()
group = get_group(gid)

if scatter_idx == 1:
N, local_S = dim1, dim2
input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H)
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()

output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)

output = output.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(B, N // sp_size(), sp_size() * local_S, H)
else: # scatter_idx == 2, O: scatter sequence, gather heads
local_N, S = dim1, dim2
input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H)
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()

output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)

output = output.permute(1, 0, 2, 3, 4).contiguous()
output = output.reshape(B, sp_size() * local_N, S // sp_size(), H)

return output


@torch.library.register_fake("autosp::all_to_all")
def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str):
B, dim1, dim2, H = input.shape
if scatter_idx == 1:
return input.new_empty(B, dim1 // sp_size(), dim2 * sp_size(), H)
else:
return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H)


def _all_to_all_backward_setup(ctx, inputs, output):
_, scatter_idx, gather_idx, name = inputs
ctx.scatter_idx = gather_idx
ctx.gather_idx = scatter_idx
ctx.name = name + "_grad"


def _all_to_all_backward(ctx, grad):
return (
all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name),
None, None, None, None
)


torch.library.register_autograd(
"autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup
)
48 changes: 48 additions & 0 deletions deepspeed/compile/custom_ops/sp_dp_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import torch.distributed as dist

GROUP_REGISTRY = {} # int -> dist.ProcessGroup

def register_groups(groups):
"""groups: List[List[int]], e.g. [[0,1],[2,3]]"""
for gid, ranks in enumerate(groups):
if gid not in GROUP_REGISTRY:
GROUP_REGISTRY[gid] = dist.new_group(ranks)

def get_group(gid: int):
return GROUP_REGISTRY[gid] if gid is not None else dist.group.WORLD

def get_registry():
return GROUP_REGISTRY

def is_setup():
return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False

def sp_size():
assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.'

return GROUP_REGISTRY['SP_SIZE']

def dp_size():
assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly'

return GROUP_REGISTRY['DP_SIZE']

def populate_registry(SP_SIZE, DP_SIZE):
""" Populate rank to SP/DP mesh index. """

if GROUP_REGISTRY.get('is_reg', False):
return

group_listing = []
offset = 0
for _ in range(DP_SIZE):
group_listing.append([i + offset for i in range(SP_SIZE)])
offset += SP_SIZE

register_groups(group_listing)

## Extraneous metadata required for proper instatiation. ##
GROUP_REGISTRY['SP_SIZE'] = SP_SIZE
GROUP_REGISTRY['DP_SIZE'] = DP_SIZE
GROUP_REGISTRY['is_reg'] = True
29 changes: 27 additions & 2 deletions deepspeed/compile/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

# DeepSpeed Team

from typing import Callable, Any, List, Dict
from typing import Callable, Any, List, Dict, Optional
from collections import defaultdict

import torch
from torch.fx import Node, Graph
from torch.fx import Node, Graph, GraphModule

from .util import get_last_uses

Expand Down Expand Up @@ -138,3 +138,28 @@ def free_tensors(tensors: List[torch.Tensor]):

# Python version for debugging
# graph.create_node('call_function', free_tensors, args, {}, name=node_name)

def find_node_by_name(gm: GraphModule, name: str) -> Optional[Node]:
for node in gm.graph.nodes:
if node.name == name:
return node
return None

def get_node_shape_meta(node: Node) -> Optional[torch.Tensor]:
return node.meta.get("val") or node.meta.get("example_value")

def find_node_by_tag(gm: GraphModule, tag: str) -> Optional[Node]:
input_id_node = None
for node in gm.graph.nodes:
# https://github.com/pytorch/pytorch/blob/085b71eab05cbc7d474a173884269c62d2778f77/torch/_dynamo/utils.py#L5048
tensor_dict = node.meta.get('tensor_dict')
if tensor_dict and tensor_dict.get('tag') == tag:
input_id_node = node
break
return input_id_node

def replace_node_users(node: Node, replacement: Node, exclude: Optional[List[Node]] = None):
exclude = exclude or []
to_replace = [u for u in node.users if u not in exclude]
for user in to_replace:
user.replace_input_with(node, replacement)
14 changes: 14 additions & 0 deletions deepspeed/compile/init_sp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from torch.fx import GraphModule
from .passes.sp_compile import apply_autosp

def init_autosp(compile_config):
def backend_fn(gm: GraphModule, real_inputs):
apply_autosp(gm, real_inputs, debug=False, sp_size=compile_config.sp_size, dp_size=compile_config.dp_size)
return torch._inductor.compile(gm, real_inputs)
return backend_fn
Loading