Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I think there are some mistakes #18

Open
William20234 opened this issue Dec 18, 2023 · 1 comment
Open

I think there are some mistakes #18

William20234 opened this issue Dec 18, 2023 · 1 comment

Comments

@William20234
Copy link

After applying Channel Attention Module, maybe it would be better to apply a convolution layer in order to modify the channels to the original value (usually 3 channels), instead of applying Spatial Attention Module instantly. Or Spatial Attention Module can't make sense.

@William20234
Copy link
Author

My Advice:
class Bottleneck(nn.Module):

def __init__(self, inplanes, planes, stride=1, downsample=None,expansion=4):
    super(Bottleneck, self).__init__()
    self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                           padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)
    self.conv3 = nn.Conv2d(planes, planes * expansion, kernel_size=1, bias=False)
    self.bn3 = nn.BatchNorm2d(planes * expansion)
    self.relu = nn.ReLU(inplace=True)

    self.conv4=nn.Conv2d(planes * expansion, inplanes, kernel_size=1, bias=False)

    self.ca = ChannelAttention(planes * expansion)
    self.sa = SpatialAttention()

    self.downsample = downsample
    self.stride = stride

def forward(self, x):
    residual = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)

    out = self.conv3(out)
    out = self.bn3(out)

    out = self.ca(out) * out
    out=self.conv4(out)
    out = self.sa(out) * out

    if self.downsample is not None:
        residual = self.downsample(x)
    
    out += residual

    out = self.relu(out)

    return out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant