Skip to content

Commit

Permalink
feat: Removes model hooks when the context manager exits (#198)
Browse files Browse the repository at this point in the history
* feat: Removed hooks when extractor gets deleted

* test: Adds unittest for destructor

* refactor: Implemented a context manager instead

* test: Updates unittests

* refactor: Refactors scripts

* docs: Updates usage example
  • Loading branch information
frgfm authored Dec 21, 2022
1 parent 95539fd commit e7d4644
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 123 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ from torchvision.models import resnet18
from torchcam.methods import SmoothGradCAMpp

model = resnet18(pretrained=True).eval()
cam_extractor = SmoothGradCAMpp(model)
# Get your input
img = read_image("path/to/your/image.png")
# Preprocess it for your chosen model
input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# Preprocess your data and feed it to the model
out = model(input_tensor.unsqueeze(0))
# Retrieve the CAM by passing the class index and the model output
activation_map = cam_extractor(out.squeeze(0).argmax().item(), out)
with SmoothGradCAMpp(model) as cam_extractor:
# Preprocess your data and feed it to the model
out = model(input_tensor.unsqueeze(0))
# Retrieve the CAM by passing the class index and the model output
activation_map = cam_extractor(out.squeeze(0).argmax().item(), out)
```

If you want to visualize your heatmap, you only need to cast the CAM to a numpy ndarray:
Expand Down
25 changes: 12 additions & 13 deletions scripts/eval_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,21 @@ def main(args):
with torch.no_grad():
_ = model(img_tensor)

# Hook the corresponding layer in the model
cam_extractor = methods.__dict__[args.method](model)
timings = []

# Evaluation runs
for _ in range(args.it):
model.zero_grad()
scores = model(img_tensor)

# Select the class index
class_idx = scores.squeeze(0).argmax().item() if args.class_idx is None else args.class_idx

# Use the hooked data to compute activation map
start_ts = time.perf_counter()
_ = cam_extractor(class_idx, scores)
timings.append(time.perf_counter() - start_ts)
with methods.__dict__[args.method](model) as cam_extractor:
for _ in range(args.it):
model.zero_grad()
scores = model(img_tensor)

# Select the class index
class_idx = scores.squeeze(0).argmax().item() if args.class_idx is None else args.class_idx

# Use the hooked data to compute activation map
start_ts = time.perf_counter()
_ = cam_extractor(class_idx, scores)
timings.append(time.perf_counter() - start_ts)

_timings = np.array(timings)
print(f"{args.method} w/ {args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)")
Expand Down
16 changes: 8 additions & 8 deletions scripts/eval_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def main(args):
)

# Hook the corresponding layer in the model
cam_extractor = methods.__dict__[args.method](model, args.target.split(","))
metric = ClassificationMetric(cam_extractor, partial(torch.softmax, dim=-1))

# Evaluation runs
for x, _ in loader:
model.zero_grad()
x = x.to(device=device)
metric.update(x)
with methods.__dict__[args.method](model, args.target.split(",")) as cam_extractor:
metric = ClassificationMetric(cam_extractor, partial(torch.softmax, dim=-1))

# Evaluation runs
for x, _ in loader:
model.zero_grad()
x = x.to(device=device)
metric.update(x)

print(f"{args.method} w/ {args.arch} (validation set of Imagenette on ({args.size}, {args.size}) inputs)")
metrics_dict = metric.summary()
Expand Down
39 changes: 19 additions & 20 deletions tests/test_methods_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ def test_base_cam_constructor(mock_img_model):
model = mobilenet_v2(pretrained=False).eval()
# Check that multiple target layers is disabled for base CAM
with pytest.raises(ValueError):
_ = activation.CAM(model, ["classifier.1", "classifier.2"])
activation.CAM(model, ["classifier.1", "classifier.2"])

# FC layer checks
with pytest.raises(TypeError):
_ = activation.CAM(model, fc_layer=3)
activation.CAM(model, fc_layer=3)


def _verify_cam(activation_map, output_size):
Expand Down Expand Up @@ -48,22 +48,21 @@ def test_img_cams(cam_name, target_layer, fc_layer, num_samples, output_size, ba

target_layer = target_layer(model) if callable(target_layer) else target_layer
# Hook the corresponding layer in the model
extractor = activation.__dict__[cam_name](model, target_layer, **kwargs)

with torch.no_grad():
scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1)))
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores)[0], (batch_size, *output_size))
# Multiple class indices
_verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size))
with activation.__dict__[cam_name](model, target_layer, **kwargs) as extractor:
with torch.no_grad():
scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1)))
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores)[0], (batch_size, *output_size))
# Multiple class indices
_verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size))


def test_cam_conv1x1(mock_fullyconv_model):
extractor = activation.CAM(mock_fullyconv_model, fc_layer="1")
with torch.no_grad():
scores = mock_fullyconv_model(torch.rand((1, 3, 32, 32)))
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores)[0], (1, 32, 32))
with activation.CAM(mock_fullyconv_model, fc_layer="1") as extractor:
with torch.no_grad():
scores = mock_fullyconv_model(torch.rand((1, 3, 32, 32)))
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores)[0], (1, 32, 32))


@pytest.mark.parametrize(
Expand All @@ -83,9 +82,9 @@ def test_video_cams(cam_name, target_layer, num_samples, output_size, mock_video
kwargs["num_samples"] = num_samples

# Hook the corresponding layer in the model
extractor = activation.__dict__[cam_name](model, target_layer, **kwargs)
with activation.__dict__[cam_name](model, target_layer, **kwargs) as extractor:

with torch.no_grad():
scores = model(mock_video_tensor)
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores)[0], output_size)
with torch.no_grad():
scores = model(mock_video_tensor)
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores)[0], output_size)
89 changes: 48 additions & 41 deletions tests/test_methods_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,51 @@ def test_cam_constructor(mock_img_model):
model = mock_img_model.eval()
# Check that wrong target_layer raises an error
with pytest.raises(ValueError):
_ = core._CAM(model, "3")
core._CAM(model, "3")

# Wrong types
with pytest.raises(TypeError):
_ = core._CAM(model, 3)
core._CAM(model, 3)
with pytest.raises(TypeError):
_ = core._CAM(model, [3])
core._CAM(model, [3])

# Unrelated module
with pytest.raises(ValueError):
_ = core._CAM(model, torch.nn.ReLU())
core._CAM(model, torch.nn.ReLU())


def test_cam_precheck(mock_img_model, mock_img_tensor):
def test_cam_context_manager(mock_img_model):
model = mock_img_model.eval()
extractor = core._CAM(model, "0.3")
with torch.no_grad():
# Check missing forward raises Error
with pytest.raises(AssertionError):
extractor(0)
with core._CAM(model):
# Model is hooked
assert sum(len(mod._forward_hooks) for mod in model.modules()) == 1
# Exit should remove hooks
assert all(len(mod._forward_hooks) == 0 for mod in model.modules())

# Correct forward
_ = model(mock_img_tensor)

# Check incorrect class index
with pytest.raises(ValueError):
extractor(-1)
def test_cam_precheck(mock_img_model, mock_img_tensor):
model = mock_img_model.eval()
with core._CAM(model, "0.3") as extractor:
with torch.no_grad():
# Check missing forward raises Error
with pytest.raises(AssertionError):
extractor(0)

# Check incorrect class index
with pytest.raises(ValueError):
extractor([-1])
# Correct forward
model(mock_img_tensor)

# Check missing score
if extractor._score_used:
# Check incorrect class index
with pytest.raises(ValueError):
extractor(0)
extractor(-1)

# Check incorrect class index
with pytest.raises(ValueError):
extractor([-1])

# Check missing score
if extractor._score_used:
with pytest.raises(ValueError):
extractor(0)


@pytest.mark.parametrize(
Expand All @@ -68,30 +77,28 @@ def test_cam_normalize(input_shape, spatial_dims):

def test_cam_remove_hooks(mock_img_model):
model = mock_img_model.eval()
extractor = core._CAM(model, "0.3")

assert len(extractor.hook_handles) == 1
# Check that there is only one hook on the model
assert all(act is None for act in extractor.hook_a)
with torch.no_grad():
_ = model(torch.rand((1, 3, 32, 32)))
assert all(isinstance(act, torch.Tensor) for act in extractor.hook_a)

# Remove it
extractor.remove_hooks()
assert len(extractor.hook_handles) == 0
# Reset the hooked values
extractor.reset_hooks()
with torch.no_grad():
_ = model(torch.rand((1, 3, 32, 32)))
assert all(act is None for act in extractor.hook_a)
with core._CAM(model, "0.3") as extractor:
assert len(extractor.hook_handles) == 1
# Check that there is only one hook on the model
assert all(act is None for act in extractor.hook_a)
with torch.no_grad():
_ = model(torch.rand((1, 3, 32, 32)))
assert all(isinstance(act, torch.Tensor) for act in extractor.hook_a)

# Remove it
extractor.remove_hooks()
assert len(extractor.hook_handles) == 0
# Reset the hooked values
extractor.reset_hooks()
with torch.no_grad():
_ = model(torch.rand((1, 3, 32, 32)))
assert all(act is None for act in extractor.hook_a)


def test_cam_repr(mock_img_model):
model = mock_img_model.eval()
extractor = core._CAM(model, "0.3")

assert repr(extractor) == "_CAM(target_layer=['0.3'])"
with core._CAM(model, "0.3") as extractor:
assert repr(extractor) == "_CAM(target_layer=['0.3'])"


def test_fuse_cams():
Expand Down
58 changes: 28 additions & 30 deletions tests/test_methods_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,30 @@ def test_img_cams(cam_name, target_layer, output_size, batch_size, mock_img_tens

target_layer = target_layer(model) if callable(target_layer) else target_layer
# Hook the corresponding layer in the model
extractor = gradient.__dict__[cam_name](model, target_layer)

scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1)))
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores, retain_graph=True)[0], (batch_size, *output_size))
# Multiple class indices
_verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size))

# Inplace model
model = nn.Sequential(
nn.Conv2d(3, 8, 3, padding=1),
nn.ReLU(),
nn.Conv2d(8, 8, 3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(1),
nn.Linear(8, 10),
)
with gradient.__dict__[cam_name](model, target_layer) as extractor:

scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1)))
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores, retain_graph=True)[0], (batch_size, *output_size))
# Multiple class indices
_verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size))

# Inplace model
model = nn.Sequential(
nn.Conv2d(3, 8, 3, padding=1),
nn.ReLU(),
nn.Conv2d(8, 8, 3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(1),
nn.Linear(8, 10),
)

# Hook before the inplace ops
extractor = gradient.__dict__[cam_name](model, "2")
scores = model(mock_img_tensor)
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores)[0], (1, 224, 224))
with gradient.__dict__[cam_name](model, "2") as extractor:
scores = model(mock_img_tensor)
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores)[0], (1, 224, 224))


@pytest.mark.parametrize(
Expand All @@ -69,20 +69,18 @@ def test_img_cams(cam_name, target_layer, output_size, batch_size, mock_img_tens
def test_video_cams(cam_name, target_layer, output_size, mock_video_model, mock_video_tensor):
model = mock_video_model.eval()
# Hook the corresponding layer in the model
extractor = gradient.__dict__[cam_name](model, target_layer)

scores = model(mock_video_tensor)
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores)[0], output_size)
with gradient.__dict__[cam_name](model, target_layer) as extractor:
scores = model(mock_video_tensor)
# Use the hooked data to compute activation map
_verify_cam(extractor(scores[0].argmax().item(), scores)[0], output_size)


def test_smoothgradcampp_repr():
model = mobilenet_v2(pretrained=False).eval()

# Hook the corresponding layer in the model
extractor = gradient.SmoothGradCAMpp(model, "features.18.0")

assert repr(extractor) == "SmoothGradCAMpp(target_layer=['features.18.0'], num_samples=4, std=0.3)"
with gradient.SmoothGradCAMpp(model, "features.18.0") as extractor:
assert repr(extractor) == "SmoothGradCAMpp(target_layer=['features.18.0'], num_samples=4, std=0.3)"


def test_layercam_fuse_cams(mock_img_model):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
def test_classification_metric():

model = mobilenet_v3_small(pretrained=False)
extractor = LayerCAM(model, "features.12")
metric = metrics.ClassificationMetric(extractor, partial(torch.softmax, dim=-1))
with LayerCAM(model, "features.12") as extractor:
metric = metrics.ClassificationMetric(extractor, partial(torch.softmax, dim=-1))

# Fixed class
metric.update(torch.rand((2, 3, 224, 224), dtype=torch.float32), class_idx=0)
# Top predicted class
metric.update(torch.rand((2, 3, 224, 224), dtype=torch.float32))
# Fixed class
metric.update(torch.rand((2, 3, 224, 224), dtype=torch.float32), class_idx=0)
# Top predicted class
metric.update(torch.rand((2, 3, 224, 224), dtype=torch.float32))
out = metric.summary()

assert len(out) == 2
Expand Down
7 changes: 7 additions & 0 deletions torchcam/methods/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def __init__(
# Model output is used by the extractor
self._score_used = False

def __enter__(self):
return self

def __exit__(self, exct_type, exce_value, traceback):
self.remove_hooks()
self.reset_hooks()

def _resolve_layer_name(self, target_layer: nn.Module) -> str:
"""Resolves the name of a given layer inside the hooked model."""
_found = False
Expand Down

0 comments on commit e7d4644

Please sign in to comment.