Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
MrChengmo committed Jan 20, 2021
2 parents 62b463c + f25c8a7 commit 760fda5
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 10 deletions.
6 changes: 3 additions & 3 deletions models/rank/dnn/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ runner:
use_gpu: 0

model_path: "static_model.py"
#reader_type: "QueueDataset" # DataLoader / QueueDataset / RecDataset
reader_type: "QueueDataset" # DataLoader / QueueDataset / RecDataset
pipe_command: "python benchmark_reader.py"
dataset_debug: False
split_file_list: False

train_batch_size: 1000
train_data_dir: "data/sample_data/train"
train_reader_path: "reader"
train_data_dir: "train_data"
train_reader_path: "benchmark_reader"
model_save_path: "model"

infer_batch_size: 2
Expand Down
81 changes: 81 additions & 0 deletions models/rank/dnn/criteo_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import numpy as np

from paddle.io import IterableDataset


class RecDataset(IterableDataset):
def __init__(self, file_list, config):
super(RecDataset, self).__init__()
self.file_list = file_list
self.init()

def init(self):
from operator import mul
padding = 0
sparse_slots = "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
self.sparse_slots = sparse_slots.strip().split(" ")
self.dense_slots = ["dense_feature"]
self.dense_slots_shape = [13]
self.slots = self.sparse_slots + self.dense_slots
self.slot2index = {}
self.visit = {}
for i in range(len(self.slots)):
self.slot2index[self.slots[i]] = i
self.visit[self.slots[i]] = False
self.padding = padding

def __iter__(self):
full_lines = []
self.data = []
for file in self.file_list:
with open(file, "r") as rf:
for l in rf:
line = l.strip().split(" ")
output = [(i, []) for i in self.slots]
for i in line:
slot_feasign = i.split(":")
slot = slot_feasign[0]
if slot not in self.slots:
continue
if slot in self.sparse_slots:
feasign = int(slot_feasign[1])
else:
feasign = float(slot_feasign[1])
output[self.slot2index[slot]][1].append(feasign)
self.visit[slot] = True
for i in self.visit:
slot = i
if not self.visit[slot]:
if i in self.dense_slots:
output[self.slot2index[i]][1].extend(
[self.padding] *
self.dense_slots_shape[self.slot2index[i]])
else:
output[self.slot2index[i]][1].extend(
[self.padding])
else:
self.visit[slot] = False
# sparse
output_list = []
for key, value in output[:-1]:
output_list.append(np.array(value))
# dense
output_list.append(
np.array(output[-1][1]).astype("float32"))
# list
yield output_list
13 changes: 7 additions & 6 deletions models/rank/wide_deep/static_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,18 @@ def net(self, input, is_infer=False):

self.predict = predict_2d

auc, batch_auc, _ = paddle.static.layers.auc(input=self.predict,
label=self.label_input,
num_thresholds=2**12,
slide_steps=20)
auc, batch_auc, _ = paddle.static.auc(input=self.predict,
label=self.label_input,
num_thresholds=2**12,
slide_steps=20)
self.inference_target_var = auc
if is_infer:
fetch_dict = {'auc': auc}
return fetch_dict

cost = paddle.nn.functional.cross_entropy(
input=predict_2d, label=self.label_input)
cost = paddle.nn.functional.log_loss(
input=pred, label=paddle.cast(
self.label_input, dtype="float32"))
avg_cost = paddle.mean(x=cost)
self._cost = avg_cost

Expand Down
2 changes: 1 addition & 1 deletion tools/utils/utils_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def create_data_loader(config, place, mode="train"):
logger.info("reader path:{}".format(reader_path))
from importlib import import_module
reader_class = import_module(reader_path)
dataset = reader_class.RecDataset(file_list)
dataset = reader_class.RecDataset(file_list, config=config)
loader = DataLoader(
dataset, batch_size=batch_size, places=place, drop_last=True)
return loader
Expand Down

0 comments on commit 760fda5

Please sign in to comment.