diff --git a/models/multitask/metaheac/readme.md b/models/multitask/metaheac/readme.md index e433c95db..ad5602a5e 100644 --- a/models/multitask/metaheac/readme.md +++ b/models/multitask/metaheac/readme.md @@ -55,7 +55,7 @@ os : windows/linux/macos # 动态图训练 python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml # 动态图预测 -python -u ./infer_meta.py -m config.yaml +python -u ./infer.py -m config.yaml ``` ## 模型组网 @@ -87,18 +87,18 @@ cd ../../models/multitask/metaheac/ # 切回模型目录 python -u ../../../tools/trainer.py -m config_big.yaml # 动态图预测 # step2: infer 此时test数据集为hot -python -u ./infer_meta.py -m config_big.yaml +python -u ./infer.py -m config_big.yaml # step3:修改config_big.yaml文件中test_data_dir的路径为cold -# python -u ./infer_meta.py -m config.yaml +# python -u ./infer.py -m config.yaml ``` ## infer说明 ### 数据集说明 为了测试模型在不同规模的内容定向推广任务上的表现,将数据集根据内容定向推广任务给定的候选集大小进行了划分,分为大于T和小于T两部分。将腾讯广告大赛2018的Look-alike数据集中的T设置为4000,其中hot数据集中候选集大于T,cold数据集中候选集小于T. -### infer_meta.py说明 -infer_meta.py是用于元学习模型infer的tool,在使用中主要有以下几点需要注意: +### infer.py说明 +infer.py是用于元学习模型infer的tool,在使用中主要有以下几点需要注意: 1. 在对模型进行infer时(train时也可使用这样的操作),可以将runner.infer_batch_size注释掉,这样将禁用DataLoader的自动组batch功能,进而可以使用自定义的组batch方式. -2. 由于元学习在infer时需要先对特定任务的少量数据集进行训练,因此在infer_meta.py的infer_dataloader中每次接收单个子任务的全量infer数据集(包括训练数据和测试数据). +2. 由于元学习在infer时需要先对特定任务的少量数据集进行训练,因此在infer.py的infer_dataloader中每次接收单个子任务的全量infer数据集(包括训练数据和测试数据). 3. 实际组batch在infer.py中进行,在获取到单个子任务的数据后,获取config中的batch_size参数,对训练数据和测试数据进行组batch,并分别调用dygraph_model.py中的infer_train_forward和infer_forward进行训练和测试. 4. 和普通infer不同,由于需要对单个子任务进行少量数据的train和test,对于每个子任务来说加载的都是train阶段训练好的泛化模型. 5. 在对单个子任务infer时,创建了局部的paddle.metric.Auc("ROC"),可以查看每个子任务的AUC指标,在全局metric中维护包含所有子任务的AUC指标. diff --git a/tools/static_trainer.py b/tools/static_trainer.py index 8492dcd53..af54c29c8 100644 --- a/tools/static_trainer.py +++ b/tools/static_trainer.py @@ -86,7 +86,7 @@ def main(args): reader_type = config.get("runner.reader_type", "DataLoader") use_fleet = config.get("runner.use_fleet", False) seed = config.get("runner.seed", 12345) - paddle.seed(12345) + paddle.seed(seed) use_save_data = config.get("runner.use_save_data", False) os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1)) logger.info("**************common.configs**********")