-
Notifications
You must be signed in to change notification settings - Fork 646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FGCNN model #784
Add FGCNN model #784
Changes from 12 commits
44122e5
fdbfd4a
3bac9c4
1cbb9d2
dcbd0b7
7379e34
6fd476e
7a67e72
d7db7be
7c7096c
696e8de
43cb97d
59a2c31
ff2c144
0098727
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
wget --no-check-certificate https://paddlerec.bj.bcebos.com/datasets/fgcnn/datapro.zip | ||
unzip -o datapro.zip | ||
echo "Complete data download." |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# global settings | ||
|
||
runner: | ||
train_data_dir: "data/trainlite" | ||
train_reader_path: "reader" # importlib format | ||
use_gpu: True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. demo数据下默认不开启gpu,然后demo训练时间需要在一分钟以内 |
||
use_auc: True | ||
train_batch_size: 10 | ||
epochs: 1 | ||
print_interval: 10 | ||
# model_init_path: "output_model_all_fgcnn/1" # init model | ||
model_save_path: "output_model_sample_fgcnn" | ||
test_data_dir: "data/testlite" | ||
infer_reader_path: "reader" # importlib format | ||
infer_batch_size: 10 | ||
infer_load_path: "output_model_sample_fgcnn" | ||
infer_start_epoch: 0 | ||
infer_end_epoch: 1 | ||
|
||
# hyper parameters of user-defined network | ||
hyper_parameters: | ||
# optimizer config | ||
optimizer: | ||
class: Adam | ||
learning_rate: 0.001 | ||
sparse_inputs_slots: 26 | ||
sparse_feature_size: 1000000 | ||
feature_name: ['I1','I2','I3','I4','I5','I6','I7','I8','I9','I10','I11','I12','I13','C1','C2','C3','C4','C5','C6','C7','C8','C9','C10','C11','C12','C13','C14','C15','C16','C17', 'C18','C19', 'C20', 'C21', 'C22','C23', 'C24', 'C25', 'C26'] | ||
dense_inputs_slots: 13 | ||
feature_dim: 20 | ||
conv_kernel_width: [ 9, 9, 9, 9] | ||
conv_filters: [38, 40, 42, 44] | ||
new_maps: [3, 3, 3, 3] | ||
pooling_width: [2, 2, 2, 2] | ||
stride: [1, 1] | ||
dnn_hidden_units: [100, 100, 100] | ||
dnn_dropout: 0.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# global settings | ||
|
||
runner: | ||
train_data_dir: "data/train" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 全量数据需要从datasets目录中读取数据 |
||
train_reader_path: "reader" # importlib format | ||
use_gpu: True | ||
use_auc: True | ||
train_batch_size: 2000 | ||
epochs: 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 全量数据训练的epoch和readme中描述的不一致 |
||
print_interval: 2000 | ||
# model_init_path: "output_model_all_fgcnn/1" # init model | ||
model_save_path: "output_model_all_fgcnn" | ||
test_data_dir: "data/test" | ||
infer_reader_path: "reader" # importlib format | ||
infer_batch_size: 5000 | ||
infer_load_path: "output_model_all_fgcnn" | ||
infer_start_epoch: 0 | ||
infer_end_epoch: 2 | ||
|
||
# hyper parameters of user-defined network | ||
hyper_parameters: | ||
# optimizer config | ||
optimizer: | ||
class: Adam | ||
learning_rate: 0.001 | ||
sparse_inputs_slots: 26 | ||
sparse_feature_size: 1000000 | ||
feature_name: ['I1','I2','I3','I4','I5','I6','I7','I8','I9','I10','I11','I12','I13','C1','C2','C3','C4','C5','C6','C7','C8','C9','C10','C11','C12','C13','C14','C15','C16','C17', 'C18','C19', 'C20', 'C21', 'C22','C23', 'C24', 'C25', 'C26'] | ||
dense_inputs_slots: 13 | ||
feature_dim: 20 | ||
conv_kernel_width: [ 9, 9, 9, 9] | ||
conv_filters: [38, 40, 42, 44] | ||
new_maps: [3, 3, 3, 3] | ||
pooling_width: [2, 2, 2, 2] | ||
stride: [1, 1] | ||
dnn_hidden_units: [1000, 1000, 1000] | ||
dnn_dropout: 0.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import net | ||
import numpy as np | ||
|
||
|
||
class DygraphModel(): | ||
# define model | ||
def create_model(self, config): | ||
sparse_input_slot = config.get('hyper_parameters.sparse_inputs_slots') | ||
dense_input_slot = config.get('hyper_parameters.dense_inputs_slots') | ||
sparse_feature_size = config.get( | ||
"hyper_parameters.sparse_feature_size") | ||
feature_name = config.get("hyper_parameters.feature_name") | ||
feature_dim = config.get("hyper_parameters.feature_dim", 20) | ||
conv_kernel_width = config.get("hyper_parameters.conv_kernel_width", | ||
(7, 7, 7, 7)) | ||
conv_filters = config.get("hyper_parameters.conv_filters", | ||
(14, 16, 18, 20)) | ||
new_maps = config.get("hyper_parameters.new_maps", (3, 3, 3, 3)) | ||
pooling_width = config.get("hyper_parameters.pooling_width", | ||
(2, 2, 2, 2)) | ||
stride = config.get("hyper_parameters.stride", (1, 1)) | ||
dnn_hidden_units = config.get("hyper_parameters.dnn_hidden_units", | ||
(128, )) | ||
dnn_dropout = config.get("hyper_parameters.dnn_dropout", 0.0) | ||
fgcnn_model = net.FGCNN( | ||
sparse_input_slot, sparse_feature_size, feature_name, feature_dim, | ||
dense_input_slot, conv_kernel_width, conv_filters, new_maps, | ||
pooling_width, stride, dnn_hidden_units, dnn_dropout) | ||
|
||
return fgcnn_model | ||
|
||
# define feeds which convert numpy of batch data to paddle.tensor | ||
def create_feeds(self, batch_data, config): | ||
# print(len(batch_data)) | ||
inputs = paddle.to_tensor(np.array(batch_data[0]).astype('int64')) | ||
inputs = batch_data[0] | ||
label = batch_data[1] | ||
return label, inputs | ||
|
||
# define loss function by predicts and label | ||
|
||
def create_loss(self, y_pred, label): | ||
loss = nn.functional.log_loss( | ||
y_pred, label=paddle.cast( | ||
label, dtype="float32")) | ||
avg_cost = paddle.mean(x=loss) | ||
return avg_cost | ||
|
||
# define optimizer | ||
def create_optimizer(self, dy_model, config): | ||
lr = config.get("hyper_parameters.optimizer.learning_rate", 1e-3) | ||
optimizer = paddle.optimizer.Adam( | ||
parameters=dy_model.parameters(), learning_rate=lr) | ||
return optimizer | ||
|
||
def create_metrics(self): | ||
metrics_list_name = ["auc"] | ||
auc_metric = paddle.metric.Auc("ROC") | ||
metrics_list = [auc_metric] | ||
return metrics_list, metrics_list_name | ||
|
||
# construct train forward phase | ||
def train_forward(self, dy_model, metrics_list, batch_data, config): | ||
# 稠密向量 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里注释是什么意思? |
||
label, inputs = self.create_feeds(batch_data, config) | ||
pred = dy_model.forward(inputs) | ||
loss = self.create_loss(pred, label) | ||
# update metrics | ||
predict_2d = paddle.concat(x=[1 - pred, pred], axis=1) | ||
metrics_list[0].update(preds=predict_2d.numpy(), labels=label.numpy()) | ||
# print_dict format :{'loss': loss} | ||
print_dict = {'loss': loss} | ||
return loss, metrics_list, print_dict | ||
|
||
def infer_forward(self, dy_model, metrics_list, batch_data, config): | ||
# label, sparse_tensor = self.create_feeds(batch_data, config) | ||
label, inputs = self.create_feeds(batch_data, config) | ||
pred = dy_model.forward(inputs) | ||
# pred = dy_model.forward(sparse_tensor) | ||
loss = self.create_loss(pred, label) | ||
# update metrics | ||
predict_2d = paddle.concat(x=[1 - pred, pred], axis=1) | ||
metrics_list[0].update(preds=predict_2d.numpy(), labels=label.numpy()) | ||
# print_dict format :{'loss': loss} | ||
print_dict = {'loss': loss} | ||
return metrics_list, print_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
更名为run.sh