Skip to content

Commit b5949d8

Browse files
houseroadsoumith
authored andcommitted
Adding implicit padding for 3d average pooling
1 parent e4c05c2 commit b5949d8

File tree

2 files changed

+180
-49
lines changed

2 files changed

+180
-49
lines changed

generic/THNN.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,14 +1249,18 @@ TH_API void THNN_(VolumetricAveragePooling_updateOutput)(
12491249
THTensor *input,
12501250
THTensor *output,
12511251
int kT, int kW, int kH,
1252-
int dT, int dW, int dH);
1252+
int dT, int dW, int dH,
1253+
int padT, int padW, int padH,
1254+
bool ceil_mode, bool count_include_pad);
12531255
TH_API void THNN_(VolumetricAveragePooling_updateGradInput)(
12541256
THNNState *state,
12551257
THTensor *input,
12561258
THTensor *gradOutput,
12571259
THTensor *gradInput,
12581260
int kT, int kW, int kH,
1259-
int dT, int dW, int dH);
1261+
int dT, int dW, int dH,
1262+
int padT, int padW, int padH,
1263+
bool ceil_mode, bool count_include_pad);
12601264

12611265
TH_API void THNN_(VolumetricConvolution_updateOutput)(
12621266
THNNState *state,

generic/VolumetricAveragePooling.c

Lines changed: 174 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)(
1111
int kH,
1212
int dT,
1313
int dW,
14-
int dH) {
14+
int dH,
15+
int padT,
16+
int padW,
17+
int padH,
18+
bool ceil_mode)
19+
{
1520
long nslices;
1621
long itime;
1722
long iheight;
@@ -49,14 +54,46 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)(
4954
input->size[dimt], input->size[dimh], input->size[dimw],
5055
kT, kH, kW);
5156

57+
// The second argument is argNumber... here is the index of padH.
58+
THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 11,
59+
"pad should not be greater than half of kernel size, but got "
60+
"padT = %d, padW = %d, padH = %d, kT = %d, kW = %d, kH = %d",
61+
padT, padW, padH, kT, kW, kH);
62+
5263
/* sizes */
5364
nslices = input->size[dimN];
5465
itime = input->size[dimt];
5566
iheight = input->size[dimh];
5667
iwidth = input->size[dimw];
57-
otime = (itime - kT) / dT + 1;
58-
oheight = (iheight - kH) / dH + 1;
59-
owidth = (iwidth - kW) / dW + 1;
68+
69+
if (ceil_mode) {
70+
otime = (long)(ceil((float)(itime - kT + 2*padT) / dT)) + 1;
71+
oheight = (long)(ceil((float)(iheight - kH + 2*padH) / dH)) + 1;
72+
owidth = (long)(ceil((float)(iwidth - kW + 2*padW) / dW)) + 1;
73+
}
74+
else
75+
{
76+
otime = (long)(floor((float)(itime - kT + 2*padT) / dT)) + 1;
77+
oheight = (long)(floor((float)(iheight - kH + 2*padH) / dH)) + 1;
78+
owidth = (long)(floor((float)(iwidth - kW + 2*padW) / dW)) + 1;
79+
}
80+
81+
if (padT || padW || padH)
82+
{
83+
// ensure that the last pooling starts inside the image
84+
// needed to avoid problems in ceil mode
85+
if ((otime - 1)*dT >= itime + padT)
86+
--otime;
87+
if ((oheight - 1)*dH >= iheight + padH)
88+
--oheight;
89+
if ((owidth - 1)*dW >= iwidth + padW)
90+
--owidth;
91+
}
92+
93+
if (otime < 1 || owidth < 1 || oheight < 1)
94+
THError("Given input size: (%dx%dx%dx%d). "
95+
"Calculated output size: (%dx%dx%dx%d). Output size is too small",
96+
nslices,itime,iheight,iwidth,nslices,otime,oheight,owidth);
6097

