Skip to content

Commit 20b9be6

Browse files
authored
[Tensor Parallelism] split fix bug (#33015)
1 parent a2a45d8 commit 20b9be6

File tree

6 files changed

+20
-5
lines changed

6 files changed

+20
-5
lines changed

python/paddle/distributed/collective.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,11 @@ def _parallel_linear(x,
977977
group=None):
978978
"""
979979
Parallel Linear
980+
981+
axis the dimension of the parameter of linear layer.
982+
axis = 0: the row dimension
983+
axid = 1: the col dimension
984+
980985
"""
981986
if group is not None and not group.is_member():
982987
return
@@ -1008,6 +1013,12 @@ def _parallel_linear(x,
10081013
main_block = paddle.static.default_main_program().global_block()
10091014
startup_block.vars[linear.weight.name].is_distributed = True
10101015
main_block.vars[linear.weight.name].is_distributed = True
1016+
# set is_distributed for splited bias
1017+
# if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
1018+
# if a linear layer is splited by col, the bias would also be split into each rank as its weight
1019+
if axis == 1 and linear._bias_attr != False:
1020+
startup_block.vars[linear.bias.name].is_distributed = True
1021+
main_block.vars[linear.bias.name].is_distributed = True
10111022

10121023
if not gather_out: return linear_out
10131024

python/paddle/distributed/fleet/base/distributed_strategy.py

100755100644
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ def sharding_configs(self):
814814
"sharding_segment_strategy": "segment_broadcast_MB",
815815
"segment_broadcast_MB": 32,
816816
"sharding_degree": 8,
817-
"sharding_degree": 2,
817+
"dp_degree": 2,
818818
"gradient_merge_acc_step": 4,
819819
}
820820
"""

python/paddle/fluid/contrib/mixed_precision/fp16_lists.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def _update_list(self):
145145
'sign',
146146
'cast',
147147
'fused_bn_add_activation',
148+
'c_identity',
148149
}
149150

150151
# The set of ops that don't support fp16 calculation

python/paddle/fluid/tests/unittests/column_parallel_linear_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def get_model(self, main_prog, startup_program, rank):
6969
axis=1,
7070
num_partitions=2,
7171
weight_attr=param_attr,
72-
bias_attr=False, )
72+
bias_attr=True, )
7373

7474
return [linear_out]
7575

python/paddle/fluid/tests/unittests/row_parallel_linear_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ def get_model(self, main_prog, startup_program, rank):
6565

6666
linear_out = paddle.distributed.split(
6767
data,
68-
size=(1000, 8),
68+
size=(1000, 16),
6969
operation='linear',
7070
axis=0,
7171
num_partitions=2,
7272
weight_attr=param_attr,
73-
bias_attr=False, )
73+
bias_attr=True, )
7474

7575
return [linear_out]
7676

python/paddle/fluid/tests/unittests/test_collective_api_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,10 @@ def _run_cluster(self, model_file, envs):
154154
#update environment
155155
env0.update(envs)
156156
env1.update(envs)
157-
tr_cmd = "%s %s"
157+
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
158+
tr_cmd = "%s -m coverage run --branch -p %s"
159+
else:
160+
tr_cmd = "%s %s"
158161
tr0_cmd = tr_cmd % (self._python_interp, model_file)
159162
tr1_cmd = tr_cmd % (self._python_interp, model_file)
160163
tr0_pipe = open("/tmp/tr0_err_%d.log" % os.getpid(), "w")

0 commit comments

Comments
 (0)