Skip to content

Commit 236af56

Browse files
author
Yibing Liu
committed
separate index tensor from candidate tensors in multiplex_op
2 parents f94109d + e114aad commit 236af56

File tree

11 files changed

+427
-51
lines changed

11 files changed

+427
-51
lines changed

doc/faq/index_cn.rst

+122-1
Original file line numberDiff line numberDiff line change
@@ -390,4 +390,125 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
390390

391391
* 如果发现最早的报错就是网络通信的问题,很有可能是非独占方式执行导致的端口冲突,可以联系OP,看当前MPI集群是否支持resource=full参数提交,如果支持增加此参数提交,并更换job 端口。
392392

393-
* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。
393+
* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。
394+
395+
19. PaddlePaddle如何输出多个层
396+
------------------------------
397+
398+
* 将需要输出的层作为 :code:`paddle.inference.Inference()` 接口的 :code:`output_layer` 参数输入,代码如下:
399+
400+
.. code-block:: python
401+
402+
inferer = paddle.inference.Inference(output_layer=[layer1, layer2], parameters=parameters)
403+
404+
* 指定要输出的字段进行输出。以输出 :code:`value` 字段为例,代码如下:
405+
406+
.. code-block:: python
407+
408+
out = inferer.infer(input=data_batch, flatten_result=False, field=["value"])
409+
410+
这里设置 :code:`flatten_result=False`,得到的输出结果是元素个数等于输出字段数的 :code:`list`,该 :code:`list` 的每个元素是由所有输出层相应字段结果组成的 :code:`list`,每个字段结果的类型是 :code:`numpy.array`。:code:`flatten_result` 的默认值为 :code:`True`,该情况下,PaddlePaddle会分别对每个字段将所有输出层的结果按行进行拼接,如果各输出层该字段 :code:`numpy.array` 结果的相应维数不匹配,程序将不能正常运行。
411+
412+
20. :code:`paddle.layer.memory` 的参数 :code:`name` 如何使用
413+
-------------------------------------------------------------
414+
415+
* :code:`paddle.layer.memory` 用于获取特定layer上一时间步的输出,该layer是通过参数 :code:`name` 指定,即,:code:`paddle.layer.memory` 会关联参数 :code:`name` 取值相同的layer,并将该layer上一时间步的输出作为自身当前时间步的输出。
416+
417+
* PaddlePaddle的所有layer都有唯一的name,用户通过参数 :code:`name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。而 :code:`paddle.layer.memory` 不是真正的layer,其name由参数 :code:`memory_name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。:code:`paddle.layer.memory` 的参数 :code:`name` 用于指定其要关联的layer,需要用户显式设定。
418+
419+
21. dropout 使用
420+
-----------------
421+
422+
* 在PaddlePaddle中使用dropout有两种方式
423+
424+
* 在相应layer的 :code:`layer_atter` 设置 :code:`drop_rate`,以 :code:`paddle.layer.fc` 为例,代码如下:
425+
426+
.. code-block:: python
427+
428+
fc = paddle.layer.fc(input=input, layer_attr=paddle.attr.ExtraLayerAttribute(drop_rate=0.5))
429+
430+
* 使用 :code:`paddle.layer.dropout`,以 :code:`paddle.layer.fc` 为例,代码如下:
431+
432+
.. code-block:: python
433+
434+
fc = paddle.layer.fc(input=input)
435+
drop_fc = paddle.layer.dropout(input=fc, dropout_rate=0.5)
436+
437+
* :code:`paddle.layer.dropout` 实际上使用了 :code:`paddle.layer.add_to`,并在该layer里采用第一种方式设置 :code:`drop_rate` 来使用dropout的。这种方式对内存消耗较大。
438+
439+
* PaddlePaddle在激活函数里实现dropout,而不是在layer里实现。
440+
441+
* :code:`paddle.layer.lstmemory`、:code:`paddle.layer.grumemory`、:code:`paddle.layer.recurrent` 不是通过一般的方式来实现对输出的激活,所以不能采用第一种方式在这几个layer里设置 :code:`drop_rate` 来使用dropout。若要对这几个layer使用dropout,可采用第二种方式,即使用 :code:`paddle.layer.dropout`。
442+
443+
22. 如何设置学习率退火(learning rate annealing)
444+
------------------------------------------------
445+
446+
在相应的优化算法里设置learning_rate_schedule及相关参数,以使用Adam算法为例,代码如下:
447+
448+
.. code-block:: python
449+
450+
optimizer = paddle.optimizer.Adam(
451+
learning_rate=1e-3,
452+
learning_rate_decay_a=0.5,
453+
learning_rate_decay_b=0.75,
454+
learning_rate_schedule="poly",)
455+
456+
PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedule及其对应学习率计算方式如下:
457+
458+
* "constant"
459+
460+
lr = learning_rate
461+
462+
* "poly"
463+
464+
lr = learning_rate * pow(1 + learning_rate_decay_a * num_samples_processed, -learning_rate_decay_b)
465+
466+
其中,num_samples_processed为已训练样本数,下同。
467+
468+
* "caffe_poly"
469+
470+
lr = learning_rate * pow(1.0 - num_samples_processed / learning_rate_decay_a, learning_rate_decay_b)
471+
472+
* "exp"
473+
474+
lr = learning_rate * pow(learning_rate_decay_a, num_samples_processed / learning_rate_decay_b)
475+
476+
* "discexp"
477+
478+
lr = learning_rate * pow(learning_rate_decay_a, floor(num_samples_processed / learning_rate_decay_b))
479+
480+
* "linear"
481+
482+
lr = max(learning_rate - learning_rate_decay_a * num_samples_processed, learning_rate_decay_b)
483+
484+
* "manual"
485+
486+
这是一种按已训练样本数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下:
487+
488+
.. code-block:: python
489+
490+
optimizer = paddle.optimizer.Adam(
491+
learning_rate=1e-3,
492+
learning_rate_schedule="manual",
493+
learning_rate_args="1000:1.0,2000:0.9,3000:0.8",)
494+
495+
在该示例中,当已训练样本数小于等于1000时,学习率为 :code:`1e-3 * 1.0`;当已训练样本数大于1000小于等于2000时,学习率为 :code:`1e-3 * 0.9`;当已训练样本数大于2000时,学习率为 :code:`1e-3 * 0.8`。
496+
497+
* "pass_manual"
498+
499+
这是一种按已训练pass数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下:
500+
501+
.. code-block:: python
502+
503+
optimizer = paddle.optimizer.Adam(
504+
learning_rate=1e-3,
505+
learning_rate_schedule="manual",
506+
learning_rate_args="1:1.0,2:0.9,3:0.8",)
507+
508+
在该示例中,当已训练pass数小于等于1时,学习率为 :code:`1e-3 * 1.0`;当已训练pass数大于1小于等于2时,学习率为 :code:`1e-3 * 0.9`;当已训练pass数大于2时,学习率为 :code:`1e-3 * 0.8`。
509+
510+
23. 出现 :code:`Duplicated layer name` 错误怎么办
511+
--------------------------------------------------
512+
513+
出现该错误的原因一般是用户对不同layer的参数 :code:`name` 设置了相同的取值。遇到该错误时,先找出参数 :code:`name` 取值相同的layer,然后将这些layer的参数 :code:`name` 设置为不同的值。
514+