6198
if (gradOutput != NULL) {
6299
THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimN, nslices);
@@ -81,43 +118,69 @@ static void THNN_(VolumetricAveragePooling_updateOutput_frame)(
81118
int kH,
82119
int dT,
83120
int dW,
84-
int dH)
121+
int dH,
122+
int padT,
123+
int padW,
124+
int padH,
125+
bool count_include_pad)
85126
{
86127
long k;
87128
#pragma omp parallel for private(k)
88129
for (k = 0; k < nslices; k++)
89130
{
90-
/* loop over output */
91131
long i, j, ti;
132+
133+
/* local pointers. */
134+
real *ip = input_p + k * itime * iwidth * iheight;
135+
real *op = output_p + k * otime * owidth * oheight;
136+
for (i = 0; i < otime * oheight * owidth; ++i)
137+
*(op + i) = 0;
138+
139+
/* loop over output */
92140
for (ti = 0; ti < otime; ti++)
93141
{
94142
for (i = 0; i < oheight; i++)
95143
{
96144
for (j = 0; j < owidth; j++)
97145
{
98-
/* local pointers */
99-
real *ip = input_p + k * itime * iwidth * iheight
100-
+ ti * iwidth * iheight * dT + i * iwidth * dH + j * dW;
101-
real *op = output_p + k * otime * owidth * oheight
102-
+ ti * owidth * oheight + i * owidth + j;
146+
/* compute pool range. */
147+
long tstart = ti * dT - padT;
148+
long hstart = i * dH - padH;
149+
long wstart = j * dW - padW;
150+
long tend = fminf(tstart + kT, itime + padT);
151+
long hend = fminf(hstart + kH, iheight + padH);
152+
long wend = fminf(wstart + kW, iwidth + padW);
153+
long pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
154+
tstart = fmaxf(tstart, 0);
155+
hstart = fmaxf(hstart, 0);
156+
wstart = fmaxf(wstart, 0);
157+
tend = fmin(tend, itime);
158+
hend = fmin(hend, iheight);
159+
wend = fmin(wend, iwidth);
160+
161+
int divide_factor;
162+
if (count_include_pad)
163+
divide_factor = pool_size;
164+
else
165+
divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
103166

104167
/* compute local sum: */
105168
real sum = 0.0;
106-
int x, y, z;
169+
long x, y, z;
107170

108-
for (z=0; z < kT; z++)
171+
for (z = tstart; z < tend; z++)
109172
{
110-
for (y = 0; y < kH; y++)
173+
for (y = hstart; y < hend; y++)
111174
{
112-
for (x = 0; x < kW; x++)
175+
for (x = wstart; x < wend; x++)
113176
{
114177
sum += *(ip + z * iwidth * iheight + y * iwidth + x);
115178
}
116179
}
117180
}
118181

119182
/* set output to local max */
120-
*op = sum / (kT * kW * kH);
183+
*op++ += sum / divide_factor;
121184
}
122185
}
123186
}
@@ -133,7 +196,12 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
133196
int kH,
134197
int dT,
135198
int dW,
136-
int dH)
199+
int dH,
200+
int padT,
201+
int padW,
202+
int padH,
203+
bool ceil_mode,
204+
bool count_include_pad)
137205
{
138206
long nslices;
139207
long itime;
@@ -147,7 +215,7 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
147215

148216
THNN_(VolumetricAveragePooling_shapeCheck)(
149217
state, input, NULL, kT, kW, kH,
150-
dT, dW, dH);
218+
dT, dW, dH, padT, padW, padH, ceil_mode);
151219

152220
int dimN = 0;
153221
int dimt = 1;
@@ -167,9 +235,29 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
167235
itime = input->size[dimt];
168236
iheight = input->size[dimh];
169237
iwidth = input->size[dimw];
170-
otime = (itime - kT) / dT + 1;
171-
oheight = (iheight - kH) / dH + 1;
172-
owidth = (iwidth - kW) / dW + 1;
238+
if (ceil_mode)
239+
{
240+
otime = (long)(ceil((float)(itime - kT + 2*padT) / dT)) + 1;
241+
oheight = (long)(ceil((float)(iheight - kH + 2*padH) / dH)) + 1;
242+
owidth = (long)(ceil((float)(iwidth - kW + 2*padW) / dW)) + 1;
243+
}
244+
else
245+
{
246+
otime = (long)(floor((float)(itime - kT + 2*padT) / dT)) + 1;
247+
oheight = (long)(floor((float)(iheight - kH + 2*padH) / dH)) + 1;
248+
owidth = (long)(floor((float)(iwidth - kW + 2*padW) / dW)) + 1;
249+
}
250+
if (padT || padH || padW)
251+
{
252+
// ensure that the last pooling starts inside the image
253+
// needed to avoid problems in ceil mode
254+
if ((otime - 1)*dT >= itime + padT)
255+
--otime;
256+
if ((oheight - 1)*dH >= iheight + padH)
257+
--oheight;
258+
if ((owidth - 1)*dW >= iwidth + padW)
259+
--owidth;
260+
}
173261

