forked from PaddlePaddle/PaddleRec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
712 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# PLE | ||
|
||
以下是本例的简要目录结构及说明: | ||
|
||
``` | ||
├── data # 文档 | ||
├── train #训练数据 | ||
├── train_data.txt | ||
├── test #测试数据 | ||
├── test_data.txt | ||
├── run.sh | ||
├── data_preparation.py | ||
├── __init__.py | ||
├── config.yaml #配置文件 | ||
├── census_reader.py #数据读取文件 | ||
├── model.py #模型文件 | ||
``` | ||
|
||
注:在阅读该示例前,建议您先了解以下内容: | ||
|
||
[paddlerec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md) | ||
|
||
## 内容 | ||
|
||
- [模型简介](https://github.com/PaddlePaddle/PaddleRec/tree/master/models/multitask/ple#模型简介) | ||
- [数据准备](https://github.com/PaddlePaddle/PaddleRec/tree/master/models/multitask/ple#数据准备) | ||
- [运行环境](https://github.com/PaddlePaddle/PaddleRec/tree/master/models/multitask/ple#运行环境) | ||
- [快速开始](https://github.com/PaddlePaddle/PaddleRec/tree/master/models/multitask/ple#快速开始) | ||
- [论文复现](https://github.com/PaddlePaddle/PaddleRec/tree/master/models/multitask/ple#论文复现) | ||
- [进阶使用](https://github.com/PaddlePaddle/PaddleRec/tree/master/models/multitask/ple#进阶使用) | ||
- [FAQ](https://github.com/PaddlePaddle/PaddleRec/tree/master/models/multitask/ple#FAQ) | ||
|
||
## 模型简介 | ||
|
||
多任务模型通过学习不同任务的联系和差异,可提高每个任务的学习效率和质量。但在多任务场景中经常出现跷跷板现象,即有些任务表现良好,有些任务表现变差。 论文[《Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations》](https://dl.acm.org/doi/abs/10.1145/3383313.3412236 ) ,论文提出了Progressive Layered Extraction (简称PLE),来解决多任务学习的跷跷板现象。 | ||
|
||
我们在Paddlepaddle定义PLE的网络结构,在开源数据集Census-income Data上验证模型效果。 | ||
|
||
若进行精度验证,请参考[论文复现](https://github.com/PaddlePaddle/PaddleRec/tree/master/models/multitask/ple#论文复现)部分。 | ||
|
||
本项目支持功能 | ||
|
||
训练:单机CPU、单机单卡GPU、单机多卡GPU、本地模拟参数服务器训练、增量训练,配置请参考 [启动训练](https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/train.md) | ||
预测:单机CPU、单机单卡GPU ;配置请参考[PaddleRec 离线预测](https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/predict.md) | ||
|
||
## 数据准备 | ||
|
||
数据地址: [Census-income Data](https://archive.ics.uci.edu/ml/machine-learning-databases/census-income-mld/census.tar.gz ) | ||
|
||
|
||
生成的格式以逗号为分割点 | ||
|
||
``` | ||
0,0,73,0,0,0,0,1700.09,0,0 | ||
``` | ||
|
||
完整的大数据参考论文复现部分。 | ||
|
||
## 运行环境 | ||
|
||
PaddlePaddle>=1.7.2 | ||
|
||
python 2.7/3.5/3.6/3.7 | ||
|
||
PaddleRec >=0.1 | ||
|
||
os : windows/linux/macos | ||
|
||
## 快速开始 | ||
|
||
### 单机训练 | ||
|
||
CPU环境 | ||
|
||
在config.yaml文件中设置好设备,epochs等。 | ||
|
||
``` | ||
dataset: | ||
- name: dataset_train | ||
batch_size: 5 | ||
type: QueueDataset | ||
data_path: "{workspace}/data/train" | ||
data_converter: "{workspace}/census_reader.py" | ||
- name: dataset_infer | ||
batch_size: 5 | ||
type: QueueDataset | ||
data_path: "{workspace}/data/train" | ||
data_converter: "{workspace}/census_reader.py" | ||
``` | ||
|
||
### 单机预测 | ||
|
||
CPU环境 | ||
|
||
在config.yaml文件中设置好epochs、device等参数。 | ||
``` | ||
- name: infer_runner | ||
class: infer | ||
init_model_path: "increment/0" | ||
device: cpu | ||
``` | ||
|
||
## 论文复现 | ||
|
||
## 进阶使用 | ||
|
||
## FAQ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
from paddlerec.core.reader import ReaderBase | ||
|
||
|
||
class Reader(ReaderBase): | ||
def init(self): | ||
pass | ||
|
||
def generate_sample(self, line): | ||
""" | ||
Read the data line by line and process it as a dictionary | ||
""" | ||
|
||
def reader(): | ||
""" | ||
This function needs to be implemented by the user, based on data format | ||
""" | ||
l = line.strip().split(',') | ||
l = list(map(float, l)) | ||
label_income = [] | ||
label_marital = [] | ||
data = l[2:] | ||
if int(l[1]) == 0: | ||
label_income = [1, 0] | ||
elif int(l[1]) == 1: | ||
label_income = [0, 1] | ||
if int(l[0]) == 0: | ||
label_marital = [1, 0] | ||
elif int(l[0]) == 1: | ||
label_marital = [0, 1] | ||
# label_income = np.array(label_income) | ||
# label_marital = np.array(label_marital) | ||
feature_name = ["input", "label_income", "label_marital"] | ||
yield zip(feature_name, [data] + [label_income] + [label_marital]) | ||
|
||
return reader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
workspace: "models/multitask/ple" | ||
|
||
dataset: | ||
- name: dataset_train | ||
batch_size: 5 # or big data set 32 | ||
type: DataLoader # or QueueDataset | ||
data_path: "{workspace}/data/train" | ||
data_converter: "{workspace}/census_reader.py" | ||
- name: dataset_infer | ||
batch_size: 5 # or big data set 32 | ||
type: DataLoader # or QueueDataset | ||
data_path: "{workspace}/data/train" | ||
data_converter: "{workspace}/census_reader.py" | ||
|
||
hyper_parameters: | ||
feature_size: 499 | ||
task_num: 2 | ||
shared_num: 2 | ||
exp_per_task: 3 | ||
level_number: 1 | ||
expert_size: 16 | ||
tower_size: 8 | ||
optimizer: | ||
class: adam | ||
learning_rate: 0.001 | ||
strategy: async | ||
|
||
mode: [train_runner, infer_runner] | ||
|
||
runner: | ||
- name: train_runner | ||
class: train | ||
device: cpu # or gpu | ||
selected_gpus: "0" | ||
epochs: 10 | ||
save_checkpoint_interval: 1 | ||
save_inference_interval: 4 | ||
save_checkpoint_path: "increment_ple" | ||
save_inference_path: "inference" | ||
print_interval: 1 # big data set 10 | ||
phases: [train] | ||
- name: infer_runner | ||
class: infer | ||
init_model_path: "increment_ple/1" | ||
device: cpu # or gpu | ||
phases: [infer] | ||
|
||
phase: | ||
- name: train | ||
model: "{workspace}/model.py" | ||
dataset_name: dataset_train | ||
thread_num: 1 | ||
- name: infer | ||
model: "{workspace}/model.py" | ||
dataset_name: dataset_infer | ||
thread_num: 1 |
Oops, something went wrong.