Skip to content

Commit fcd74e0

Browse files
QiJunereyoung
authored andcommitted
add book04.word2vec train test (#5002)
* init * ensure ids in lookup table op must be a column vector * add book4 configuration in test_layers * debug test_book4 * add test_word2vec * follow comments * follow comments
1 parent 40e7caf commit fcd74e0

File tree

8 files changed

+282
-10
lines changed

8 files changed

+282
-10
lines changed

paddle/framework/var_desc.cc

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace framework {
2020

21+
VarDesc::VarType VarDescBind::GetType() const { return desc_.type(); }
22+
23+
void VarDescBind::SetType(VarDesc::VarType type) { desc_.set_type(type); }
24+
2125
void VarDescBind::SetShape(const std::vector<int64_t> &dims) {
2226
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
2327
}

paddle/framework/var_desc.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ class VarDescBind {
7575

7676
int32_t GetLodLevel() const;
7777

78-
VarDesc::VarType GetType() const { return desc_.type(); }
78+
VarDesc::VarType GetType() const;
7979

80-
void SetType(VarDesc::VarType type) { desc_.set_type(type); }
80+
void SetType(VarDesc::VarType type);
8181

8282
bool Persistable() const { return desc_.persistable(); }
8383

paddle/pybind/protobuf.cc

+1
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ void BindOpDesc(py::module &m) {
257257
.def("block_attr", &OpDescBind::GetBlockAttr)
258258
.def("check_attrs", &OpDescBind::CheckAttrs)
259259
.def("infer_shape", &OpDescBind::InferShape)
260+
.def("infer_var_type", &OpDescBind::InferVarType)
260261
.def("serialize_to_string", [](OpDescBind &op_desc) -> py::bytes {
261262
const OpDesc *desc = op_desc.Proto();
262263
PADDLE_ENFORCE(desc->IsInitialized(),

python/paddle/v2/framework/framework.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def __init__(self,
5353
if is_new_var:
5454
self.desc.set_data_type(dtype)
5555
else:
56-
old_dtype = self.data_type()
57-
if dtype != old_shape:
56+
old_dtype = self.data_type
57+
if dtype != old_dtype:
5858
raise ValueError("Variable {0} has been created before. "
5959
"The previous data type is {1}; the new "
6060
"data type is {2}. They are not "
@@ -191,7 +191,6 @@ def __init__(self,
191191
"`type` to initilized an Operator can not be None.")
192192
self.desc.set_type(type)
193193
proto = OpProtoHolder.instance().get_op_proto(type)
194-
195194
if inputs is not None:
196195
given = set()
197196
need = set()
@@ -206,6 +205,7 @@ def __init__(self,
206205
str(e) for e in given)))
207206

208207
for in_proto in proto.inputs:
208+
209209
in_argus = inputs[in_proto.name]
210210
if not isinstance(in_argus, list):
211211
in_argus = [in_argus]
@@ -257,6 +257,7 @@ def __init__(self,
257257

258258
self.desc.check_attrs()
259259
if type not in {'feed', 'fetch'}:
260+
self.desc.infer_var_type(self.block.desc)
260261
self.desc.infer_shape(self.block.desc)
261262

262263
def __str__(self):

python/paddle/v2/framework/layer_helper.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,7 @@ def create_parameter(self, attr, shape, dtype, suffix='w'):
120120
if attr['name'] is None:
121121
attr['name'] = unique_name(".".join([self.name, suffix]))
122122
self.init_program.global_block().create_parameter(
123-
name=attr['name'],
124-
dtype=dtype,
125-
shape=shape,
126-
init_attr=attr['init_attr'])
123+
dtype=dtype, shape=shape, **attr)
127124
return self.program.global_block().create_parameter(
128125
name=attr['name'], dtype=dtype, shape=shape)
129126

python/paddle/v2/framework/layers.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from paddle.v2.framework.framework import OpProtoHolder, Variable
44
import re
55

6-
__all__ = ['fc', 'data', 'cross_entropy', 'conv2d', 'pool2d']
6+
__all__ = [
7+
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat'
8+
]
79

810

911
def fc(input,
@@ -55,6 +57,24 @@ def fc(input,
5557
return helper.append_activation(pre_activation)
5658

5759

60+
def embedding(input,
61+
size,
62+
data_type='float32',
63+
param_attr=None,
64+
program=None,
65+
init_program=None):
66+
helper = LayerHelper('embedding', **locals())
67+
w = helper.create_parameter(
68+
attr=helper.param_attr, shape=size, dtype=data_type)
69+
tmp = helper.create_tmp_variable(data_type)
70+
helper.append_op(
71+
type='lookup_table',
72+
inputs={'Ids': input,
73+
'W': w},
74+
outputs={'Out': tmp})
75+
return tmp
76+
77+
5878
def data(name,
5979
shape,
6080
data_type='float32',
@@ -122,6 +142,19 @@ def func(**kwargs):
122142
_create_op_func_('mul')
123143

124144

145+
def concat(input, axis, program=None, init_program=None):
146+
helper = LayerHelper('concat', **locals())
147+
if not isinstance(input, list) and not isinstance(input, tuple):
148+
input = [input]
149+
out = helper.create_tmp_variable(dtype=input[0].data_type)
150+
helper.append_op(
151+
type='concat',
152+
inputs={'X': input},
153+
outputs={'Out': [out]},
154+
attrs={'axis': axis})
155+
return out
156+
157+
125158
def cross_entropy(input, label, **kwargs):
126159
helper = LayerHelper('cross_entropy', **kwargs)
127160
out = helper.create_tmp_variable(dtype=input.data_type)

python/paddle/v2/framework/tests/test_layers.py

+71
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,77 @@ def test_recognize_digits_conv(self):
8888

8989
print str(program)
9090

91+
def test_word_embedding(self):
92+
program = Program()
93+
dict_size = 10000
94+
embed_size = 32
95+
first_word = layers.data(
96+
name='firstw', shape=[1], data_type='int32', program=program)
97+
second_word = layers.data(
98+
name='secondw', shape=[1], data_type='int32', program=program)
99+
third_word = layers.data(
100+
name='thirdw', shape=[1], data_type='int32', program=program)
101+
forth_word = layers.data(
102+
name='forthw', shape=[1], data_type='int32', program=program)
103+
next_word = layers.data(
104+
name='nextw', shape=[1], data_type='int32', program=program)
105+
106+
embed_param_attr_1 = {
107+
'name': 'shared_w',
108+
'init_attr': {
109+
'max': 1.0,
110+
'type': 'uniform_random',
111+
'min': -1.0
112+
}
113+
}
114+
embed_param_attr_2 = {'name': 'shared_w'}
115+
116+
embed_first = layers.embedding(
117+
input=first_word,
118+
size=[dict_size, embed_size],
119+
data_type='float32',
120+
param_attr=embed_param_attr_1,
121+
program=program)
122+
embed_second = layers.embedding(
123+
input=second_word,
124+
size=[dict_size, embed_size],
125+
data_type='float32',
126+
param_attr=embed_param_attr_2,
127+
program=program)
128+
129+
embed_third = layers.embedding(
130+
input=third_word,
131+
size=[dict_size, embed_size],
132+
data_type='float32',
133+
param_attr=embed_param_attr_2,
134+
program=program)
135+
embed_forth = layers.embedding(
136+
input=forth_word,
137+
size=[dict_size, embed_size],
138+
data_type='float32',
139+
param_attr=embed_param_attr_2,
140+
program=program)
141+
142+
concat_embed = layers.concat(
143+
input=[embed_first, embed_second, embed_third, embed_forth],
144+
axis=1,
145+
program=program)
146+
147+
hidden1 = layers.fc(input=concat_embed,
148+
size=256,
149+
act='sigmoid',
150+
program=program)
151+
predict_word = layers.fc(input=hidden1,
152+
size=dict_size,
153+
act='softmax',
154+
program=program)
155+
cost = layers.cross_entropy(
156+
input=predict_word, label=next_word, program=program)
157+
avg_cost = layers.mean(x=cost, program=program)
158+
self.assertIsNotNone(avg_cost)
159+
160+
print str(program)
161+
91162

92163
if __name__ == '__main__':
93164
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import paddle.v2 as paddle
2+
import paddle.v2.framework.layers as layers
3+
import paddle.v2.framework.core as core
4+
import paddle.v2.framework.optimizer as optimizer
5+
6+
from paddle.v2.framework.framework import Program, g_program
7+
from paddle.v2.framework.executor import Executor
8+
9+
import numpy as np
10+
11+
init_program = Program()
12+
program = Program()
13+
14+
embed_size = 32
15+
hidden_size = 256
16+
N = 5
17+
batch_size = 32
18+
19+
word_dict = paddle.dataset.imikolov.build_dict()
20+
dict_size = len(word_dict)
21+
22+
first_word = layers.data(
23+
name='firstw',
24+
shape=[1],
25+
data_type='int32',
26+
program=program,
27+
init_program=init_program)
28+
second_word = layers.data(
29+
name='secondw',
30+
shape=[1],
31+
data_type='int32',
32+
program=program,
33+
init_program=init_program)
34+
third_word = layers.data(
35+
name='thirdw',
36+
shape=[1],
37+
data_type='int32',
38+
program=program,
39+
init_program=init_program)
40+
forth_word = layers.data(
41+
name='forthw',
42+
shape=[1],
43+
data_type='int32',
44+
program=program,
45+
init_program=init_program)
46+
next_word = layers.data(
47+
name='nextw',
48+
shape=[1],
49+
data_type='int32',
50+
program=program,
51+
init_program=init_program)
52+
53+
embed_param_attr_1 = {
54+
'name': 'shared_w',
55+
'init_attr': {
56+
'max': 1.0,
57+
'type': 'uniform_random',
58+
'min': -1.0
59+
}
60+
}
61+
embed_param_attr_2 = {'name': 'shared_w'}
62+
63+
embed_first = layers.embedding(
64+
input=first_word,
65+
size=[dict_size, embed_size],
66+
data_type='float32',
67+
param_attr=embed_param_attr_1,
68+
program=program,
69+
init_program=init_program)
70+
embed_second = layers.embedding(
71+
input=second_word,
72+
size=[dict_size, embed_size],
73+
data_type='float32',
74+
param_attr=embed_param_attr_2,
75+
program=program,
76+
init_program=init_program)
77+
78+
embed_third = layers.embedding(
79+
input=third_word,
80+
size=[dict_size, embed_size],
81+
data_type='float32',
82+
param_attr=embed_param_attr_2,
83+
program=program,
84+
init_program=init_program)
85+
embed_forth = layers.embedding(
86+
input=forth_word,
87+
size=[dict_size, embed_size],
88+
data_type='float32',
89+
param_attr=embed_param_attr_2,
90+
program=program,
91+
init_program=init_program)
92+
93+
concat_embed = layers.concat(
94+
input=[embed_first, embed_second, embed_third, embed_forth],
95+
axis=1,
96+
program=program,
97+
init_program=init_program)
98+
99+
hidden1 = layers.fc(input=concat_embed,
100+
size=hidden_size,
101+
act='sigmoid',
102+
program=program,
103+
init_program=init_program)
104+
predict_word = layers.fc(input=hidden1,
105+
size=dict_size,
106+
act='softmax',
107+
program=program,
108+
init_program=init_program)
109+
cost = layers.cross_entropy(
110+
input=predict_word,
111+
label=next_word,
112+
program=program,
113+
init_program=init_program)
114+
avg_cost = layers.mean(x=cost, program=program, init_program=init_program)
115+
116+
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
117+
opts = sgd_optimizer.minimize(avg_cost)
118+
119+
train_reader = paddle.batch(
120+
paddle.dataset.imikolov.train(word_dict, N), batch_size)
121+
122+
place = core.CPUPlace()
123+
exe = Executor(place)
124+
125+
exe.run(init_program, feed={}, fetch_list=[])
126+
PASS_NUM = 100
127+
for pass_id in range(PASS_NUM):
128+
for data in train_reader():
129+
input_data = [[data_idx[idx] for data_idx in data] for idx in xrange(5)]
130+
input_data = map(lambda x: np.array(x).astype("int32"), input_data)
131+
input_data = map(lambda x: np.expand_dims(x, axis=1), input_data)
132+
133+
first_data = input_data[0]
134+
first_tensor = core.LoDTensor()
135+
first_tensor.set(first_data, place)
136+
137+
second_data = input_data[0]
138+
second_tensor = core.LoDTensor()
139+
second_tensor.set(second_data, place)
140+
141+
third_data = input_data[0]
142+
third_tensor = core.LoDTensor()
143+
third_tensor.set(third_data, place)
144+
145+
forth_data = input_data[0]
146+
forth_tensor = core.LoDTensor()
147+
forth_tensor.set(forth_data, place)
148+
149+
next_data = input_data[0]
150+
next_tensor = core.LoDTensor()
151+
next_tensor.set(next_data, place)
152+
153+
outs = exe.run(program,
154+
feed={
155+
'firstw': first_tensor,
156+
'secondw': second_tensor,
157+
'thirdw': third_tensor,
158+
'forthw': forth_tensor,
159+
'nextw': next_tensor
160+
},
161+
fetch_list=[avg_cost])
162+
out = np.array(outs[0])
163+
if out[0] < 10.0:
164+
exit(0) # if avg cost less than 10.0, we think our code is good.
165+
exit(1)

0 commit comments

Comments
 (0)