Skip to content

Add CUDA graph-based all reduce launcher #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions benchmark/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main(args: argparse.Namespace):
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_batch_size=args.max_batch_size,
max_num_batched_tokens=args.max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
Expand Down Expand Up @@ -94,6 +94,7 @@ def profile_step(profile=False):
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
args = parser.parse_args()
args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len)
args.max_num_batched_tokens = max(
args.max_num_batched_tokens, args.batch_size * args.input_len)
print(args)
main(args)
13 changes: 7 additions & 6 deletions cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
dtype: str,
seed: int,
swap_space: int,
max_batch_size: int,
max_num_batched_tokens: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
Expand All @@ -43,7 +43,7 @@ def __init__(
tensor_parallel_size=tensor_parallel_size,
)
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=max_batch_size)
max_num_batched_tokens=max_num_batched_tokens)
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
swap_space=swap_space)
print(f'# GPU blocks: {self.num_gpu_blocks}, '
Expand All @@ -66,6 +66,7 @@ def __init__(
dtype=dtype,
seed=seed,
model_path=model_path,
max_num_batched_tokens=max_num_batched_tokens,
)
self.controllers.append(controller)

Expand All @@ -75,7 +76,7 @@ def __init__(
block_size=block_size,
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
max_num_batched_tokens=max_batch_size,
max_num_batched_tokens=max_num_batched_tokens,
)
# Connect the controllers.
for i in range(len(self.controllers) - 1):
Expand Down Expand Up @@ -168,14 +169,14 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
help='model path to download and load the weights')
# Parallel arguments
parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens')
return parser
71 changes: 71 additions & 0 deletions cacheflow/parallel_utils/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None

_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None

def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
Expand Down Expand Up @@ -205,6 +206,20 @@ def initialize_model_parallel(
_set_global_memory_buffer()


def initialize_all_reduce_launcher(
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
global _ALL_REDUCE_LAUNCHER
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
max_num_tokens=max_num_tokens,
hidden_size=hidden_size,
dtype=dtype,
disable_graph=disable_graph,
)

def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
Expand Down Expand Up @@ -491,6 +506,9 @@ def get_global_memory_buffer():
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER

def get_all_reduce_launcher() -> 'GraphAllReduce':
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
return _ALL_REDUCE_LAUNCHER

def destroy_model_parallel():
"""Set the groups to none."""
Expand Down Expand Up @@ -520,3 +538,56 @@ def destroy_model_parallel():
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None


class GraphAllReduce:

def __init__(
self,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.disable_graph = disable_graph

tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size == 1:
return

self.group = get_tensor_model_parallel_group()
self.buffer = torch.empty(
size=(max_num_tokens, hidden_size),
dtype=dtype,
device='cuda',
)

# Build graphs for different number of tokens.
if not self.disable_graph:
self.graphs = {}
for num_tokens in range(8, max_num_tokens + 1, 8):
self.graphs[num_tokens] = self._build_graph(num_tokens)

def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
# Warm up.
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()

# Build graph.
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
torch.distributed.all_reduce(
self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
return graph

def launch(self, x: torch.Tensor) -> torch.Tensor:
# NOTE: x must be a slice of self.buffer.
num_tokens = x.shape[0]
if self.disable_graph:
torch.distributed.all_reduce(x, group=self.group)
else:
self.graphs[num_tokens].replay()
return x
19 changes: 13 additions & 6 deletions cacheflow/parallel_utils/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_all_reduce_launcher,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
Expand Down Expand Up @@ -407,8 +408,7 @@ def __init__(self, input_size, output_size, *,
self.bias.zero_()
else:
self.register_parameter('bias', None)


self.weight_t = self.weight.t()

def forward(self, input_):
"""Forward of RowParallelLinear
Expand All @@ -425,11 +425,18 @@ def forward(self, input_):
input_parallel = input_
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
if get_tensor_model_parallel_world_size() == 1:
# Matrix multiply.
output_ = F.linear(input_parallel, self.weight)
else:
# Matrix multiply.
all_reduce_launcher = get_all_reduce_launcher()
num_tokens = input_parallel.shape[0]
output_buffer = all_reduce_launcher.buffer[:num_tokens]
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
# All-reduce across all the partitions.
output_ = all_reduce_launcher.launch(output_buffer)

# All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
Expand Down
2 changes: 2 additions & 0 deletions cacheflow/worker/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
dtype: str,
seed: int,
model_path: str,
max_num_batched_tokens: int,
) -> None:
self.stage_id = stage_id
self.stage_devices = stage_devices
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
model_path=model_path,
max_num_batched_tokens=max_num_batched_tokens,
)
self.workers.append(worker)

Expand Down
7 changes: 6 additions & 1 deletion cacheflow/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from cacheflow.sequence import SequenceOutputs
from cacheflow.worker.cache_engine import CacheEngine
from cacheflow.parallel_utils.parallel_state import (
initialize_model_parallel, get_tensor_model_parallel_world_size)
initialize_model_parallel,
initialize_all_reduce_launcher,
get_tensor_model_parallel_world_size)
from cacheflow.utils import set_random_seed


Expand All @@ -27,6 +29,7 @@ def __init__(
rank: int,
world_size: int,
model_path: str,
max_num_batched_tokens: int,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
) -> None:
Expand All @@ -44,6 +47,8 @@ def __init__(
self.model = self.model.cuda()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
initialize_all_reduce_launcher(
max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
self.num_layers = self.model.config.num_hidden_layers
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
Expand Down
2 changes: 1 addition & 1 deletion simple_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main(args: argparse.Namespace):
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_batch_size=args.max_batch_size,
max_num_batched_tokens=args.max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
Expand Down