174262
/* get contiguous input */
175263
input = THTensor_(newContiguous)(input);
@@ -187,7 +275,9 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
187275
itime, iwidth, iheight,
188276
otime, owidth, oheight,
189277
kT, kW, kH,
190-
dT, dW, dH
278+
dT, dW, dH,
279+
padT, padW, padH,
280+
count_include_pad
191281
);
192282
}
193283
else /* batch mode */
@@ -212,7 +302,9 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
212302
itime, iwidth, iheight,
213303
otime, owidth, oheight,
214304
kT, kW, kH,
215-
dT, dW, dH
305+
dT, dW, dH,
306+
padT, padW, padH,
307+
count_include_pad
216308
);
217309
}
218310
}
@@ -236,36 +328,62 @@ static void THNN_(VolumetricAveragePooling_updateGradInput_frame)(
236328
int kH,
237329
int dT,
238330
int dW,
239-
int dH)
331+
int dH,
332+
int padT,
333+
int padW,
334+
int padH,
335+
bool count_include_pad)
240336
{
241337
long k;
242338
#pragma omp parallel for private(k)
243339
for (k = 0; k < nslices; k++)
244340
{
245-
/* loop over output */
246341
long i, j, ti;
342+
343+
/* local pointers */
344+
real *ip = gradInput_p + k * itime * iwidth * iheight;
345+
real *op = gradOutput_p + k * otime * owidth * oheight;
346+
for (i = 0; i < itime*iwidth*iheight; i++)
347+
*(ip + i) = 0;
348+
349+
/* loop over output */
247350
for (ti = 0; ti < otime; ti++)
248351
{
249352
for (i = 0; i < oheight; i++)
250353
{
251354
for (j = 0; j < owidth; j++)
252355
{
253-
/* local pointers */
254-
real *ip = gradInput_p + k * itime * iwidth * iheight
255-
+ ti * iwidth * iheight * dT + i * iwidth * dH + j * dW;
256-
real *op = gradOutput_p + k * otime * owidth * oheight
257-
+ ti * owidth * oheight + i * owidth + j;
356+
long tstart = ti * dT - padT;
357+
long hstart = i * dH - padH;
358+
long wstart = j * dW - padW;
359+
long tend = fminf(tstart + kT, itime + padT);
360+
long hend = fminf(hstart + kH, iheight + padH);
361+
long wend = fminf(wstart + kW, iwidth + padW);
362+
long pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart);
363+
tstart = fmaxf(tstart, 0);
364+
hstart = fmaxf(hstart, 0);
365+
wstart = fmaxf(wstart, 0);
366+
tend = fminf(tend, itime);
367+
hend = fminf(hend, iheight);
368+
wend = fminf(wend, iwidth);
369+
370+
long divide_factor;
371+
if (count_include_pad)
372+
divide_factor = pool_size;
373+
else
374+
divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
258375

259376
/* scatter gradients out to footprint: */
260-
real val = *op / (kT * kW * kH);
261-
int x,y,z;
262-
for (z=0; z < kT; z++)
377+
real val = *op++;
378+
379+
long x,y,z;
380+
for (z = tstart; z < tend; z++)
263381
{
264-
for (y = 0; y < kH; y++)
382+
for (y = hstart; y < hend; y++)
265383
{
266-
for (x = 0; x < kW; x++)
384+
for (x = wstart; x < wend; x++)
267385
{
268-
*(ip + z * iwidth * iheight + y * iwidth + x) += val;
386+
*(ip + z * iheight * iwidth + y * iwidth + x) += val / divide_factor;
269387
}
270388
}
271389
}
@@ -285,15 +403,20 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
285403
int kH,
286404
int dT,
287405
int dW,
288-
int dH)
406+
int dH,
407+
int padT,
408+
int padW,
409+
int padH,
410+
bool ceil_mode,
411+
bool count_include_pad)
289412
{
290-
int nslices;
291-
int itime;
292-
int iheight;
293-
int iwidth;
294-
int otime;
295-
int oheight;
296-
int owidth;
413+
long nslices;
414+
long itime;
415+
long iheight;
416+
long iwidth;
417+
long otime;
418+
long oheight;
419+
long owidth;
297420
real *gradInput_data;
298421
real *gradOutput_data;
299422

@@ -304,7 +427,7 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
304427

305428
THNN_(VolumetricAveragePooling_shapeCheck)(
306429
state, input, gradOutput, kT, kW, kH,
307-
dT, dW, dH);
430+
dT, dW, dH, padT, padW, padH, ceil_mode);
308431

309432
/* get contiguous gradOutput */
310433
gradOutput = THTensor_(newContiguous)(gradOutput);
@@ -342,7 +465,9 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
342465
itime, iwidth, iheight,
343466
otime, owidth, oheight,
344467
kT, kW, kH,
345-
dT, dW, dH
468+
dT, dW, dH,
469+
padT, padW, padH,
470+
count_include_pad
346471
);
347472
}
348473
else /* batch mode */
@@ -361,7 +486,9 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
361486
itime, iwidth, iheight,
362487
otime, owidth, oheight,
363488
kT, kW, kH,
364-
dT, dW, dH
489+
dT, dW, dH,
490+
padT, padW, padH,
491+
count_include_pad
365492
);
366493
}
367494
}

0 commit comments

Comments
 (0)