Skip to content

Commit

Permalink
cast
Browse files Browse the repository at this point in the history
  • Loading branch information
esythan committed Aug 10, 2022
1 parent ea60e40 commit 533eaf3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 5 additions & 3 deletions models/rank/slot_dnn/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def __init__(self,
self.layer_sizes = layer_sizes
self._init_range = 0.2

self.entry = paddle.distributed.ShowClickEntry("show", "click")

sizes = [emb_dim * slot_num] + self.layer_sizes + [1]
acts = ["relu" for _ in range(len(self.layer_sizes))] + [None]
scales = []
Expand All @@ -54,10 +52,14 @@ def __init__(self,
self.add_sublayer('act_%d' % i, act)
self._mlp_layers.append(act)

def forward(self, slot_inputs):
def forward(self, show, click, slot_inputs):
self.all_vars = []
embs = []
self.inference_feed_vars = []
show_cast = paddle.cast(show, dtype='float32')
click_cast = paddle.cast(click, dtype='float32')
self.entry = paddle.distributed.ShowClickEntry(show_cast.name,
click_cast.name)
for s_input in slot_inputs:
emb = paddle.static.nn.sparse_embedding(
input=s_input,
Expand Down
4 changes: 3 additions & 1 deletion models/rank/slot_dnn/static_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def create_feeds(self, is_infer=False):
return feeds_list

def net(self, input, is_infer=False):
self.show_input = input[0]
self.label_input = input[1]
self.slot_inputs = input[2:]

Expand All @@ -67,7 +68,8 @@ def net(self, input, is_infer=False):
self.layer_sizes,
sync_mode=self.sync_mode)

self.predict = dnn_model.forward(self.slot_inputs)
self.predict = dnn_model.forward(self.show_input, self.label_input,
self.slot_inputs)

# self.all_vars = input + dnn_model.all_vars
self.all_vars = dnn_model.all_vars
Expand Down

0 comments on commit 533eaf3

Please sign in to comment.