Skip to content

Commit

Permalink
update reademe
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhen38 committed Jul 14, 2022
1 parent 46bdc5d commit 5fa156a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions models/multitask/metaheac/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ os : windows/linux/macos
# 动态图训练
python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
# 动态图预测
python -u ./infer_meta.py -m config.yaml
python -u ./infer.py -m config.yaml
```

## 模型组网
Expand Down Expand Up @@ -87,18 +87,18 @@ cd ../../models/multitask/metaheac/ # 切回模型目录
python -u ../../../tools/trainer.py -m config_big.yaml
# 动态图预测
# step2: infer 此时test数据集为hot
python -u ./infer_meta.py -m config_big.yaml
python -u ./infer.py -m config_big.yaml
# step3:修改config_big.yaml文件中test_data_dir的路径为cold
# python -u ./infer_meta.py -m config.yaml
# python -u ./infer.py -m config.yaml
```

## infer说明
### 数据集说明
为了测试模型在不同规模的内容定向推广任务上的表现,将数据集根据内容定向推广任务给定的候选集大小进行了划分,分为大于T和小于T两部分。将腾讯广告大赛2018的Look-alike数据集中的T设置为4000,其中hot数据集中候选集大于T,cold数据集中候选集小于T.
### infer_meta.py说明
infer_meta.py是用于元学习模型infer的tool,在使用中主要有以下几点需要注意:
### infer.py说明
infer.py是用于元学习模型infer的tool,在使用中主要有以下几点需要注意:
1. 在对模型进行infer时(train时也可使用这样的操作),可以将runner.infer_batch_size注释掉,这样将禁用DataLoader的自动组batch功能,进而可以使用自定义的组batch方式.
2. 由于元学习在infer时需要先对特定任务的少量数据集进行训练,因此在infer_meta.py的infer_dataloader中每次接收单个子任务的全量infer数据集(包括训练数据和测试数据).
2. 由于元学习在infer时需要先对特定任务的少量数据集进行训练,因此在infer.py的infer_dataloader中每次接收单个子任务的全量infer数据集(包括训练数据和测试数据).
3. 实际组batch在infer.py中进行,在获取到单个子任务的数据后,获取config中的batch_size参数,对训练数据和测试数据进行组batch,并分别调用dygraph_model.py中的infer_train_forward和infer_forward进行训练和测试.
4. 和普通infer不同,由于需要对单个子任务进行少量数据的train和test,对于每个子任务来说加载的都是train阶段训练好的泛化模型.
5. 在对单个子任务infer时,创建了局部的paddle.metric.Auc("ROC"),可以查看每个子任务的AUC指标,在全局metric中维护包含所有子任务的AUC指标.
Expand Down
2 changes: 1 addition & 1 deletion tools/static_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(args):
reader_type = config.get("runner.reader_type", "DataLoader")
use_fleet = config.get("runner.use_fleet", False)
seed = config.get("runner.seed", 12345)
paddle.seed(12345)
paddle.seed(seed)
use_save_data = config.get("runner.use_save_data", False)
os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1))
logger.info("**************common.configs**********")
Expand Down

0 comments on commit 5fa156a

Please sign in to comment.