Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/inference/tests/book/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ function(inference_test TARGET_NAME)
endfunction(inference_test)

inference_test(fit_a_line)
inference_test(recognize_digits ARGS mlp)
inference_test(image_classification ARGS vgg resnet)
inference_test(label_semantic_roles)
inference_test(rnn_encoder_decoder)
inference_test(recognize_digits ARGS mlp)
inference_test(recommender_system)
inference_test(rnn_encoder_decoder)
inference_test(understand_sentiment)
inference_test(word2vec)
5 changes: 3 additions & 2 deletions paddle/inference/tests/book/test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ template <typename Place, bool IsCombined = false>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
// 1. Define place, executor, scope and inference_program
// 1. Define place, executor, scope
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
Expand All @@ -101,7 +101,8 @@ void TestInference(const std::string& dirname,
if (IsCombined) {
// All parameters are saved in a single file.
// Hard-coding the file names of program and parameters in unittest.
// Users are free to specify different filename.
// Users are free to specify different filename
// (provided: the filenames are changed in the python api as well: io.py)
std::string prog_filename = "__model_combined__";
std::string param_filename = "__params_combined__";
inference_program = paddle::inference::Load(executor,
Expand Down
68 changes: 68 additions & 0 deletions paddle/inference/tests/book/test_inference_word2vec.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <gtest/gtest.h>
#include "gflags/gflags.h"
#include "test_helper.h"

DEFINE_string(dirname, "", "Directory of the inference model.");

TEST(inference, word2vec) {
if (FLAGS_dirname.empty()) {
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
}

LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::string dirname = FLAGS_dirname;

// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc

paddle::framework::LoDTensor first_word, second_word, third_word, fourth_word;
paddle::framework::LoD lod{{0, 1}};
int64_t dict_size = 2072; // Hard-coding the size of dictionary

SetupLoDTensor(first_word, lod, static_cast<int64_t>(0), dict_size);
SetupLoDTensor(second_word, lod, static_cast<int64_t>(0), dict_size);
SetupLoDTensor(third_word, lod, static_cast<int64_t>(0), dict_size);
SetupLoDTensor(fourth_word, lod, static_cast<int64_t>(0), dict_size);

std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&first_word);
cpu_feeds.push_back(&second_word);
cpu_feeds.push_back(&third_word);
cpu_feeds.push_back(&fourth_word);

paddle::framework::LoDTensor output1;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);

// Run inference on CPU
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.lod();
LOG(INFO) << output1.dims();

#ifdef PADDLE_WITH_CUDA
paddle::framework::LoDTensor output2;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs2;
cpu_fetchs2.push_back(&output2);

// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.lod();
LOG(INFO) << output2.dims();

CheckError<float>(output1, output2);
#endif
}
81 changes: 73 additions & 8 deletions python/paddle/v2/fluid/tests/book/test_word2vec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# # Licensed under the Apache License, Version 2.0 (the "License");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# # -> #, and add a empty line above. This can be fixed in the NMT pr.

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
Expand All @@ -16,14 +15,67 @@
import paddle.v2.fluid as fluid
import unittest
import os
import numpy as np
import math
import sys


def main(use_cuda, is_sparse, parallel):
if use_cuda and not fluid.core.is_compiled_with_cuda():
def create_random_lodtensor(lod, place, low, high):
data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64")
res = fluid.LoDTensor()
res.set(data, place)
res.set_lod([lod])
return res


def infer(use_cuda, save_dirname=None):
if save_dirname is None:
return

place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)

# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)

word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict) - 1

# Setup input, by creating 4 words, and setting up lod required for
# lookup_table_op
lod = [0, 1]
first_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
second_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
third_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
fourth_word = create_random_lodtensor(lod, place, low=0, high=dict_size)

assert feed_target_names[0] == 'firstw'
assert feed_target_names[1] == 'secondw'
assert feed_target_names[2] == 'thirdw'
assert feed_target_names[3] == 'forthw'

# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={
feed_target_names[0]: first_word,
feed_target_names[1]: second_word,
feed_target_names[2]: third_word,
feed_target_names[3]: fourth_word
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
print("Inference results: ", np_data)


def train(use_cuda, is_sparse, parallel, save_dirname):
PASS_NUM = 100
EMBED_SIZE = 32
HIDDEN_SIZE = 256
Expand Down Expand Up @@ -67,7 +119,7 @@ def __network__(words):
act='softmax')
cost = fluid.layers.cross_entropy(input=predict_word, label=words[4])
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
return avg_cost, predict_word

word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
Expand All @@ -79,13 +131,13 @@ def __network__(words):
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')

if not parallel:
avg_cost = __network__(
avg_cost, predict_word = __network__(
[first_word, second_word, third_word, forth_word, next_word])
else:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
avg_cost = __network__(
avg_cost, predict_word = __network__(
map(pd.read_input, [
first_word, second_word, third_word, forth_word, next_word
]))
Expand Down Expand Up @@ -113,13 +165,25 @@ def __network__(words):
feed=feeder.feed(data),
fetch_list=[avg_cost])
if avg_cost_np[0] < 5.0:
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname, [
'firstw', 'secondw', 'thirdw', 'forthw'
], [predict_word], exe)
return
if math.isnan(float(avg_cost_np[0])):
sys.exit("got NaN loss, training failed.")

raise AssertionError("Cost is too large {0:2.2}".format(avg_cost_np[0]))


def main(use_cuda, is_sparse, parallel):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
save_dirname = "word2vec.inference.model"
train(use_cuda, is_sparse, parallel, save_dirname)
infer(use_cuda, save_dirname)


FULL_TEST = os.getenv('FULL_TEST',
'0').lower() in ['true', '1', 't', 'y', 'yes', 'on']
SKIP_REASON = "Only run minimum number of tests in CI server, to make CI faster"
Expand All @@ -142,7 +206,8 @@ def __impl__(*args, **kwargs):
with fluid.program_guard(prog, startup_prog):
main(use_cuda=use_cuda, is_sparse=is_sparse, parallel=parallel)

if use_cuda and is_sparse and parallel:
# run only 2 cases: use_cuda is either True or False
if is_sparse == False and parallel == False:
fn = __impl__
else:
# skip the other test when on CI server
Expand Down