Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[auto parallel] Add pp lazy init, bug fix for xavier #60441

Merged
merged 3 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1078,12 +1078,17 @@ static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
EAGER_TRY
paddle::Tensor* src_ptr =
&(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
PADDLE_ENFORCE_EQ(self->tensor.initialized(),
true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized! please initialize "
"src tensor before share_buffer_with to other.",
self->tensor.name()));
if (!self->tensor.initialized()) {
PADDLE_ENFORCE(self->tensor.is_dist_tensor() &&
!phi::distributed::IsCurRankInMesh(
static_cast<phi::distributed::DistTensor*>(
self->tensor.impl().get())
->process_mesh()),
platform::errors::InvalidArgument(
"Tensor %s has not been initialized! Please initialize "
"src tensor before share_buffer_with to other.",
self->tensor.name()));
}
src_ptr->set_impl(self->tensor.impl());
RETURN_PY_NONE

Expand Down
6 changes: 6 additions & 0 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def shard_tensor(
if isinstance(data, EagerParamBase):

def lazy_init_hook(param, origin_hook):
for placement in param.placements:
assert not placement.is_partial(), (
"Lazy init not support partial reshard. Notice that: shard a param to partial "
"won't save any memory, but will increase the communication cost!"
)

# lazy init hook with randomness controlling
def _init_func(var, block):
# get the unique rng name
Expand Down
18 changes: 13 additions & 5 deletions python/paddle/nn/initializer/xavier.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ def forward(self, var, block=None):
if self._seed == 0:
self._seed = block.program.random_seed

out_var_shape = (
var._local_shape
if (isinstance(var, framework.EagerParamBase) and var.is_dist())
else var.shape
)
# to be compatible of fp16 initalizers
if var.dtype == core.VarDesc.VarType.FP16 or (
var.dtype == core.VarDesc.VarType.BF16 and not self._uniform
Expand All @@ -114,9 +119,7 @@ def forward(self, var, block=None):
name=unique_name.generate(
".".join(['xavier_init', var.name, 'tmp'])
),
shape=var._local_shape
if (isinstance(var, framework.EagerParamBase) and var.is_dist())
else var.shape,
shape=out_var_shape,
dtype=out_dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
Expand All @@ -135,7 +138,7 @@ def forward(self, var, block=None):
if self._uniform:
limit = math.sqrt(6.0 / float(fan_in + fan_out))
out_var = _C_ops.uniform(
out_var.shape,
out_var_shape,
out_dtype,
-limit,
limit,
Expand All @@ -147,7 +150,12 @@ def forward(self, var, block=None):

place = _current_expected_place()
out_var = _C_ops.gaussian(
out_var.shape, 0.0, std, self._seed, out_dtype, place
out_var_shape,
0.0,
std,
self._seed,
out_dtype,
place,
)

if var.dtype == core.VarDesc.VarType.FP16 or (
Expand Down
81 changes: 61 additions & 20 deletions test/auto_parallel/semi_auto_parallel_lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,81 @@
class TestSemiAutoParallelLazyInit:
def __init__(self):
self._backend = os.getenv("backend")
self._placements_type = os.getenv("_placements_type")
self._seed = eval(os.getenv("seed"))
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
if self._placements_type == "DP":
self._mesh_weight = dist.ProcessMesh([0, 1], dim_names=["x"])
self._mesh_bias = dist.ProcessMesh([0, 1], dim_names=["x"])
self._placements_weight = [dist.Replicate()]
self._placements_bias = [dist.Replicate()]
elif self._placements_type == "PP":
self._mesh_weight = dist.ProcessMesh([0], dim_names=["x"])
self._mesh_bias = dist.ProcessMesh([1], dim_names=["x"])
self._placements_weight = [dist.Replicate()]
self._placements_bias = [dist.Replicate()]

def test_replicate(self):
def test_different_xavier(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
weight_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal()
)
bias_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.XavierUniform()
)
with LazyGuard():
linear = paddle.nn.Linear(
10, 10, weight_attr=weight_attr, bias_attr=bias_attr
)
linear.weight = dist.shard_tensor(
linear.weight, self._mesh_weight, self._placements_weight
)
linear.bias = dist.shard_tensor(
linear.bias, self._mesh_bias, self._placements_bias
)

def test_placements(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
with LazyGuard():
linear = paddle.nn.Linear(10, 10)
linear.weight = dist.shard_tensor(
linear.weight, self._mesh, [dist.Replicate()]
linear.weight, self._mesh_weight, self._placements_weight
)
linear.bias = dist.shard_tensor(
linear.bias, self._mesh, [dist.Replicate()]
linear.bias, self._mesh_bias, self._placements_bias
)
for param in linear.parameters():
assert not param._is_initialized()
param.initialize()
assert param._is_initialized()

local_weight_md5 = linear.weight._local_value()._md5sum()
mesh0 = dist.ProcessMesh([0], dim_names=["x"])
mesh1 = dist.ProcessMesh([1], dim_names=["x"])
tmp = paddle.distributed.auto_parallel.api.dtensor_from_local(
linear.weight._local_value(),
mesh0 if dist.get_rank() == 0 else mesh1,
[dist.Replicate()],
)
tmp = dist.reshard(
tmp, mesh1 if dist.get_rank() == 0 else mesh0, [dist.Replicate()]
)
tmp_md5 = tmp._local_value()._md5sum()
assert local_weight_md5 == tmp_md5

if self._placements_type == "DP":
assert linear.weight._is_initialized()
assert linear.bias._is_initialized()
local_weight_md5 = linear.weight._local_value()._md5sum()
mesh0 = dist.ProcessMesh([0], dim_names=["x"])
mesh1 = dist.ProcessMesh([1], dim_names=["x"])
tmp = paddle.distributed.auto_parallel.api.dtensor_from_local(
linear.weight._local_value(),
mesh0 if dist.get_rank() == 0 else mesh1,
[dist.Replicate()],
)
tmp = dist.reshard(
tmp,
mesh1 if dist.get_rank() == 0 else mesh0,
[dist.Replicate()],
)
tmp_md5 = tmp._local_value()._md5sum()
assert local_weight_md5 == tmp_md5
elif self._placements_type == "PP":
if dist.get_rank() == 0:
assert linear.weight._is_initialized()
assert not linear.bias._is_initialized()
else:
assert not linear.weight._is_initialized()
assert linear.bias._is_initialized()

def run_test_case(self):
self.test_replicate()
self.test_placements()
self.test_different_xavier()


if __name__ == '__main__':
Expand Down
5 changes: 4 additions & 1 deletion test/auto_parallel/test_semi_auto_parallel_lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def setUp(self):
"dtype": "float32",
"seed": "2023",
}
self._changeable_envs = {"backend": ["cpu", "gpu"]}
self._changeable_envs = {
"backend": ["cpu", "gpu"],
"_placements_type": ["DP", "PP"],
}

def test_lazy_init(self):
envs_list = test_base.gen_product_envs_list(
Expand Down