We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
GlobalAvgPool
GlobalMaxPool
1 parent 67c90d1 commit 4fa67d2Copy full SHA for 4fa67d2
torchsparse/nn/modules/pooling.py
@@ -1,3 +1,4 @@
1
+import torch
2
from torch import nn
3
4
from torchsparse import SparseTensor
@@ -8,11 +9,11 @@
8
9
10
class GlobalAvgPool(nn.Module):
11
- def forward(self, input: SparseTensor) -> SparseTensor:
12
+ def forward(self, input: SparseTensor) -> torch.Tensor:
13
return F.global_avg_pool(input)
14
15
16
class GlobalMaxPool(nn.Module):
17
18
19
return F.global_max_pool(input)
0 commit comments