Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/paddle/distributed/auto_parallel/cost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
from .base_cost import Cost
from .base_cost import CommContext
from .base_cost import build_comm_desc
from .base_cost import build_comp_desc_from_op
from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_dp_costs
from .base_cost import build_comp_costs_from_descs
from .tensor_cost import TensorCost
from .estimate_cost import CostEstimator

from .comp_op_cost import MatmulV2OpCost
from .comp_op_cost import FillConstantBatchSizeLikeOpCost

from .comm_op_cost import SendOpCost
from .comm_op_cost import RecvOpCost
Expand Down
18 changes: 0 additions & 18 deletions python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,24 +357,6 @@ def calc_time(self):
return 0


@register_op_cost
class FillConstantBatchSizeLikeGradOpCost(CompOpCost):
OP_TYPE = "fill_constant_batch_size_like_grad"

def __init__(self, op=None, op_desc=None, cluster=None):
super(FillConstantBatchSizeLikeGradOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class GatherOpCost(CompOpCost):
OP_TYPE = "gather"
Expand Down
71 changes: 71 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
from ..cost import _g_op_cost_factory
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs

__op_not_need_param_init__ = ["while", "cond"]

Expand Down Expand Up @@ -99,6 +102,74 @@ def __init__(self, name):
self._forward_implemented = True
self._backward_implemented = True

def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost

def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type],
ctx, processes, desc_mapping,
cluster)
res_cost = [cost_mapping]

return res_cost

def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
backward_op = dist_op.serial_op
op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type],
ctx, processes, desc_mapping,
cluster)
res.append(cost_mapping)

main_block = backward_op.block
vars = main_block.vars
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break

if need_gradient_allreduce:
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
var_dim_mapping = dist_attr.get_input_dims_mapping(
varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res

def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import register_distributed_operator_impl, is_parameter_related
from .common import is_elementwise_op
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
Expand All @@ -32,6 +32,9 @@
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0
from ..cost import _g_op_cost_factory
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs


class DistributedElementwise(DistributedOperatorImplContainer):
Expand All @@ -52,6 +55,74 @@ def __init__(self, name):
self._forward_implemented = False
self._backward_implemented = False

def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost

def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type],
ctx, processes, desc_mapping,
cluster)
res_cost = [cost_mapping]

return res_cost

def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
backward_op = dist_op.serial_op
op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type],
ctx, processes, desc_mapping,
cluster)
res.append(cost_mapping)

main_block = backward_op.block
vars = main_block.vars
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break

if need_gradient_allreduce:
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
var_dim_mapping = dist_attr.get_input_dims_mapping(
varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res

def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
if not is_elementwise_op(op_desc.type()):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from .dist_default import DistributedDefaultImpl0
from ..cost import FillConstantBatchSizeLikeOpCost
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from ..cost import AllreduceSumOpCost


class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer):
Expand All @@ -47,6 +52,29 @@ def __init__(self, name):
self._forward_implemented = True
self._backward_implemented = True

def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
raise ValueError(
"The fill_constant_batch_size_like has no grad op.")
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost

def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
FillConstantBatchSizeLikeOpCost, ctx, processes, desc_mapping,
cluster)

res_cost = [cost_mapping]
return res_cost

def is_input_compatible(self, dist_op):

return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS})
py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS})
py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS})
py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS})
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from paddle.distributed.auto_parallel.cost.comp_op_cost import EmbeddingGradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantBatchSizeLikeOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantBatchSizeLikeGradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import GatherOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import GeluOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import GeluGradOpCost
Expand Down Expand Up @@ -184,11 +183,6 @@ def test_comp_cost(self):
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)

op_cost = FillConstantBatchSizeLikeGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)

op_cost = GatherOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
Expand Down
Loading