-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Add label semantic examples with new Fluid api #10368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add label semantic examples with new Fluid api #10368
Conversation
|
||
|
||
def inference_network(): | ||
word = fluid.layers.data( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put data layer definitions into lstm_net
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
mark_dict_len = 2 | ||
|
||
def lstm_net(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, | ||
**ignored): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need **ignored
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed, thanks.
return crf_decode | ||
|
||
def train_network(): | ||
mix_hidden_lr = 1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of the constants in this file is upper case snake case (BATCH_SIZE), some are snake case (mix_hidden_lr). Could you change all constants to upper case snake case (BATCH_SIZE)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks.
event.batch_id + 1, avg_cost)) | ||
|
||
if avg_cost > 0.01: # Low threshold for speeding up CI | ||
trainer.params.save(save_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that the API changed a little, please use trainer.save_params
(https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py#L103)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
def infer(use_cuda, save_path): | ||
params = fluid.Params(save_path) | ||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() | ||
inferencer = fluid.Inferencer(inference_network, params, place=place) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that the API changed a little (param.py is removed), please use
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py#L108 and https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py#L118
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Helin for the review. I just made a few changes. Can you please take a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One comment, LGTM otherwise. Thank you!
if (event.batch_id % 10) == 0: | ||
avg_cost = trainer.test(reader=test_reader) | ||
|
||
print('BatchID {1:04}, Loss {2:2.2}'.format(event.batch_id + 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be BatchID {0:04}, Loss {1:2.2}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
… simplify_fluid_labelrole
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. The new API is so clean.
Issues:
train_network()
andinference_network()
.