paddle/operators/crop_op.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ class CropKernel : public framework::OpKernel {
3838
auto out_stride = framework::stride(out->dims());
3939
auto offsets = context.Attr<std::vector<int>>("offsets");
4040
PADDLE_ENFORCE_EQ(
41-
x->dims().size(), offsets.size(),
41+
x->dims().size(), static_cast<int64_t>(offsets.size()),
4242
"Offsets size should be equal to dimension size of input tensor.");
4343
int64_t offset = 0;
44-
for (int i = 0; i < offsets.size(); ++i) {
44+
for (size_t i = 0; i < offsets.size(); ++i) {
4545
offset += (x_stride[i] * offsets[i]);
4646
}
4747
StridedMemcpy<T>(context.device_context(), x_data + offset, x_stride,
@@ -57,7 +57,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
5757
d_x->mutable_data<T>(context.GetPlace());
5858
auto offsets = context.Attr<std::vector<int>>("offsets");
5959
Eigen::array<std::pair<int, int>, D> paddings;
60-
for (int i = 0; i < D; ++i) {
60+
for (size_t i = 0; i < D; ++i) {
6161
paddings[i].first = offsets[i];
6262
paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i];
6363
}

paddle/operators/math/math_function.cc

+26
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,32 @@ void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
4848
beta, C, ldc);
4949
}
5050

51+
template <>
52+
void gemm<platform::CPUPlace, float>(const platform::DeviceContext& context,
53+
const bool transA, const bool transB,
54+
const int M, const int N, const int K,
55+
const float alpha, const float* A,
56+
const int lda, const float* B,
57+
const int ldb, const float beta, float* C,
58+
const int ldc) {
59+
cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
60+
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
61+
lda, B, ldb, beta, C, ldc);
62+
}
63+
64+
template <>
65+
void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
66+
const bool transA, const bool transB,
67+
const int M, const int N, const int K,
68+
const double alpha, const double* A,
69+
const int lda, const double* B,
70+
const int ldb, const double beta,
71+
double* C, const int ldc) {
72+
cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
73+
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
74+
lda, B, ldb, beta, C, ldc);
75+
}
76+
5177
template <>
5278
void matmul<platform::CPUPlace, float>(
5379
const platform::DeviceContext& context, const framework::Tensor& matrix_a,

paddle/operators/math/math_function.cu

+36
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,42 @@ void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
6363
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
6464
}
6565

