@@ -11,7 +11,12 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)(
11
11
int kH ,
12
12
int dT ,
13
13
int dW ,
14
- int dH ) {
14
+ int dH ,
15
+ int padT ,
16
+ int padW ,
17
+ int padH ,
18
+ bool ceil_mode )
19
+ {
15
20
long nslices ;
16
21
long itime ;
17
22
long iheight ;
@@ -49,14 +54,46 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)(
49
54
input -> size [dimt ], input -> size [dimh ], input -> size [dimw ],
50
55
kT , kH , kW );
51
56
57
+ // The second argument is argNumber... here is the index of padH.
58
+ THArgCheck (kT /2 >= padT && kW /2 >= padW && kH /2 >= padH , 11 ,
59
+ "pad should not be greater than half of kernel size, but got "
60
+ "padT = %d, padW = %d, padH = %d, kT = %d, kW = %d, kH = %d" ,
61
+ padT , padW , padH , kT , kW , kH );
62
+
52
63
/* sizes */
53
64
nslices = input -> size [dimN ];
54
65
itime = input -> size [dimt ];
55
66
iheight = input -> size [dimh ];
56
67
iwidth = input -> size [dimw ];
57
- otime = (itime - kT ) / dT + 1 ;
58
- oheight = (iheight - kH ) / dH + 1 ;
59
- owidth = (iwidth - kW ) / dW + 1 ;
68
+
69
+ if (ceil_mode ) {
70
+ otime = (long )(ceil ((float )(itime - kT + 2 * padT ) / dT )) + 1 ;
71
+ oheight = (long )(ceil ((float )(iheight - kH + 2 * padH ) / dH )) + 1 ;
72
+ owidth = (long )(ceil ((float )(iwidth - kW + 2 * padW ) / dW )) + 1 ;
73
+ }
74
+ else
75
+ {
76
+ otime = (long )(floor ((float )(itime - kT + 2 * padT ) / dT )) + 1 ;
77
+ oheight = (long )(floor ((float )(iheight - kH + 2 * padH ) / dH )) + 1 ;
78
+ owidth = (long )(floor ((float )(iwidth - kW + 2 * padW ) / dW )) + 1 ;
79
+ }
80
+
81
+ if (padT || padW || padH )
82
+ {
83
+ // ensure that the last pooling starts inside the image
84
+ // needed to avoid problems in ceil mode
85
+ if ((otime - 1 )* dT >= itime + padT )
86
+ -- otime ;
87
+ if ((oheight - 1 )* dH >= iheight + padH )
88
+ -- oheight ;
89
+ if ((owidth - 1 )* dW >= iwidth + padW )
90
+ -- owidth ;
91
+ }
92
+
93
+ if (otime < 1 || owidth < 1 || oheight < 1 )
94
+ THError ("Given input size: (%dx%dx%dx%d). "
95
+ "Calculated output size: (%dx%dx%dx%d). Output size is too small" ,
96
+ nslices ,itime ,iheight ,iwidth ,nslices ,otime ,oheight ,owidth );
60
97
61
98
if (gradOutput != NULL ) {
62
99
THNN_CHECK_DIM_SIZE (gradOutput , ndim , dimN , nslices );
@@ -81,43 +118,69 @@ static void THNN_(VolumetricAveragePooling_updateOutput_frame)(
81
118
int kH ,
82
119
int dT ,
83
120
int dW ,
84
- int dH )
121
+ int dH ,
122
+ int padT ,
123
+ int padW ,
124
+ int padH ,
125
+ bool count_include_pad )
85
126
{
86
127
long k ;
87
128
#pragma omp parallel for private(k)
88
129
for (k = 0 ; k < nslices ; k ++ )
89
130
{
90
- /* loop over output */
91
131
long i , j , ti ;
132
+
133
+ /* local pointers. */
134
+ real * ip = input_p + k * itime * iwidth * iheight ;
135
+ real * op = output_p + k * otime * owidth * oheight ;
136
+ for (i = 0 ; i < otime * oheight * owidth ; ++ i )
137
+ * (op + i ) = 0 ;
138
+
139
+ /* loop over output */
92
140
for (ti = 0 ; ti < otime ; ti ++ )
93
141
{
94
142
for (i = 0 ; i < oheight ; i ++ )
95
143
{
96
144
for (j = 0 ; j < owidth ; j ++ )
97
145
{
98
- /* local pointers */
99
- real * ip = input_p + k * itime * iwidth * iheight
100
- + ti * iwidth * iheight * dT + i * iwidth * dH + j * dW ;
101
- real * op = output_p + k * otime * owidth * oheight
102
- + ti * owidth * oheight + i * owidth + j ;
146
+ /* compute pool range. */
147
+ long tstart = ti * dT - padT ;
148
+ long hstart = i * dH - padH ;
149
+ long wstart = j * dW - padW ;
150
+ long tend = fminf (tstart + kT , itime + padT );
151
+ long hend = fminf (hstart + kH , iheight + padH );
152
+ long wend = fminf (wstart + kW , iwidth + padW );
153
+ long pool_size = (tend - tstart ) * (hend - hstart ) * (wend - wstart );
154
+ tstart = fmaxf (tstart , 0 );
155
+ hstart = fmaxf (hstart , 0 );
156
+ wstart = fmaxf (wstart , 0 );
157
+ tend = fmin (tend , itime );
158
+ hend = fmin (hend , iheight );
159
+ wend = fmin (wend , iwidth );
160
+
161
+ int divide_factor ;
162
+ if (count_include_pad )
163
+ divide_factor = pool_size ;
164
+ else
165
+ divide_factor = (tend - tstart ) * (hend - hstart ) * (wend - wstart );
103
166
104
167
/* compute local sum: */
105
168
real sum = 0.0 ;
106
- int x , y , z ;
169
+ long x , y , z ;
107
170
108
- for (z = 0 ; z < kT ; z ++ )
171
+ for (z = tstart ; z < tend ; z ++ )
109
172
{
110
- for (y = 0 ; y < kH ; y ++ )
173
+ for (y = hstart ; y < hend ; y ++ )
111
174
{
112
- for (x = 0 ; x < kW ; x ++ )
175
+ for (x = wstart ; x < wend ; x ++ )
113
176
{
114
177
sum += * (ip + z * iwidth * iheight + y * iwidth + x );
115
178
}
116
179
}
117
180
}
118
181
119
182
/* set output to local max */
120
- * op = sum / ( kT * kW * kH ) ;
183
+ * op ++ + = sum / divide_factor ;
121
184
}
122
185
}
123
186
}
@@ -133,7 +196,12 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
133
196
int kH ,
134
197
int dT ,
135
198
int dW ,
136
- int dH )
199
+ int dH ,
200
+ int padT ,
201
+ int padW ,
202
+ int padH ,
203
+ bool ceil_mode ,
204
+ bool count_include_pad )
137
205
{
138
206
long nslices ;
139
207
long itime ;
@@ -147,7 +215,7 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
147
215
148
216
THNN_ (VolumetricAveragePooling_shapeCheck )(
149
217
state , input , NULL , kT , kW , kH ,
150
- dT , dW , dH );
218
+ dT , dW , dH , padT , padW , padH , ceil_mode );
151
219
152
220
int dimN = 0 ;
153
221
int dimt = 1 ;
@@ -167,9 +235,29 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
167
235
itime = input -> size [dimt ];
168
236
iheight = input -> size [dimh ];
169
237
iwidth = input -> size [dimw ];
170
- otime = (itime - kT ) / dT + 1 ;
171
- oheight = (iheight - kH ) / dH + 1 ;
172
- owidth = (iwidth - kW ) / dW + 1 ;
238
+ if (ceil_mode )
239
+ {
240
+ otime = (long )(ceil ((float )(itime - kT + 2 * padT ) / dT )) + 1 ;
241
+ oheight = (long )(ceil ((float )(iheight - kH + 2 * padH ) / dH )) + 1 ;
242
+ owidth = (long )(ceil ((float )(iwidth - kW + 2 * padW ) / dW )) + 1 ;
243
+ }
244
+ else
245
+ {
246
+ otime = (long )(floor ((float )(itime - kT + 2 * padT ) / dT )) + 1 ;
247
+ oheight = (long )(floor ((float )(iheight - kH + 2 * padH ) / dH )) + 1 ;
248
+ owidth = (long )(floor ((float )(iwidth - kW + 2 * padW ) / dW )) + 1 ;
249
+ }
250
+ if (padT || padH || padW )
251
+ {
252
+ // ensure that the last pooling starts inside the image
253
+ // needed to avoid problems in ceil mode
254
+ if ((otime - 1 )* dT >= itime + padT )
255
+ -- otime ;
256
+ if ((oheight - 1 )* dH >= iheight + padH )
257
+ -- oheight ;
258
+ if ((owidth - 1 )* dW >= iwidth + padW )
259
+ -- owidth ;
260
+ }
173
261
174
262
/* get contiguous input */
175
263
input = THTensor_ (newContiguous )(input );
@@ -187,7 +275,9 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
187
275
itime , iwidth , iheight ,
188
276
otime , owidth , oheight ,
189
277
kT , kW , kH ,
190
- dT , dW , dH
278
+ dT , dW , dH ,
279
+ padT , padW , padH ,
280
+ count_include_pad
191
281
);
192
282
}
193
283
else /* batch mode */
@@ -212,7 +302,9 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
212
302
itime , iwidth , iheight ,
213
303
otime , owidth , oheight ,
214
304
kT , kW , kH ,
215
- dT , dW , dH
305
+ dT , dW , dH ,
306
+ padT , padW , padH ,
307
+ count_include_pad
216
308
);
217
309
}
218
310
}
@@ -236,36 +328,62 @@ static void THNN_(VolumetricAveragePooling_updateGradInput_frame)(
236
328
int kH ,
237
329
int dT ,
238
330
int dW ,
239
- int dH )
331
+ int dH ,
332
+ int padT ,
333
+ int padW ,
334
+ int padH ,
335
+ bool count_include_pad )
240
336
{
241
337
long k ;
242
338
#pragma omp parallel for private(k)
243
339
for (k = 0 ; k < nslices ; k ++ )
244
340
{
245
- /* loop over output */
246
341
long i , j , ti ;
342
+
343
+ /* local pointers */
344
+ real * ip = gradInput_p + k * itime * iwidth * iheight ;
345
+ real * op = gradOutput_p + k * otime * owidth * oheight ;
346
+ for (i = 0 ; i < itime * iwidth * iheight ; i ++ )
347
+ * (ip + i ) = 0 ;
348
+
349
+ /* loop over output */
247
350
for (ti = 0 ; ti < otime ; ti ++ )
248
351
{
249
352
for (i = 0 ; i < oheight ; i ++ )
250
353
{
251
354
for (j = 0 ; j < owidth ; j ++ )
252
355
{
253
- /* local pointers */
254
- real * ip = gradInput_p + k * itime * iwidth * iheight
255
- + ti * iwidth * iheight * dT + i * iwidth * dH + j * dW ;
256
- real * op = gradOutput_p + k * otime * owidth * oheight
257
- + ti * owidth * oheight + i * owidth + j ;
356
+ long tstart = ti * dT - padT ;
357
+ long hstart = i * dH - padH ;
358
+ long wstart = j * dW - padW ;
359
+ long tend = fminf (tstart + kT , itime + padT );
360
+ long hend = fminf (hstart + kH , iheight + padH );
361
+ long wend = fminf (wstart + kW , iwidth + padW );
362
+ long pool_size = (tend - tstart ) * (hend - hstart ) * (wend - wstart );
363
+ tstart = fmaxf (tstart , 0 );
364
+ hstart = fmaxf (hstart , 0 );
365
+ wstart = fmaxf (wstart , 0 );
366
+ tend = fminf (tend , itime );
367
+ hend = fminf (hend , iheight );
368
+ wend = fminf (wend , iwidth );
369
+
370
+ long divide_factor ;
371
+ if (count_include_pad )
372
+ divide_factor = pool_size ;
373
+ else
374
+ divide_factor = (tend - tstart ) * (hend - hstart ) * (wend - wstart );
258
375
259
376
/* scatter gradients out to footprint: */
260
- real val = * op / (kT * kW * kH );
261
- int x ,y ,z ;
262
- for (z = 0 ; z < kT ; z ++ )
377
+ real val = * op ++ ;
378
+
379
+ long x ,y ,z ;
380
+ for (z = tstart ; z < tend ; z ++ )
263
381
{
264
- for (y = 0 ; y < kH ; y ++ )
382
+ for (y = hstart ; y < hend ; y ++ )
265
383
{
266
- for (x = 0 ; x < kW ; x ++ )
384
+ for (x = wstart ; x < wend ; x ++ )
267
385
{
268
- * (ip + z * iwidth * iheight + y * iwidth + x ) += val ;
386
+ * (ip + z * iheight * iwidth + y * iwidth + x ) += val / divide_factor ;
269
387
}
270
388
}
271
389
}
@@ -285,15 +403,20 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
285
403
int kH ,
286
404
int dT ,
287
405
int dW ,
288
- int dH )
406
+ int dH ,
407
+ int padT ,
408
+ int padW ,
409
+ int padH ,
410
+ bool ceil_mode ,
411
+ bool count_include_pad )
289
412
{
290
- int nslices ;
291
- int itime ;
292
- int iheight ;
293
- int iwidth ;
294
- int otime ;
295
- int oheight ;
296
- int owidth ;
413
+ long nslices ;
414
+ long itime ;
415
+ long iheight ;
416
+ long iwidth ;
417
+ long otime ;
418
+ long oheight ;
419
+ long owidth ;
297
420
real * gradInput_data ;
298
421
real * gradOutput_data ;
299
422
@@ -304,7 +427,7 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
304
427
305
428
THNN_ (VolumetricAveragePooling_shapeCheck )(
306
429
state , input , gradOutput , kT , kW , kH ,
307
- dT , dW , dH );
430
+ dT , dW , dH , padT , padW , padH , ceil_mode );
308
431
309
432
/* get contiguous gradOutput */
310
433
gradOutput = THTensor_ (newContiguous )(gradOutput );
@@ -342,7 +465,9 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
342
465
itime , iwidth , iheight ,
343
466
otime , owidth , oheight ,
344
467
kT , kW , kH ,
345
- dT , dW , dH
468
+ dT , dW , dH ,
469
+ padT , padW , padH ,
470
+ count_include_pad
346
471
);
347
472
}
348
473
else /* batch mode */
@@ -361,7 +486,9 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
361
486
itime , iwidth , iheight ,
362
487
otime , owidth , oheight ,
363
488
kT , kW , kH ,
364
- dT , dW , dH
489
+ dT , dW , dH ,
490
+ padT , padW , padH ,
491
+ count_include_pad
365
492
);
366
493
}
367
494
}
0 commit comments