Skip to content

Commit 4fa67d2

Browse files
author
Zhijian Liu
authored
Fix type annotation of GlobalAvgPool and GlobalMaxPool (#154)
1 parent 67c90d1 commit 4fa67d2

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchsparse/nn/modules/pooling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from torch import nn
23

34
from torchsparse import SparseTensor
@@ -8,11 +9,11 @@
89

910
class GlobalAvgPool(nn.Module):
1011

11-
def forward(self, input: SparseTensor) -> SparseTensor:
12+
def forward(self, input: SparseTensor) -> torch.Tensor:
1213
return F.global_avg_pool(input)
1314

1415

1516
class GlobalMaxPool(nn.Module):
1617

17-
def forward(self, input: SparseTensor) -> SparseTensor:
18+
def forward(self, input: SparseTensor) -> torch.Tensor:
1819
return F.global_max_pool(input)

0 commit comments

Comments
 (0)