66+
template <>
67+
void gemm<platform::GPUPlace, float>(const platform::DeviceContext& context,
68+
const bool transA, const bool transB,
69+
const int M, const int N, const int K,
70+
const float alpha, const float* A,
71+
const int lda, const float* B,
72+
const int ldb, const float beta, float* C,
73+
const int ldc) {
74+
// Note that cublas follows fortran order, so the order is different from
75+
// the cblas convention.
76+
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
77+
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
78+
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
79+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
80+
.cublas_handle(),
81+
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
82+
}
83+
84+
template <>
85+
void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
86+
const bool transA, const bool transB,
87+
const int M, const int N, const int K,
88+
const double alpha, const double* A,
89+
const int lda, const double* B,
90+
const int ldb, const double beta,
91+
double* C, const int ldc) {
92+
// Note that cublas follows fortran order, so the order is different from
93+
// the cblas convention.
94+
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
95+
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
96+
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
97+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
98+
.cublas_handle(),
99+
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
100+
}
101+
66102
template <>
67103
void matmul<platform::GPUPlace, float>(
68104
const platform::DeviceContext& context, const framework::Tensor& matrix_a,

paddle/operators/math/math_function.h

+7
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
7070
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
7171
const T alpha, const T* A, const T* B, const T beta, T* C);
7272

73+
// gemm wrapper with stride args for matrix uncontinuous in memory
74+
template <typename Place, typename T>
75+
void gemm(const platform::DeviceContext& context, const bool transA,
76+
const bool transB, const int M, const int N, const int K,
77+
const T alpha, const T* A, const int lda, const T* B, const int ldb,
78+
const T beta, T* C, const int ldc);
79+
7380
// matrix multiply with continuous memory
7481
template <typename Place, typename T>
7582
void matmul(const platform::DeviceContext& context,

0 commit comments

Comments
 (0)