Skip to content

Commit

Permalink
suport dual stride for pooling (#2202)
Browse files Browse the repository at this point in the history
* add dual stride for pooling

* fix lint

* update
  • Loading branch information
tornadomeet authored and antinucleon committed May 22, 2016
1 parent ec5538b commit 6523a7a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion mshadow
15 changes: 8 additions & 7 deletions src/operator/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,15 @@ class PoolingOp : public Operator {
Tensor<xpu, 4> data = in_data[pool_enum::kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> out = out_data[pool_enum::kOut].get<xpu, 4, real_t>(s);
mshadow::Shape<2> out_shape = Shape2(out.shape_[2], out.shape_[3]);
// TODO(bing): dual stride in mshadow
CHECK_EQ(param_.stride[0], param_.stride[1])
<< "Only same stride is supported now";
if (param_.pool_type == pool_enum::kMaxPooling || param_.pool_type == pool_enum::kSumPooling) {
Assign(out,
req[pool_enum::kOut],
pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
out_shape,
param_.kernel[0],
param_.kernel[1],
param_.stride[0]));
param_.stride[0],
param_.stride[1]));
} else if (param_.pool_type == pool_enum::kAvgPooling) {
Assign(out,
req[pool_enum::kOut],
Expand All @@ -95,7 +93,8 @@ class PoolingOp : public Operator {
out_shape,
param_.kernel[0],
param_.kernel[1],
param_.stride[0]));
param_.stride[0],
param_.stride[1]));
}
}

Expand Down Expand Up @@ -129,7 +128,8 @@ class PoolingOp : public Operator {
pad(grad, 0, 0),
param_.kernel[0],
param_.kernel[1],
param_.stride[0]),
param_.stride[0],
param_.stride[1]),
in_shape,
param_.pad[0],
param_.pad[1]));
Expand All @@ -141,7 +141,8 @@ class PoolingOp : public Operator {
pad(grad, 0, 0),
param_.kernel[0],
param_.kernel[1],
param_.stride[0]),
param_.stride[0],
param_.stride[1]),
in_shape,
param_.pad[0],
param_.pad[1]));
Expand Down
3 changes: 3 additions & 0 deletions src/operator/upsampling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ class UpSamplingNearestOp : public Operator {
in_shape,
scale,
scale,
scale,
scale));
} else {
Assign(input_grad, req[i],
pool<mshadow::red::sum>(slice<1>(grad, begin, end),
in_shape,
scale,
scale,
scale,
scale));
}
begin = end;
Expand All @@ -151,6 +153,7 @@ class UpSamplingNearestOp : public Operator {
in_shape,
param_.scale,
param_.scale,
param_.scale,
param_.scale));
}
}
Expand Down

0 comments on commit 6523a7a

Please sign in to comment.