Skip to content

Commit f6bbf35

Browse files
Merge pull request #1090 from NiklasGustafsson/bugs
Fixed array-based overloads of max_pool{123}d.
2 parents 27ba19c + 59d3c4d commit f6bbf35

File tree

5 files changed

+32
-6
lines changed

5 files changed

+32
-6
lines changed

RELEASENOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ __Bug Fixes__:
1515
#1041 Running example code got error in Windows 10<br/>
1616
#1064 Inplace operators create an alias<br/>
1717
#1084 Module.zero_grad() does not work<br/>
18+
#1089 max_pool2d overload creates tensor with incorrect shape<br/>
1819

1920
## NuGet Version 0.100.4
2021

src/TorchSharp/NN/Pooling/MaxPool1D.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ public static Tensor max_pool1d(Tensor input, long kernelSize, long? stride = nu
8989
long? padding = null, long? dilation = null, bool ceil_mode = false)
9090
{
9191
var kernelSizes = new long[] { kernelSize };
92-
var strides = new long[] { stride ?? 1 };
92+
var strides = new long[] { stride ?? kernelSize };
9393
var paddings = new long[] { padding ?? 0 };
9494
var dilations = new long[] { dilation ?? 1 };
9595
unsafe {
@@ -121,7 +121,7 @@ public static (Tensor output, Tensor indices) max_pool1d_with_indices(Tensor inp
121121
long? padding = null, long? dilation = null, bool ceil_mode = false)
122122
{
123123
var kernelSizes = new long[] { kernelSize };
124-
var strides = new long[] { stride ?? 1 };
124+
var strides = new long[] { stride ?? kernelSize };
125125
var paddings = new long[] { padding ?? 0 };
126126
var dilations = new long[] { dilation ?? 1 };
127127
IntPtr[] ptrArray;

src/TorchSharp/NN/Pooling/MaxPool2D.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ public static partial class functional
132132
public static Tensor max_pool2d(Tensor input, long[] kernelSize, long[] strides = null,
133133
long[] padding = null, long[] dilation = null, bool ceil_mode = false)
134134
{
135-
strides = strides ?? kernelSize.Select(x => 1L).ToArray();
135+
strides = strides ?? kernelSize;
136136
padding = padding ?? kernelSize.Select(x => 0L).ToArray();
137137
dilation = dilation ?? kernelSize.Select(x => 1L).ToArray();
138138
unsafe {
@@ -232,7 +232,7 @@ public static unsafe Tensor max_pool2d(Tensor input, (long, long) kernelSize, (l
232232
public static (Tensor output, Tensor indices) max_pool2d_with_indices(Tensor input, long[] kernelSize, long[] strides = null,
233233
long[] padding = null, long[] dilation = null, bool ceil_mode = false)
234234
{
235-
strides = strides ?? kernelSize.Select(x => 1L).ToArray();
235+
strides = strides ?? kernelSize;
236236
padding = padding ?? kernelSize.Select(x => 0L).ToArray();
237237
dilation = dilation ?? kernelSize.Select(x => 1L).ToArray();
238238
IntPtr[] ptrArray;

src/TorchSharp/NN/Pooling/MaxPool3D.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public static partial class functional
114114
public static Tensor max_pool3d(Tensor input, long[] kernelSize, long[] strides = null,
115115
long[] padding = null, long[] dilation = null, bool ceil_mode = false)
116116
{
117-
strides = strides ?? kernelSize.Select(x => 1L).ToArray();
117+
strides = strides ?? kernelSize;
118118
padding = padding ?? kernelSize.Select(x => 0L).ToArray();
119119
dilation = dilation ?? kernelSize.Select(x => 1L).ToArray();
120120
unsafe {
@@ -145,7 +145,7 @@ public static Tensor max_pool3d(Tensor input, long[] kernelSize, long[] strides
145145
public static (Tensor output, Tensor indices) max_pool3d_with_indices(Tensor input, long[] kernelSize, long[] strides = null,
146146
long[] padding = null, long[] dilation = null, bool ceil_mode = false)
147147
{
148-
strides = strides ?? kernelSize.Select(x => 1L).ToArray();
148+
strides = strides ?? kernelSize;
149149
padding = padding ?? kernelSize.Select(x => 0L).ToArray();
150150
dilation = dilation ?? kernelSize.Select(x => 1L).ToArray();
151151
IntPtr[] ptrArray;

test/TorchSharpTest/TestTorchTensorBugs.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,5 +1155,30 @@ static void Validate1057()
11551155
}
11561156
}
11571157
}
1158+
1159+
[Fact]
1160+
public void Validate1089_2d()
1161+
{
1162+
var t = torch.zeros(1, 6, 28, 28);
1163+
var expectedShape = new long[] { 1, 6, 14, 14 };
1164+
1165+
Assert.Multiple(
1166+
() => Assert.Equal(expectedShape, functional.max_pool2d(t, 2).shape),
1167+
() => Assert.Equal(expectedShape, functional.max_pool2d(t, ( 2, 2 )).shape),
1168+
() => Assert.Equal(expectedShape, functional.max_pool2d(t, new long[] { 2, 2 }).shape)
1169+
);
1170+
1171+
Assert.Equal(expectedShape, functional.max_pool2d_with_indices(t, new long[] { 2, 2 }).output.shape);
1172+
}
1173+
1174+
[Fact]
1175+
public void Validate1089_3d()
1176+
{
1177+
var t = torch.zeros(new long[] { 1, 6, 28, 28, 28 });
1178+
var expectedShape = new long[] { 1, 6, 14, 14, 14 };
1179+
1180+
Assert.Equal(expectedShape, functional.max_pool3d(t, new long[] { 2, 2, 2 }).shape);
1181+
Assert.Equal(expectedShape, functional.max_pool3d_with_indices(t, new long[] { 2, 2, 2 }).output.shape);
1182+
}
11581183
}
11591184
}

0 commit comments

Comments
 (0)