Skip to content

[fp8] support fp8 amp for hybrid parallel plugin #5975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions colossalai/booster/plugin/fp8_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch.nn.functional as F

from colossalai.quantization.fp8 import linear_fp8
from colossalai.tensor.param_op_hook import ColoParamOpHook


class FP8Hook(ColoParamOpHook):
def pre_forward(self, params) -> None:
pass

def post_forward(self, params) -> None:
pass

def pre_backward(self, params) -> None:
pass

def post_backward(self, params) -> None:
pass

def rewrite_op(self, func):
if func is F.linear:
return linear_fp8
return func
18 changes: 16 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle

from .fp8_hook import FP8Hook
from .pp_plugin_base import PipelinePluginBase

SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
Expand All @@ -66,6 +67,7 @@ def __init__(
ddp_config: dict,
custom_policy: Policy,
overlap_allgather: bool = False,
use_fp8: bool = False,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
Expand All @@ -75,6 +77,7 @@ def __init__(
self.use_dpp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8

shardformer = ShardFormer(shard_config)
if custom_policy is not None:
Expand Down Expand Up @@ -112,8 +115,12 @@ def __init__(
module = DDP(module, process_group=dp_group, **ddp_config)

super().__init__(module)
self.op_hooks = []
if overlap_allgather:
self.op_hook = ZeroOpHook()
self.op_hooks.append(ZeroOpHook())
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
Expand Down Expand Up @@ -223,7 +230,11 @@ def _force_wait_all_gather(self):
wait_all_gather_handle(p)

def _wait_all_gather(self):
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
return (
ColoParamOpHookManager.use_hooks(*self.op_hooks)
if (self.overlap_allgather or self.use_fp8)
else nullcontext()
)


def get_param_info(optim: Optimizer):
Expand Down Expand Up @@ -1019,6 +1030,7 @@ def __init__(
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -1063,6 +1075,7 @@ def __init__(
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.use_fp8 = use_fp8
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
Expand Down Expand Up @@ -1243,6 +1256,7 @@ def configure(
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:
Expand Down
3 changes: 2 additions & 1 deletion colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ def forward(
if bias is not None:
assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
# ensure x and w are row-major
assert x.is_contiguous() and w.is_contiguous(), "Input and weight should be contiguous."
x = x.contiguous()
w = w.contiguous()
ctx.x_shape = x.shape
ctx.has_bias = bias is not None
ctx.out_dtype = x.dtype
Expand Down
2 changes: 2 additions & 0 deletions colossalai/tensor/colo_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __torch_function__(cls, func, types, args=..., kwargs=None):
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
with torch._C.DisableTorchFunction():
func = ColoParamOpHookManager.rewrite_op(func)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)
Expand Down
9 changes: 9 additions & 0 deletions colossalai/tensor/param_op_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def pre_backward(self, params: List[torch.Tensor]) -> None:
def post_backward(self, params: List[torch.Tensor]) -> None:
pass

def rewrite_op(self, func) -> Any:
return func


class ColoParamOpHookManager:
"""
Expand Down Expand Up @@ -101,6 +104,12 @@ def post_op(params: List[torch.Tensor], arg: Any) -> Any:
def has_hook() -> bool:
return len(ColoParamOpHookManager.hooks) > 0

@staticmethod
def rewrite_op(func) -> Any:
for hook in ColoParamOpHookManager.hooks:
func = hook.rewrite_op(func)
return func


class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
Expand Down
50 changes: 50 additions & 0 deletions tests/test_fp8/test_fp8_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from colossalai.accelerator import get_accelerator
from colossalai.booster.plugin.fp8_hook import FP8Hook
from colossalai.quantization.fp8 import linear_fp8
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device

REPLACED = False
TRIGGERED = False


def new_linear_fp8(x, w, bias=None):
global TRIGGERED
TRIGGERED = True
return linear_fp8(x, w, bias)


class FP8TestHook(FP8Hook):
def rewrite_op(self, func):
func = super().rewrite_op(func)
if func is linear_fp8:
global REPLACED
REPLACED = True
return new_linear_fp8
return func


D_IN, D_OUT = 16, 32
B, S = 2, 64
DTYPE = torch.bfloat16


@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_hook():
# create tensors
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
w.__class__ = ColoParameter
w.__init__(w, requires_grad=True)
hook = FP8TestHook()
with ColoParamOpHookManager.use_hooks(hook):
o = F.linear(x, w)
assert o.shape == (B, S, D_OUT)
assert REPLACED
assert TRIGGERED
Loading