diff --git a/README.md b/README.md index 63e8f2e..8df3d46 100644 --- a/README.md +++ b/README.md @@ -77,16 +77,16 @@ from torchvision.models import resnet18 from torchcam.methods import SmoothGradCAMpp model = resnet18(pretrained=True).eval() -cam_extractor = SmoothGradCAMpp(model) # Get your input img = read_image("path/to/your/image.png") # Preprocess it for your chosen model input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) -# Preprocess your data and feed it to the model -out = model(input_tensor.unsqueeze(0)) -# Retrieve the CAM by passing the class index and the model output -activation_map = cam_extractor(out.squeeze(0).argmax().item(), out) +with SmoothGradCAMpp(model) as cam_extractor: + # Preprocess your data and feed it to the model + out = model(input_tensor.unsqueeze(0)) + # Retrieve the CAM by passing the class index and the model output + activation_map = cam_extractor(out.squeeze(0).argmax().item(), out) ``` If you want to visualize your heatmap, you only need to cast the CAM to a numpy ndarray: diff --git a/scripts/eval_latency.py b/scripts/eval_latency.py index 0cbf3c6..60d97a3 100644 --- a/scripts/eval_latency.py +++ b/scripts/eval_latency.py @@ -35,22 +35,21 @@ def main(args): with torch.no_grad(): _ = model(img_tensor) - # Hook the corresponding layer in the model - cam_extractor = methods.__dict__[args.method](model) timings = [] # Evaluation runs - for _ in range(args.it): - model.zero_grad() - scores = model(img_tensor) - - # Select the class index - class_idx = scores.squeeze(0).argmax().item() if args.class_idx is None else args.class_idx - - # Use the hooked data to compute activation map - start_ts = time.perf_counter() - _ = cam_extractor(class_idx, scores) - timings.append(time.perf_counter() - start_ts) + with methods.__dict__[args.method](model) as cam_extractor: + for _ in range(args.it): + model.zero_grad() + scores = model(img_tensor) + + # Select the class index + class_idx = scores.squeeze(0).argmax().item() if args.class_idx is None else args.class_idx + + # Use the hooked data to compute activation map + start_ts = time.perf_counter() + _ = cam_extractor(class_idx, scores) + timings.append(time.perf_counter() - start_ts) _timings = np.array(timings) print(f"{args.method} w/ {args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)") diff --git a/scripts/eval_perf.py b/scripts/eval_perf.py index b908a4a..c2e393c 100644 --- a/scripts/eval_perf.py +++ b/scripts/eval_perf.py @@ -60,14 +60,14 @@ def main(args): ) # Hook the corresponding layer in the model - cam_extractor = methods.__dict__[args.method](model, args.target.split(",")) - metric = ClassificationMetric(cam_extractor, partial(torch.softmax, dim=-1)) - - # Evaluation runs - for x, _ in loader: - model.zero_grad() - x = x.to(device=device) - metric.update(x) + with methods.__dict__[args.method](model, args.target.split(",")) as cam_extractor: + metric = ClassificationMetric(cam_extractor, partial(torch.softmax, dim=-1)) + + # Evaluation runs + for x, _ in loader: + model.zero_grad() + x = x.to(device=device) + metric.update(x) print(f"{args.method} w/ {args.arch} (validation set of Imagenette on ({args.size}, {args.size}) inputs)") metrics_dict = metric.summary() diff --git a/tests/test_methods_activation.py b/tests/test_methods_activation.py index 2a41d08..1d8a68e 100644 --- a/tests/test_methods_activation.py +++ b/tests/test_methods_activation.py @@ -9,11 +9,11 @@ def test_base_cam_constructor(mock_img_model): model = mobilenet_v2(pretrained=False).eval() # Check that multiple target layers is disabled for base CAM with pytest.raises(ValueError): - _ = activation.CAM(model, ["classifier.1", "classifier.2"]) + activation.CAM(model, ["classifier.1", "classifier.2"]) # FC layer checks with pytest.raises(TypeError): - _ = activation.CAM(model, fc_layer=3) + activation.CAM(model, fc_layer=3) def _verify_cam(activation_map, output_size): @@ -48,22 +48,21 @@ def test_img_cams(cam_name, target_layer, fc_layer, num_samples, output_size, ba target_layer = target_layer(model) if callable(target_layer) else target_layer # Hook the corresponding layer in the model - extractor = activation.__dict__[cam_name](model, target_layer, **kwargs) - - with torch.no_grad(): - scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1))) - # Use the hooked data to compute activation map - _verify_cam(extractor(scores[0].argmax().item(), scores)[0], (batch_size, *output_size)) - # Multiple class indices - _verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size)) + with activation.__dict__[cam_name](model, target_layer, **kwargs) as extractor: + with torch.no_grad(): + scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1))) + # Use the hooked data to compute activation map + _verify_cam(extractor(scores[0].argmax().item(), scores)[0], (batch_size, *output_size)) + # Multiple class indices + _verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size)) def test_cam_conv1x1(mock_fullyconv_model): - extractor = activation.CAM(mock_fullyconv_model, fc_layer="1") - with torch.no_grad(): - scores = mock_fullyconv_model(torch.rand((1, 3, 32, 32))) - # Use the hooked data to compute activation map - _verify_cam(extractor(scores[0].argmax().item(), scores)[0], (1, 32, 32)) + with activation.CAM(mock_fullyconv_model, fc_layer="1") as extractor: + with torch.no_grad(): + scores = mock_fullyconv_model(torch.rand((1, 3, 32, 32))) + # Use the hooked data to compute activation map + _verify_cam(extractor(scores[0].argmax().item(), scores)[0], (1, 32, 32)) @pytest.mark.parametrize( @@ -83,9 +82,9 @@ def test_video_cams(cam_name, target_layer, num_samples, output_size, mock_video kwargs["num_samples"] = num_samples # Hook the corresponding layer in the model - extractor = activation.__dict__[cam_name](model, target_layer, **kwargs) + with activation.__dict__[cam_name](model, target_layer, **kwargs) as extractor: - with torch.no_grad(): - scores = model(mock_video_tensor) - # Use the hooked data to compute activation map - _verify_cam(extractor(scores[0].argmax().item(), scores)[0], output_size) + with torch.no_grad(): + scores = model(mock_video_tensor) + # Use the hooked data to compute activation map + _verify_cam(extractor(scores[0].argmax().item(), scores)[0], output_size) diff --git a/tests/test_methods_core.py b/tests/test_methods_core.py index e1024f9..1674ea1 100644 --- a/tests/test_methods_core.py +++ b/tests/test_methods_core.py @@ -8,42 +8,51 @@ def test_cam_constructor(mock_img_model): model = mock_img_model.eval() # Check that wrong target_layer raises an error with pytest.raises(ValueError): - _ = core._CAM(model, "3") + core._CAM(model, "3") # Wrong types with pytest.raises(TypeError): - _ = core._CAM(model, 3) + core._CAM(model, 3) with pytest.raises(TypeError): - _ = core._CAM(model, [3]) + core._CAM(model, [3]) # Unrelated module with pytest.raises(ValueError): - _ = core._CAM(model, torch.nn.ReLU()) + core._CAM(model, torch.nn.ReLU()) -def test_cam_precheck(mock_img_model, mock_img_tensor): +def test_cam_context_manager(mock_img_model): model = mock_img_model.eval() - extractor = core._CAM(model, "0.3") - with torch.no_grad(): - # Check missing forward raises Error - with pytest.raises(AssertionError): - extractor(0) + with core._CAM(model): + # Model is hooked + assert sum(len(mod._forward_hooks) for mod in model.modules()) == 1 + # Exit should remove hooks + assert all(len(mod._forward_hooks) == 0 for mod in model.modules()) - # Correct forward - _ = model(mock_img_tensor) - # Check incorrect class index - with pytest.raises(ValueError): - extractor(-1) +def test_cam_precheck(mock_img_model, mock_img_tensor): + model = mock_img_model.eval() + with core._CAM(model, "0.3") as extractor: + with torch.no_grad(): + # Check missing forward raises Error + with pytest.raises(AssertionError): + extractor(0) - # Check incorrect class index - with pytest.raises(ValueError): - extractor([-1]) + # Correct forward + model(mock_img_tensor) - # Check missing score - if extractor._score_used: + # Check incorrect class index with pytest.raises(ValueError): - extractor(0) + extractor(-1) + + # Check incorrect class index + with pytest.raises(ValueError): + extractor([-1]) + + # Check missing score + if extractor._score_used: + with pytest.raises(ValueError): + extractor(0) @pytest.mark.parametrize( @@ -68,30 +77,28 @@ def test_cam_normalize(input_shape, spatial_dims): def test_cam_remove_hooks(mock_img_model): model = mock_img_model.eval() - extractor = core._CAM(model, "0.3") - - assert len(extractor.hook_handles) == 1 - # Check that there is only one hook on the model - assert all(act is None for act in extractor.hook_a) - with torch.no_grad(): - _ = model(torch.rand((1, 3, 32, 32))) - assert all(isinstance(act, torch.Tensor) for act in extractor.hook_a) - - # Remove it - extractor.remove_hooks() - assert len(extractor.hook_handles) == 0 - # Reset the hooked values - extractor.reset_hooks() - with torch.no_grad(): - _ = model(torch.rand((1, 3, 32, 32))) - assert all(act is None for act in extractor.hook_a) + with core._CAM(model, "0.3") as extractor: + assert len(extractor.hook_handles) == 1 + # Check that there is only one hook on the model + assert all(act is None for act in extractor.hook_a) + with torch.no_grad(): + _ = model(torch.rand((1, 3, 32, 32))) + assert all(isinstance(act, torch.Tensor) for act in extractor.hook_a) + + # Remove it + extractor.remove_hooks() + assert len(extractor.hook_handles) == 0 + # Reset the hooked values + extractor.reset_hooks() + with torch.no_grad(): + _ = model(torch.rand((1, 3, 32, 32))) + assert all(act is None for act in extractor.hook_a) def test_cam_repr(mock_img_model): model = mock_img_model.eval() - extractor = core._CAM(model, "0.3") - - assert repr(extractor) == "_CAM(target_layer=['0.3'])" + with core._CAM(model, "0.3") as extractor: + assert repr(extractor) == "_CAM(target_layer=['0.3'])" def test_fuse_cams(): diff --git a/tests/test_methods_gradient.py b/tests/test_methods_gradient.py index 6022c1a..ce24b09 100644 --- a/tests/test_methods_gradient.py +++ b/tests/test_methods_gradient.py @@ -30,30 +30,30 @@ def test_img_cams(cam_name, target_layer, output_size, batch_size, mock_img_tens target_layer = target_layer(model) if callable(target_layer) else target_layer # Hook the corresponding layer in the model - extractor = gradient.__dict__[cam_name](model, target_layer) - - scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1))) - # Use the hooked data to compute activation map - _verify_cam(extractor(scores[0].argmax().item(), scores, retain_graph=True)[0], (batch_size, *output_size)) - # Multiple class indices - _verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size)) - - # Inplace model - model = nn.Sequential( - nn.Conv2d(3, 8, 3, padding=1), - nn.ReLU(), - nn.Conv2d(8, 8, 3, padding=1), - nn.ReLU(inplace=True), - nn.AdaptiveAvgPool2d((1, 1)), - nn.Flatten(1), - nn.Linear(8, 10), - ) + with gradient.__dict__[cam_name](model, target_layer) as extractor: + + scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1))) + # Use the hooked data to compute activation map + _verify_cam(extractor(scores[0].argmax().item(), scores, retain_graph=True)[0], (batch_size, *output_size)) + # Multiple class indices + _verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size)) + + # Inplace model + model = nn.Sequential( + nn.Conv2d(3, 8, 3, padding=1), + nn.ReLU(), + nn.Conv2d(8, 8, 3, padding=1), + nn.ReLU(inplace=True), + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(1), + nn.Linear(8, 10), + ) # Hook before the inplace ops - extractor = gradient.__dict__[cam_name](model, "2") - scores = model(mock_img_tensor) - # Use the hooked data to compute activation map - _verify_cam(extractor(scores[0].argmax().item(), scores)[0], (1, 224, 224)) + with gradient.__dict__[cam_name](model, "2") as extractor: + scores = model(mock_img_tensor) + # Use the hooked data to compute activation map + _verify_cam(extractor(scores[0].argmax().item(), scores)[0], (1, 224, 224)) @pytest.mark.parametrize( @@ -69,20 +69,18 @@ def test_img_cams(cam_name, target_layer, output_size, batch_size, mock_img_tens def test_video_cams(cam_name, target_layer, output_size, mock_video_model, mock_video_tensor): model = mock_video_model.eval() # Hook the corresponding layer in the model - extractor = gradient.__dict__[cam_name](model, target_layer) - - scores = model(mock_video_tensor) - # Use the hooked data to compute activation map - _verify_cam(extractor(scores[0].argmax().item(), scores)[0], output_size) + with gradient.__dict__[cam_name](model, target_layer) as extractor: + scores = model(mock_video_tensor) + # Use the hooked data to compute activation map + _verify_cam(extractor(scores[0].argmax().item(), scores)[0], output_size) def test_smoothgradcampp_repr(): model = mobilenet_v2(pretrained=False).eval() # Hook the corresponding layer in the model - extractor = gradient.SmoothGradCAMpp(model, "features.18.0") - - assert repr(extractor) == "SmoothGradCAMpp(target_layer=['features.18.0'], num_samples=4, std=0.3)" + with gradient.SmoothGradCAMpp(model, "features.18.0") as extractor: + assert repr(extractor) == "SmoothGradCAMpp(target_layer=['features.18.0'], num_samples=4, std=0.3)" def test_layercam_fuse_cams(mock_img_model): diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 22b4b8f..2d0ba7c 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -10,13 +10,13 @@ def test_classification_metric(): model = mobilenet_v3_small(pretrained=False) - extractor = LayerCAM(model, "features.12") - metric = metrics.ClassificationMetric(extractor, partial(torch.softmax, dim=-1)) + with LayerCAM(model, "features.12") as extractor: + metric = metrics.ClassificationMetric(extractor, partial(torch.softmax, dim=-1)) - # Fixed class - metric.update(torch.rand((2, 3, 224, 224), dtype=torch.float32), class_idx=0) - # Top predicted class - metric.update(torch.rand((2, 3, 224, 224), dtype=torch.float32)) + # Fixed class + metric.update(torch.rand((2, 3, 224, 224), dtype=torch.float32), class_idx=0) + # Top predicted class + metric.update(torch.rand((2, 3, 224, 224), dtype=torch.float32)) out = metric.summary() assert len(out) == 2 diff --git a/torchcam/methods/core.py b/torchcam/methods/core.py index 857b0e7..020faa7 100644 --- a/torchcam/methods/core.py +++ b/torchcam/methods/core.py @@ -77,6 +77,13 @@ def __init__( # Model output is used by the extractor self._score_used = False + def __enter__(self): + return self + + def __exit__(self, exct_type, exce_value, traceback): + self.remove_hooks() + self.reset_hooks() + def _resolve_layer_name(self, target_layer: nn.Module) -> str: """Resolves the name of a given layer inside the hooked model.""" _found = False