Skip to content

Commit

Permalink
fix little error
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyoujiyi committed Aug 25, 2022
1 parent d23b74f commit 0f2192e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
9 changes: 8 additions & 1 deletion models/recall/ncf/config_fl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

runner:
sync_mode: "geo" # 可选, string: sync/async/geo
with_coodinator: 1
#with_coodinator: 1
geo_step: 100 # 可选, int, 在geo模式下控制本地的迭代次数
split_file_list: True # 可选, bool, 若每个节点上都拥有全量数据,则需设置为True
thread_num: 1 # 多线程配置
Expand All @@ -39,6 +39,13 @@ runner:
infer_load_path: "output_model_ncf"
infer_start_epoch: 2
infer_end_epoch: 3

need_dump: True
dump_fields_path: "/home/wangbin/the_one_ps/ziyoujiyi_PaddleRec/PaddleRec/models/recall/ncf"
dump_fields: ['item_input', 'user_input']
dump_param: []
local_sparse: ['embedding_0.w_0']
remote_sparse: ['embedding_1.w_0']

hyper_parameters:
optimizer:
Expand Down
1 change: 1 addition & 0 deletions models/recall/ncf/fl_ps_help.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* 在 PaddleRec/datasets/movielens_pinterest_NCF/fl_data 中新建目录 fl_test_data 和 fl_train_data,用于存放每个 client 上的训练数据集和测试数据集
* 在 PaddleRec/datasets/movielens_pinterest_NCF/fl_data 目录中执行: python gen_heter_data.py,生成 10 份训练数据
* 总样本数 4970844(按 1:4 补充负样本):0 - 518095,1 - 520165,2 - 373605,3 - 315550,4 - 483779,5 - 495635,6 - 402810,7 - 354590,8 - 262710,9 - 1243905
* 样本数据每一行表示:物品 id,用户 id,标签

# 3、运行命令
1. 不带 coordinator 版本
Expand Down
2 changes: 1 addition & 1 deletion tools/static_fl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def init_network(self):
self.model = get_model(self.config)
self.input_data = self.model.create_feeds()
self.metrics = self.model.net(self.input_data)
self.model.create_optimizer(get_strategy(self.config))
self.model.create_optimizer(get_strategy(self.config)) ## get_strategy
if self.pure_bf16:
self.model.optimizer.amp_init(self.place)

Expand Down
4 changes: 3 additions & 1 deletion tools/utils/static_ps/program_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def get_strategy(config):
"dump_fields_path": config.get("runner.dump_fields_path", ""),
"dump_fields": config.get("runner.dump_fields", []),
"dump_param": config.get("runner.dump_param", []),
"stat_var_names": config.get("stat_var_names", [])
"stat_var_names": config.get("stat_var_names", []),
"local_sparse": config.get("runner.local_sparse", []),
"remote_sparse": config.get("runner.remote_sparse", [])
}
print("strategy:", strategy.trainer_desc_configs)

Expand Down

0 comments on commit 0f2192e

Please sign in to comment.