Skip to content

Commit e6aacd1

Browse files
authored
add trainer desc config to distributed strategy (#34457)
* add trainer desc config to distributed strategy * code style modified
1 parent 41c4f72 commit e6aacd1

File tree

5 files changed

+136
-0
lines changed

5 files changed

+136
-0
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ message AsyncConfig {
146146
optional int32 use_ps_gpu = 12 [ default = 0 ];
147147
}
148148

149+
message TrainerDescConfig {
150+
optional string dump_fields_path = 1;
151+
repeated string dump_fields = 2;
152+
repeated string dump_param = 3;
153+
repeated string stat_var_names = 4;
154+
}
155+
149156
message PipelineConfig {
150157
optional int32 micro_batch_size = 1 [ default = 1 ];
151158
optional int32 accumulate_steps = 2 [ default = 1 ];
@@ -206,6 +213,7 @@ message DistributedStrategy {
206213
optional ShardingConfig sharding_configs = 111;
207214
optional HybridConfig hybrid_configs = 112;
208215
optional TensorParallelConfig tensor_parallel_configs = 113;
216+
optional TrainerDescConfig trainer_desc_configs = 114;
209217
optional BuildStrategy build_strategy = 201;
210218
optional ExecutionStrategy execution_strategy = 202;
211219
optional GradientScaleConfig gradient_scale_configs = 203;

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,45 @@ def a_sync_configs(self, configs):
360360
"a_sync_configs")
361361
assign_configs_value(self.strategy.a_sync_configs, configs)
362362

363+
@property
364+
def trainer_desc_configs(self):
365+
"""
366+
Set trainer desc configurations.
367+
368+
**Notes**:
369+
dump_fields_path(str): the path of dump fields
370+
371+
dump_fields(list(str)): the fields that you want to dump
372+
373+
dump_param(list(str)): the param that you want to dump
374+
375+
stat_var_names(list(str)):
376+
377+
Examples:
378+
379+
.. code-block:: python
380+
381+
import paddle.distributed.fleet as fleet
382+
role_maker = fleet.PaddleCloudRoleMaker()
383+
fleet.init(role_maker)
384+
385+
strategy = fleet.DistributedStrategy()
386+
configs = {"dump_fields_path": "./dump_data", "dump_fields": ["xxx", "yyy"]}
387+
strategy.trainer_desc_configs = configs
388+
389+
# code block for defining loss and local optimizer
390+
# sgd = fleet.distributed_optimizer(optimizer, strategy)
391+
392+
"""
393+
return get_msg_dict(self.strategy.trainer_desc_configs)
394+
395+
@trainer_desc_configs.setter
396+
@is_strict_auto
397+
def trainer_desc_configs(self, configs):
398+
check_configs_key(self.strategy.trainer_desc_configs, configs,
399+
"trainer_desc_configs")
400+
assign_configs_value(self.strategy.trainer_desc_configs, configs)
401+
363402
@property
364403
def amp(self):
365404
"""

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,6 +1476,14 @@ def minimize(self,
14761476
context["graph_optimize_ops"] = optimize_ops
14771477
context["graph_optimize_grads"] = params_grads
14781478

1479+
program = paddle.static.default_main_program()
1480+
opt_info = {}
1481+
opt_info["mpi_size"] = self.worker_num()
1482+
opt_info["mpi_rank"] = self.worker_index()
1483+
for k, v in self._user_defined_strategy.trainer_desc_configs.items():
1484+
opt_info[k] = v
1485+
program._fleet_opt = opt_info
1486+
14791487
if self._runtime_handle is None:
14801488
self._runtime_handle = RuntimeFactory()._create_runtime(context)
14811489

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import time
17+
import unittest
18+
19+
import paddle
20+
import paddle.distributed.fleet.base.role_maker as role_maker
21+
import paddle.fluid.transpiler.details.program_utils as pu
22+
23+
paddle.enable_static()
24+
25+
26+
class TestDistStrategyTrainerDescConfig(unittest.TestCase):
27+
def setUp(self):
28+
os.environ["PADDLE_PSERVER_NUMS"] = "2"
29+
os.environ["PADDLE_TRAINERS_NUM"] = "2"
30+
os.environ["POD_IP"] = "127.0.0.1"
31+
os.environ["PADDLE_PORT"] = "36001"
32+
os.environ["PADDLE_TRAINER_ID"] = "0"
33+
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
34+
"127.0.0.1:36001,127.0.0.2:36001"
35+
36+
def test_trainer_desc_config(self):
37+
os.environ["TRAINING_ROLE"] = "TRAINER"
38+
import paddle.distributed.fleet as fleet
39+
40+
fleet.init(role_maker.PaddleCloudRoleMaker())
41+
42+
x = paddle.fluid.layers.data(name='x', shape=[1], dtype='float32')
43+
y = paddle.fluid.layers.data(name='y', shape=[1], dtype='float32')
44+
cost = paddle.fluid.layers.square_error_cost(input=x, label=y)
45+
avg_cost = paddle.fluid.layers.mean(cost)
46+
47+
strategy = paddle.distributed.fleet.DistributedStrategy()
48+
config = {
49+
"dump_fields_path": "dump_data",
50+
"dump_fields": ["xxx", "yyy"],
51+
"dump_param": []
52+
}
53+
strategy.trainer_desc_configs = config
54+
55+
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
56+
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
57+
optimizer.minimize(avg_cost)
58+
59+
program = paddle.static.default_main_program()
60+
self.assertEqual(program._fleet_opt["dump_fields_path"], "dump_data")
61+
self.assertEqual(len(program._fleet_opt["dump_fields"]), 2)
62+
self.assertEqual(len(program._fleet_opt["dump_param"]), 0)
63+
self.assertEqual(program._fleet_opt["mpi_size"],
64+
int(os.environ["PADDLE_TRAINERS_NUM"]))
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main()

python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,19 @@ def test_a_sync_configs(self):
255255
strategy.a_sync_configs = configs
256256
self.assertEqual(strategy.a_sync_configs["k_steps"], 1000)
257257

258+
def test_trainer_desc_configs(self):
259+
strategy = paddle.distributed.fleet.DistributedStrategy()
260+
configs = {
261+
"dump_fields_path": "dump_data",
262+
"dump_fields": ["xxx", "yyy"],
263+
"dump_param": []
264+
}
265+
strategy.trainer_desc_configs = configs
266+
self.assertEqual(strategy.trainer_desc_configs["dump_fields_path"],
267+
"dump_data")
268+
self.assertEqual(len(strategy.trainer_desc_configs["dump_fields"]), 2)
269+
self.assertEqual(len(strategy.trainer_desc_configs["dump_param"]), 0)
270+
258271
def test_elastic(self):
259272
strategy = paddle.distributed.fleet.DistributedStrategy()
260273
strategy.elastic = True

0 commit comments

Comments
 (0)