diff --git a/doc/inference.md b/doc/inference.md index 2c124795a..9162db882 100644 --- a/doc/inference.md +++ b/doc/inference.md @@ -12,8 +12,8 @@ runner: ... # use inference save model use_inference: True # 静态图训练时保存为inference model - save_inference_feed_varnames: ["label","C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] # inference model 的feed参数的名字 - save_inference_fetch_varnames: ["cast_0.tmp_0"] # inference model 的fetch参数的名字 + save_inference_feed_varnames: ["C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] # inference model 的feed参数的名字 + save_inference_fetch_varnames: ["sigmoid_0.tmp_0"] # inference model 的fetch参数的名字 ``` 3. 启动静态图训练 ```bash @@ -54,7 +54,7 @@ python -u ../../../tools/to_static.py -m config.yaml # 在最后输出的list中,去除第一个np.array,即label部分。 yield output_list[1:] ``` -将inference预测得到的prediction预测值和数据集中的label对比,使用另外的脚本计算auc即可得到和静态图相同的auc结果。 +将inference预测得到的prediction预测值和数据集中的label对比,使用另外的脚本计算auc指标即可。 ## 将保存的模型使用Inference预测库进行服务端部署 paddlerec提供tools/paddle_infer.py脚本,供您方便的使用inference预测库高效的对模型进行预测。 diff --git a/models/rank/wide_deep/config.yaml b/models/rank/wide_deep/config.yaml index b87b356be..0bfc989cd 100755 --- a/models/rank/wide_deep/config.yaml +++ b/models/rank/wide_deep/config.yaml @@ -22,7 +22,7 @@ runner: train_batch_size: 2 epochs: 3 print_interval: 2 - #model_init_path: "output_model_wide_deep/2" # init model + # model_init_path: "output_model_wide_deep/2" # init model model_save_path: "output_model_wide_deep" test_data_dir: "data/sample_data/train" infer_reader_path: "criteo_reader" # importlib format @@ -32,8 +32,8 @@ runner: infer_end_epoch: 3 #use inference save model use_inference: False - save_inference_feed_varnames: ["label","C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] - save_inference_fetch_varnames: ["cast_0.tmp_0"] + save_inference_feed_varnames: ["C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] + save_inference_fetch_varnames: ["sigmoid_0.tmp_0"] # hyper parameters of user-defined network hyper_parameters: