Skip to content

Commit

Permalink
fix: Fixes SmoothGradCAMpp (#204)
Browse files Browse the repository at this point in the history
* fix: Fixed SmoothGradCAMpp

* style: Specified hook typing

* test: Updates unittest
  • Loading branch information
frgfm authored Jan 14, 2023
1 parent 8278ded commit c24a858
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 16 deletions.
18 changes: 14 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,18 @@ def mock_img_tensor():
except ConnectionError:
img_tensor = torch.rand((1, 3, 224, 224))

img_tensor.requires_grad_(True)
return img_tensor


@pytest.fixture(scope="session")
def mock_video_tensor():
return torch.rand((1, 3, 8, 16, 16))
return torch.rand((1, 3, 8, 16, 16), requires_grad=True)


@pytest.fixture(scope="session")
def mock_video_model():
return nn.Sequential(
model = nn.Sequential(
nn.Sequential(
nn.Conv3d(3, 8, 3, padding=1),
nn.ReLU(),
Expand All @@ -44,11 +45,14 @@ def mock_video_model():
nn.Flatten(1),
nn.Linear(16, 1),
)
for p in model:
p.requires_grad_(False)
return model


@pytest.fixture(scope="session")
def mock_img_model():
return nn.Sequential(
model = nn.Sequential(
nn.Sequential(
nn.Conv2d(3, 8, 3, padding=1),
nn.ReLU(),
Expand All @@ -59,11 +63,14 @@ def mock_img_model():
nn.Flatten(1),
nn.Linear(16, 1),
)
for p in model:
p.requires_grad_(False)
return model


@pytest.fixture(scope="session")
def mock_fullyconv_model():
return nn.Sequential(
model = nn.Sequential(
nn.Sequential(
nn.Conv2d(3, 8, 3, padding=1),
nn.ReLU(),
Expand All @@ -74,3 +81,6 @@ def mock_fullyconv_model():
nn.Conv2d(16, 1, 1),
nn.Flatten(1),
)
for p in model:
p.requires_grad_(False)
return model
4 changes: 4 additions & 0 deletions tests/test_methods_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

def test_base_cam_constructor(mock_img_model):
model = mobilenet_v2(pretrained=False).eval()
for p in model.parameters():
p.requires_grad_(False)
# Check that multiple target layers is disabled for base CAM
with pytest.raises(ValueError):
activation.CAM(model, ["classifier.1", "classifier.2"])
Expand Down Expand Up @@ -38,6 +40,8 @@ def _verify_cam(activation_map, output_size):
)
def test_img_cams(cam_name, target_layer, fc_layer, num_samples, output_size, batch_size, mock_img_tensor):
model = mobilenet_v2(pretrained=False).eval()
for p in model.parameters():
p.requires_grad_(False)
kwargs = {}
# Speed up testing by reducing the number of samples
if isinstance(num_samples, int):
Expand Down
24 changes: 14 additions & 10 deletions tests/test_methods_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def _verify_cam(activation_map, output_size):
)
def test_img_cams(cam_name, target_layer, output_size, batch_size, mock_img_tensor):
model = mobilenet_v2(pretrained=False).eval()
for p in model.parameters():
p.requires_grad_(False)

target_layer = target_layer(model) if callable(target_layer) else target_layer
# Hook the corresponding layer in the model
Expand All @@ -38,16 +40,18 @@ def test_img_cams(cam_name, target_layer, output_size, batch_size, mock_img_tens
# Multiple class indices
_verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size))

# Inplace model
model = nn.Sequential(
nn.Conv2d(3, 8, 3, padding=1),
nn.ReLU(),
nn.Conv2d(8, 8, 3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(1),
nn.Linear(8, 10),
)
# Inplace model
model = nn.Sequential(
nn.Conv2d(3, 8, 3, padding=1),
nn.ReLU(),
nn.Conv2d(8, 8, 3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(1),
nn.Linear(8, 10),
)
for p in model.parameters():
p.requires_grad_(False)

# Hook before the inplace ops
with gradient.__dict__[cam_name](model, "2") as extractor:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_methods_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
def test_locate_candidate_layer(mock_img_model):
# ResNet-18
mod = resnet18().eval()
for p in mod.parameters():
p.requires_grad_(False)
assert _utils.locate_candidate_layer(mod) == "layer4"

# Mobilenet V3 Large
mod = mobilenet_v3_large().eval()
for p in mod.parameters():
p.requires_grad_(False)
assert _utils.locate_candidate_layer(mod) == "features"

# Custom model
Expand All @@ -24,6 +28,8 @@ def test_locate_linear_layer(mock_img_model):

# ResNet-18
mod = resnet18().eval()
for p in mod.parameters():
p.requires_grad_(False)
assert _utils.locate_linear_layer(mod) == "fc"

# Custom model
Expand Down
2 changes: 1 addition & 1 deletion torchcam/methods/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _resolve_layer_name(self, target_layer: nn.Module) -> str:

return target_name

def _hook_a(self, module: nn.Module, input: Tensor, output: Tensor, idx: int = 0) -> None:
def _hook_a(self, module: nn.Module, input: Tuple[Tensor, ...], output: Tensor, idx: int = 0) -> None:
"""Activation hook."""
if self._hooks_enabled:
self.hook_a[idx] = output.data
Expand Down
3 changes: 2 additions & 1 deletion torchcam/methods/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _store_grad(self, grad: Tensor, idx: int = 0) -> None:
if self._hooks_enabled:
self.hook_g[idx] = grad.data

def _hook_g(self, module: nn.Module, input: Tensor, output: Tensor, idx: int = 0) -> None:
def _hook_g(self, module: nn.Module, input: Tuple[Tensor, ...], output: Tensor, idx: int = 0) -> None:
"""Gradient hook"""
if self._hooks_enabled:
self.hook_handles.append(output.register_hook(partial(self._store_grad, idx=idx)))
Expand Down Expand Up @@ -275,6 +275,7 @@ def _get_weights(
for _idx in range(self.num_samples):
# Add noise
noisy_input = self._input + self._distrib.sample(self._input.size()).to(device=self._input.device)
noisy_input.requires_grad_(True)
# Forward & Backward
out = self.model(noisy_input)
self.model.zero_grad()
Expand Down

0 comments on commit c24a858

Please sign in to comment.