Skip to content

Commit 1cc8aad

Browse files
committed
fix the ci problem
1 parent abc66c9 commit 1cc8aad

File tree

7 files changed

+143
-177
lines changed

7 files changed

+143
-177
lines changed

paddle/fluid/operators/multi_dot_op.cc

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,11 @@ inline framework::DDim ComputeAndCheckShape(
6363
// If the last tensor is 1D of size n view it as a column vector (n, 1)
6464
if (last_dim.size() == 1) {
6565
last_dim = framework::make_ddim({static_cast<int>(last_dim[0]), 1});
66-
if (is_vector) {
67-
out_dim = framework::make_ddim({1});
68-
} else {
69-
out_dim = framework::make_ddim({first_dim[0]});
70-
}
66+
out_dim = is_vector ? framework::make_ddim({1})
67+
: framework::make_ddim({first_dim[0]});
7168
} else {
72-
if (is_vector) {
73-
out_dim = framework::make_ddim({last_dim[1]});
74-
} else {
75-
out_dim = framework::make_ddim({first_dim[0], last_dim[1]});
76-
}
69+
out_dim = is_vector ? framework::make_ddim({last_dim[1]})
70+
: framework::make_ddim({first_dim[0], last_dim[1]});
7771
}
7872

7973
auto width = first_dim[1];
@@ -83,21 +77,21 @@ inline framework::DDim ComputeAndCheckShape(
8377
"the input tensor of multi_dot op must be 2D."));
8478

8579
const auto& tmp_dim = inputs_dims[i];
86-
PADDLE_ENFORCE_EQ(tmp_dim[0], width,
87-
platform::errors::InvalidArgument(
88-
"the input tensor of multi_dot op must be 2D."));
80+
PADDLE_ENFORCE_EQ(
81+
tmp_dim[0], width,
82+
platform::errors::InvalidArgument(
83+
"the input matrix does not meet the multiplication requirements."));
8984
width = tmp_dim[1];
9085
}
91-
PADDLE_ENFORCE_EQ(last_dim[0], width,
92-
platform::errors::InvalidArgument(
93-
"the input tensor of multi_dot op must be 2D."));
86+
87+
PADDLE_ENFORCE_EQ(
88+
last_dim[0], width,
89+
platform::errors::InvalidArgument(
90+
"the input matrix does not meet the multiplication requirements."));
9491

9592
return out_dim;
9693
}
9794

98-
/**
99-
* @brief the matrix multiplication
100-
*/
10195
template <typename DeviceContext, typename T>
10296
inline framework::Tensor MatMul(const framework::ExecutionContext& ctx,
10397
const framework::Tensor& matrix_a,
@@ -109,8 +103,8 @@ inline framework::Tensor MatMul(const framework::ExecutionContext& ctx,
109103

110104
framework::Tensor matrix_c;
111105
framework::DDim c_dim = framework::make_ddim({a_dim[0], b_dim[1]});
112-
matrix_c.mutable_data<T>(place, c_dim[0] * c_dim[1] * sizeof(T));
113106
matrix_c.Resize(c_dim);
107+
matrix_c.mutable_data<T>(place);
114108

115109
auto mat_dim_a = math::CreateMatrixDescriptor(a_dim, 0, false);
116110
auto mat_dim_b = math::CreateMatrixDescriptor(b_dim, 0, false);
@@ -330,27 +324,23 @@ class MultiDotKernel : public framework::OpKernel<T> {
330324
const auto Ka = ins_dims[0][1];
331325
const auto Nb = ins_dims[1][1];
332326
const auto Nc = ins_dims[2][1];
333-
const uint64_t cost1 =
334-
Ma * Nb * (Ka + Nc); // Ma * Ka * Nb + Ma * Nb * Nc;
335-
const uint64_t cost2 =
336-
Ka * Nc * (Nb + Ma); // Ka * Nb * Nc + Ma * Ka * Nc;
327+
const uint64_t cost1 = Ma * Nb * (Ka + Nc);
328+
const uint64_t cost2 = Ka * Nc * (Nb + Ma);
337329
auto mat_dim_a = math::CreateMatrixDescriptor(ins_dims[0], 0, false);
338330
auto mat_dim_b = math::CreateMatrixDescriptor(ins_dims[1], 0, false);
339331
auto mat_dim_c = math::CreateMatrixDescriptor(ins_dims[2], 0, false);
340332
if (cost1 < cost2) {
341333
framework::Tensor tmp_out;
342334
tmp_out.mutable_data<T>(place, Ma * Nb * sizeof(T));
343-
framework::DDim tmp_dim = ins_dims[0];
344-
tmp_dim[1] = Nb;
335+
framework::DDim tmp_dim = framework::make_ddim({Ma, Nb});
345336
blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, scale, &tmp_out,
346337
T(0));
347338
auto mat_dim_tmp = math::CreateMatrixDescriptor(tmp_dim, 0, false);
348339
blas.MatMul(tmp_out, mat_dim_tmp, *ins[2], mat_dim_c, scale, out, T(0));
349340
} else {
350341
framework::Tensor tmp_out;
351342
tmp_out.mutable_data<T>(place, Ka * Nc * sizeof(T));
352-
framework::DDim tmp_dim = ins_dims[1];
353-
tmp_dim[1] = Nc;
343+
framework::DDim tmp_dim = framework::make_ddim({Ka, Nc});
354344
blas.MatMul(*ins[1], mat_dim_b, *ins[2], mat_dim_c, scale, &tmp_out,
355345
T(0));
356346
auto mat_dim_tmp = math::CreateMatrixDescriptor(tmp_dim, 0, false);
@@ -361,7 +351,6 @@ class MultiDotKernel : public framework::OpKernel<T> {
361351
const auto tmp = MultiDotMatChainOrder<DeviceContext, T>(
362352
ctx, ins, ins_dims, false, &results);
363353
auto out_dim = out->dims();
364-
// TensorCopy(tmp, place, ctx.device_context(), out);
365354
*out = tmp;
366355
out->Resize(out_dim);
367356
}
@@ -473,7 +462,6 @@ class MultiDotGradKernel : public framework::OpKernel<T> {
473462
dB.mutable_data<T>(ctx.GetPlace());
474463

475464
CalcGrad(ctx, dout, *A, *B, dout_dim, a_dim, b_dim, &dA, &dB);
476-
477465
MatChainMulGrad(ctx, dA, dx, ins, dA.dims(), ins_dims, order, i, right,
478466
results);
479467
MatChainMulGrad(ctx, dB, dx, ins, dB.dims(), ins_dims, order, left, j,
@@ -489,7 +477,6 @@ class MultiDotGradKernel : public framework::OpKernel<T> {
489477
auto order = GetOrder(ins, ins_dims);
490478
auto n = ins.size();
491479
std::vector<framework::Tensor> results(n * n);
492-
// call the forward, get the itermediate result
493480
MatChainMul<DeviceContext, T>(ctx, ins, ins_dims, order, 0, n - 1, true,
494481
&results);
495482
MatChainMulGrad(ctx, dout, dx, ins, dout_dim, ins_dims, order, 0, n - 1,
@@ -548,21 +535,10 @@ class MultiDotGradKernel : public framework::OpKernel<T> {
548535
tmp_out.mutable_data<T>(place);
549536
tmp_dout.Resize({mat_dim_dout.height_, Nb});
550537
tmp_dout.mutable_data<T>(place);
551-
// tmp_out = A * B
552538
blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, alpha, &tmp_out,
553539
T(0));
554-
555-
/*
556-
* dC = dout * transpose(tmp_out)
557-
* tmp_dout = dout * transpose(C)
558-
*/
559540
CalcGrad(ctx, dout, tmp_out, *ins[2], dout_dim, tmp_out.dims(),
560541
ins_dims[2], &tmp_dout, dx[2]);
561-
562-
/*
563-
* dA = tmp_dout * transpose(B)
564-
* dB = tmp_dout * transpose(A)
565-
*/
566542
CalcGrad(ctx, tmp_dout, *ins[0], *ins[1], tmp_dout.dims(), ins_dims[0],
567543
ins_dims[1], dx[0], dx[1]);
568544
} else {
@@ -573,18 +549,8 @@ class MultiDotGradKernel : public framework::OpKernel<T> {
573549
tmp_dout.mutable_data<T>(place);
574550
blas.MatMul(*ins[1], mat_dim_b, *ins[2], mat_dim_c, alpha, &tmp_out,
575551
T(0));
576-
577-
/*
578-
* dA = dout * transpose(tmp_out)
579-
* tmp_out = dout * transpose(A)
580-
*/
581552
CalcGrad(ctx, dout, *ins[0], tmp_out, dout_dim, ins_dims[0],
582553
tmp_dout.dims(), dx[0], &tmp_dout);
583-
584-
/*
585-
* dB = tmp_dout * transpose(C)
586-
* dC = tmp_dout * transpose(B)
587-
*/
588554
CalcGrad(ctx, tmp_dout, *ins[1], *ins[2], tmp_dout.dims(), ins_dims[1],
589555
ins_dims[2], dx[1], dx[2]);
590556
}

python/paddle/fluid/tests/unittests/test_multi_dot_op.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def setUp(self):
3131
self.get_inputs_and_outputs()
3232

3333
def get_dtype(self):
34-
return "float32"
34+
return "float64"
3535

3636
def get_inputs_and_outputs(self):
3737
self.A = np.random.random((2, 8)).astype(self.dtype)
@@ -43,8 +43,8 @@ def test_check_output(self):
4343
self.check_output()
4444

4545
def test_check_grad(self):
46-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
47-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
46+
self.check_grad(['x0'], 'Out')
47+
self.check_grad(['x1'], 'Out')
4848

4949

5050
class TestMultiDotOpDouble(TestMultiDotOp):
@@ -62,9 +62,9 @@ def get_inputs_and_outputs(self):
6262
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
6363

6464
def test_check_grad(self):
65-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
66-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
67-
self.check_grad(['x2'], 'Out', max_relative_error=1e-3)
65+
self.check_grad(['x0'], 'Out')
66+
self.check_grad(['x1'], 'Out')
67+
self.check_grad(['x2'], 'Out')
6868

6969

7070
#A*(B*C)
@@ -77,9 +77,9 @@ def get_inputs_and_outputs(self):
7777
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
7878

7979
def test_check_grad(self):
80-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
81-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
82-
self.check_grad(['x2'], 'Out', max_relative_error=1e-3)
80+
self.check_grad(['x0'], 'Out')
81+
self.check_grad(['x1'], 'Out')
82+
self.check_grad(['x2'], 'Out')
8383

8484

8585
class TestMultiDotOp4Mat(TestMultiDotOp):
@@ -95,10 +95,10 @@ def get_inputs_and_outputs(self):
9595
self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}
9696

9797
def test_check_grad(self):
98-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
99-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
100-
self.check_grad(['x2'], 'Out', max_relative_error=1e-3)
101-
self.check_grad(['x3'], 'Out', max_relative_error=1e-3)
98+
self.check_grad(['x0'], 'Out')
99+
self.check_grad(['x1'], 'Out')
100+
self.check_grad(['x2'], 'Out')
101+
self.check_grad(['x3'], 'Out')
102102

103103

104104
class TestMultiDotOpFirst1D(TestMultiDotOp):
@@ -118,9 +118,9 @@ def get_inputs_and_outputs(self):
118118
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
119119

120120
def test_check_grad(self):
121-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
122-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
123-
self.check_grad(['x2'], 'Out', max_relative_error=1e-3)
121+
self.check_grad(['x0'], 'Out')
122+
self.check_grad(['x1'], 'Out')
123+
self.check_grad(['x2'], 'Out')
124124

125125

126126
class TestMultiDotOp4MatFirst1D(TestMultiDotOp):
@@ -136,10 +136,10 @@ def get_inputs_and_outputs(self):
136136
self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}
137137

138138
def test_check_grad(self):
139-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
140-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
141-
self.check_grad(['x2'], 'Out', max_relative_error=1e-3)
142-
self.check_grad(['x3'], 'Out', max_relative_error=1e-3)
139+
self.check_grad(['x0'], 'Out')
140+
self.check_grad(['x1'], 'Out')
141+
self.check_grad(['x2'], 'Out')
142+
self.check_grad(['x3'], 'Out')
143143

144144

145145
class TestMultiDotOpLast1D(TestMultiDotOp):
@@ -159,9 +159,9 @@ def get_inputs_and_outputs(self):
159159
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
160160

161161
def test_check_grad(self):
162-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
163-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
164-
self.check_grad(['x2'], 'Out', max_relative_error=1e-3)
162+
self.check_grad(['x0'], 'Out')
163+
self.check_grad(['x1'], 'Out')
164+
self.check_grad(['x2'], 'Out')
165165

166166

167167
class TestMultiDotOp4MatLast1D(TestMultiDotOp):
@@ -177,10 +177,10 @@ def get_inputs_and_outputs(self):
177177
self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}
178178

179179
def test_check_grad(self):
180-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
181-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
182-
self.check_grad(['x2'], 'Out', max_relative_error=1e-3)
183-
self.check_grad(['x3'], 'Out', max_relative_error=1e-3)
180+
self.check_grad(['x0'], 'Out')
181+
self.check_grad(['x1'], 'Out')
182+
self.check_grad(['x2'], 'Out')
183+
self.check_grad(['x3'], 'Out')
184184

185185

186186
class TestMultiDotOpFirstAndLast1D(TestMultiDotOp):
@@ -191,8 +191,8 @@ def get_inputs_and_outputs(self):
191191
self.outputs = {'Out': multi_dot([self.A, self.B])}
192192

193193
def test_check_grad(self):
194-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
195-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
194+
self.check_grad(['x0'], 'Out')
195+
self.check_grad(['x1'], 'Out')
196196

197197

198198
class TestMultiDotOp3MatFirstAndLast1D(TestMultiDotOp):
@@ -204,9 +204,9 @@ def get_inputs_and_outputs(self):
204204
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
205205

206206
def test_check_grad(self):
207-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
208-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
209-
self.check_grad(['x2'], 'Out', max_relative_error=1e-3)
207+
self.check_grad(['x0'], 'Out')
208+
self.check_grad(['x1'], 'Out')
209+
self.check_grad(['x2'], 'Out')
210210

211211

212212
class TestMultiDotOp4MatFirstAndLast1D(TestMultiDotOp):
@@ -222,10 +222,10 @@ def get_inputs_and_outputs(self):
222222
self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}
223223

224224
def test_check_grad(self):
225-
self.check_grad(['x0'], 'Out', max_relative_error=1e-3)
226-
self.check_grad(['x1'], 'Out', max_relative_error=1e-3)
227-
self.check_grad(['x2'], 'Out', max_relative_error=1e-3)
228-
self.check_grad(['x3'], 'Out', max_relative_error=1e-3)
225+
self.check_grad(['x0'], 'Out')
226+
self.check_grad(['x1'], 'Out')
227+
self.check_grad(['x2'], 'Out')
228+
self.check_grad(['x3'], 'Out')
229229

230230

231231
#####python API test#######
@@ -236,41 +236,41 @@ def test_errors(self):
236236
input1 = 12
237237
self.assertRaises(TypeError, paddle.multi_dot, [input1, input1])
238238

239-
# The inputs dtype of multi_dot must be float32, float64 or float16.
239+
# The inputs dtype of multi_dot must be float64, float64 or float16.
240240
input2 = fluid.layers.data(
241241
name='input2', shape=[10, 10], dtype="int32")
242242
self.assertRaises(TypeError, paddle.multi_dot, [input2, input2])
243243

244244
# the number of tensor must be larger than 1
245-
x0 = fluid.data(name='x0', shape=[3, 2], dtype="float32")
245+
x0 = fluid.data(name='x0', shape=[3, 2], dtype="float64")
246246
self.assertRaises(ValueError, paddle.multi_dot, [x0])
247247

248248
#the first tensor must be 1D or 2D
249-
x1 = fluid.data(name='x1', shape=[3, 2, 3], dtype="float32")
250-
x2 = fluid.data(name='x2', shape=[3, 2], dtype="float32")
249+
x1 = fluid.data(name='x1', shape=[3, 2, 3], dtype="float64")
250+
x2 = fluid.data(name='x2', shape=[3, 2], dtype="float64")
251251
self.assertRaises(ValueError, paddle.multi_dot, [x1, x2])
252252

253253
#the last tensor must be 1D or 2D
254-
x3 = fluid.data(name='x3', shape=[3, 2], dtype="float32")
255-
x4 = fluid.data(name='x4', shape=[3, 2, 2], dtype="float32")
254+
x3 = fluid.data(name='x3', shape=[3, 2], dtype="float64")
255+
x4 = fluid.data(name='x4', shape=[3, 2, 2], dtype="float64")
256256
self.assertRaises(ValueError, paddle.multi_dot, [x3, x4])
257257

258258
#the tensor must be 2D, except first and last tensor
259-
x5 = fluid.data(name='x5', shape=[3, 2], dtype="float32")
260-
x6 = fluid.data(name='x6', shape=[2], dtype="float32")
261-
x7 = fluid.data(name='x7', shape=[2, 2], dtype="float32")
259+
x5 = fluid.data(name='x5', shape=[3, 2], dtype="float64")
260+
x6 = fluid.data(name='x6', shape=[2], dtype="float64")
261+
x7 = fluid.data(name='x7', shape=[2, 2], dtype="float64")
262262
self.assertRaises(ValueError, paddle.multi_dot, [x5, x6, x7])
263263

264264

265265
class API_TestMultiDot(unittest.TestCase):
266266
def test_out(self):
267267
with fluid.program_guard(fluid.Program()):
268-
x0 = fluid.data(name='x0', shape=[3, 2], dtype="float32")
269-
x1 = fluid.data(name='x1', shape=[2, 3], dtype='float32')
268+
x0 = fluid.data(name='x0', shape=[3, 2], dtype="float64")
269+
x1 = fluid.data(name='x1', shape=[2, 3], dtype='float64')
270270
result = paddle.multi_dot([x0, x1])
271271
exe = fluid.Executor(fluid.CPUPlace())
272-
data1 = np.random.rand(3, 2).astype("float32")
273-
data2 = np.random.rand(2, 3).astype("float32")
272+
data1 = np.random.rand(3, 2).astype("float64")
273+
data2 = np.random.rand(2, 3).astype("float64")
274274
np_res = exe.run(feed={'x0': data1,
275275
'x1': data2},
276276
fetch_list=[result])

python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
'rank_loss',
6666
'sequence_conv',
6767
'smooth_l1_loss',
68-
'spectral_norm',
69-
'multi_dot',
68+
'spectral_norm'
7069
]
7170
# yapf: enable

python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@
7676
'trilinear_interp_v2', \
7777
'var_conv_2d', \
7878
'warpctc', \
79-
'bilateral_slice',
80-
'multi_dot'
79+
'bilateral_slice'
8180
]
8281

8382
NO_FP16_CHECK_GRAD_OP_LIST = [

python/paddle/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@
2222
'cholesky', #noqa
2323
'norm',
2424
'inv',
25+
'matrix_power',
2526
'multi_dot'
26-
'matrix_power'
2727
]

0 commit comments

Comments
 (0)