Skip to content
6 changes: 6 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)

# #########################################
# # offload configuration
# #########################################
FUSEDLINEARPROMOTION = "fused_linear_promotion"
set_field_default_config(FUSEDLINEARPROMOTION, "enable", False)

#########################################
# fused passes configuration
#########################################
Expand Down
29 changes: 29 additions & 0 deletions python/paddle/distributed/auto_parallel/static/parallelizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,35 @@ def _apply_post_optimization(
)
sp_pass.apply([main_program], [startup_program], self._pass_context)

# apply fused linear promotion pass
if (
self.is_train
and self._strategy.fused_linear_promotion.enable
and self._strategy.fused_passes.enable
):
if (
len(self._strategy.fused_passes.fused_passes_list) > 0
and "fuse_gemm_epilogue"
in self._strategy.fused_passes.fused_passes_list
):
amp_config = None
if self._strategy.amp.enable:
amp_config = copy.deepcopy(self._strategy.amp.to_dict())
config = {}
config["dist_context"] = self._dist_context
config["global_rank"] = rank
config["enable_sp"] = self._strategy.sp_optimization.enable
config["params_grads"] = params_grads
config["amp_level"] = (
amp_config['level'] if amp_config is not None else "o0"
)
fused_linear_promotion_pass = new_pass(
"auto_parallel_fused_linear_promotion", config
)
fused_linear_promotion_pass.apply(
[main_program], [startup_program], self._pass_context
)

# data parallel optimization
if self._strategy.dp_optimization.enable:
config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
Expand Down
11 changes: 11 additions & 0 deletions python/paddle/distributed/auto_parallel/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def __init__(self, config_dict=None):
super().__init__(category, config_dict)


class FusedLinearPromotionConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.FUSEDLINEARPROMOTION
super().__init__(category, config_dict)


class AMPConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.AMP
Expand Down Expand Up @@ -224,6 +230,11 @@ def __init__(self, config=None):
config_dict = self._config_dict.get(constants.FUSED_PASSES, None)
self.fused_passes = FusedPassesConfig(config_dict)

config_dict = self._config_dict.get(
constants.FUSEDLINEARPROMOTION, None
)
self.fused_linear_promotion = FusedLinearPromotionConfig(config_dict)

config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None)
self.dp_optimization = DPOptimizationConfig(config_dict)

Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403
from .auto_parallel_grad_clip import * # noqa: F403
from .auto_parallel_fused_linear_promotion import * # noqa: F403
from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403
from .auto_parallel_pipeline import * # noqa: F403
from .auto_parallel_sequence_parallel_optimization import * # noqa: F403
Expand Down
Loading