diff --git a/.github/collect_env.py b/.github/collect_env.py index 2cf8777..9948a28 100644 --- a/.github/collect_env.py +++ b/.github/collect_env.py @@ -217,10 +217,7 @@ def get_os(run_lambda): def get_env_info(): run_lambda = run - if TORCHCAM_AVAILABLE: - torchcam_str = torchcam.__version__ - else: - torchcam_str = "N/A" + torchcam_str = torchcam.__version__ if TORCHCAM_AVAILABLE else "N/A" if TORCH_AVAILABLE: torch_str = torch.__version__ @@ -258,14 +255,14 @@ def get_env_info(): def pretty_str(envinfo): def replace_nones(dct, replacement="Could not collect"): - for key in dct.keys(): + for key in dct: if dct[key] is not None: continue dct[key] = replacement return dct def replace_bools(dct, true="Yes", false="No"): - for key in dct.keys(): + for key in dct: if dct[key] is True: dct[key] = true elif dct[key] is False: diff --git a/.github/verify_labels.py b/.github/verify_labels.py index 71264a7..5db876d 100644 --- a/.github/verify_labels.py +++ b/.github/verify_labels.py @@ -71,7 +71,8 @@ def parse_args(): import argparse parser = argparse.ArgumentParser( - description="PR label checker", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="PR label checker", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("pr", type=int, help="PR number") diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 2df19ca..8945f92 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -56,7 +56,7 @@ jobs: steps: - uses: actions/checkout@v2 - name: Miniconda setup - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true python-version: 3.9 diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 32067bb..30c7a21 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -15,14 +15,13 @@ jobs: python: [3.9] steps: - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} architecture: x64 - name: Run ruff run: | - pip install ruff==0.1.0 + pip install ruff==0.1.9 ruff --version ruff check --diff . @@ -34,8 +33,7 @@ jobs: python: [3.9] steps: - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} architecture: x64 @@ -53,7 +51,7 @@ jobs: mypy --version mypy - black: + ruff-format: runs-on: ${{ matrix.os }} strategy: matrix: @@ -61,32 +59,12 @@ jobs: python: [3.9] steps: - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} architecture: x64 - - name: Run black - run: | - pip install "black==23.3.0" - black --version - black --check --diff . - - bandit: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python: [3.9] - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python }} - architecture: x64 - - name: Run bandit + - name: Run ruff run: | - pip install bandit[toml] - bandit --version - bandit -r . -c pyproject.toml + pip install ruff==0.1.9 + ruff --version + ruff format --check --diff . diff --git a/.gitignore b/.gitignore index e3dd063..7d10148 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,4 @@ torchcam/version.py # Conda distribution conda-dist/ +.vscode/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e4b5108..9e5775b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,13 +17,10 @@ repos: args: ['--branch', 'main'] - id: debug-statements language_version: python3 - - repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.0.290' + rev: 'v0.1.9' hooks: - id: ruff args: - --fix + - id: ruff-format diff --git a/Makefile b/Makefile index 3daffc6..5df9a9a 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,12 @@ # this target runs checks on all files quality: + ruff format --check . ruff check . mypy - black --check . - bandit -r . -c pyproject.toml # this target runs checks on all files and potentially modifies some of them style: - black . + ruff format . ruff --fix . # Run tests for the library diff --git a/demo/app.py b/demo/app.py index 7653821..092fcf9 100644 --- a/demo/app.py +++ b/demo/app.py @@ -17,7 +17,17 @@ from torchcam.methods._utils import locate_candidate_layer from torchcam.utils import overlay_mask -CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"] +CAM_METHODS = [ + "CAM", + "GradCAM", + "GradCAMpp", + "SmoothGradCAMpp", + "ScoreCAM", + "SSCAM", + "ISCAM", + "XGradCAM", + "LayerCAM", +] TV_MODELS = [ "resnet18", "resnet50", @@ -87,7 +97,8 @@ def main(): ) if cam_method is not None: cam_extractor = methods.__dict__[cam_method]( - model, target_layer=[s.strip() for s in target_layer.split("+")] if len(target_layer) > 0 else None + model, + target_layer=[s.strip() for s in target_layer.split("+")] if len(target_layer) > 0 else None, ) class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)] @@ -103,7 +114,11 @@ def main(): else: with st.spinner("Analyzing..."): # Preprocess image - img_tensor = normalize(to_tensor(resize(img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + img_tensor = normalize( + to_tensor(resize(img, (224, 224))), + [0.485, 0.456, 0.406], + [0.229, 0.224, 0.225], + ) if torch.cuda.is_available(): img_tensor = img_tensor.cuda() diff --git a/docs/source/conf.py b/docs/source/conf.py index f0cf0f9..d9c8670 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,7 +19,7 @@ from datetime import datetime from pathlib import Path -sys.path.insert(0, Path().resolve().parent.parent) +sys.path.insert(0, Path().cwd().parent.parent) import torchcam # -- Project information ----------------------------------------------------- @@ -121,9 +121,7 @@ def add_ga_javascript(app, pagename, templatename, context, doctree): gtag('js', new Date()); gtag('config', '{0}'); - """.format( - app.config.googleanalytics_id - ) + """.format(app.config.googleanalytics_id) context["metatags"] = metatags diff --git a/pyproject.toml b/pyproject.toml index 21d2bec..298b318 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,11 +50,9 @@ test = [ "pytest-pretty>=1.0.0,<2.0.0", ] quality = [ - "ruff==0.1.0", - "mypy==1.5.1", - "black==23.3.0", - "bandit[toml]>=1.7.0,<1.8.0", - "pre-commit>=2.17.0,<3.0.0", + "ruff==0.1.9", + "mypy==1.8.0", + "pre-commit>=3.0.0,<4.0.0", ] docs = [ "sphinx>=3.0.0,!=3.5.0", @@ -80,11 +78,9 @@ dev = [ "pytest-xdist>=3.0.0,<4.0.0", "pytest-pretty>=1.0.0,<2.0.0", # style - "ruff==0.1.0", - "mypy==1.5.1", - "black==23.3.0", - "bandit[toml]>=1.7.0,<1.8.0", - "pre-commit>=2.17.0,<3.0.0", + "ruff==0.1.9", + "mypy==1.8.0", + "pre-commit>=3.0.0,<4.0.0", # docs "sphinx>=3.0.0,!=3.5.0", "furo>=2022.3.4", @@ -133,6 +129,16 @@ select = [ "T20", # flake8-print "PT", # flake8-pytest-style "LOG", # flake8-logging + "SIM", # flake8-simplify + "YTT", # flake8-2020 + "ANN", # flake8-annotations + "ASYNC", # flake8-async + "BLE", # flake8-blind-except + "A", # flake8-builtins + "ICN", # flake8-import-conventions + "PIE", # flake8-pie + "ARG", # flake8-unused-arguments + "FURB", # refurb ] ignore = [ "E501", # line too long, handled by black @@ -142,20 +148,31 @@ ignore = [ "F403", # star imports "E731", # lambda assignment "C416", # list comprehension to list() + "ANN101", # missing type annotations on self + "ANN102", # missing type annotations on cls + "ANN002", # missing type annotations on *args + "ANN003", # missing type annotations on **kwargs + "COM812", # trailing comma missing "N812", # lowercase imported as non-lowercase + "ISC001", # implicit string concatenation (handled by format) + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed ] exclude = [".git"] line-length = 120 target-version = "py39" preview = true +[tool.ruff.format] +quote-style = "double" +indent-style = "space" + [tool.ruff.per-file-ignores] "**/__init__.py" = ["I001", "F401", "CPY001"] -"scripts/**.py" = ["D", "T201", "N812"] -".github/**.py" = ["D", "T201", "S602"] -"docs/**.py" = ["E402", "D103"] -"tests/**.py" = ["D103", "CPY001", "S101", "PT011",] -"demo/**.py" = ["D103"] +"scripts/**.py" = ["D", "T201", "N812", "S101", "ANN"] +".github/**.py" = ["D", "T201", "S602", "S101", "ANN"] +"docs/**.py" = ["E402", "D103", "ANN", "A001", "ARG001"] +"tests/**.py" = ["D103", "CPY001", "S101", "PT011", "ANN"] +"demo/**.py" = ["D103", "ANN"] "setup.py" = ["T201"] [tool.ruff.flake8-quotes] @@ -177,6 +194,7 @@ no_implicit_optional = true check_untyped_defs = true implicit_reexport = false disallow_untyped_defs = true +explicit_package_bases = true [[tool.mypy.overrides]] module = [ @@ -184,11 +202,3 @@ module = [ "matplotlib" ] ignore_missing_imports = true - -[tool.black] -line-length = 120 -target-version = ['py39'] - -[tool.bandit] -exclude_dirs = [".github/collect_env.py"] -skips = ["B101"] diff --git a/scripts/cam_example.py b/scripts/cam_example.py index d7440d5..7eb6bcd 100644 --- a/scripts/cam_example.py +++ b/scripts/cam_example.py @@ -35,16 +35,15 @@ def main(args): p.requires_grad_(False) # Image - if args.img.startswith("http"): - img_path = BytesIO(requests.get(args.img, timeout=5).content) - else: - img_path = args.img + img_path = BytesIO(requests.get(args.img, timeout=5).content) if args.img.startswith("http") else args.img pil_img = Image.open(img_path, mode="r").convert("RGB") # Preprocess image - img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).to( - device=device - ) + img_tensor = normalize( + to_tensor(resize(pil_img, (224, 224))), + [0.485, 0.456, 0.406], + [0.229, 0.224, 0.225], + ).to(device=device) img_tensor.requires_grad_(True) if isinstance(args.method, str): @@ -119,7 +118,8 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Saliency Map comparison", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="Saliency Map comparison", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--arch", type=str, default="resnet18", help="Name of the architecture") parser.add_argument( @@ -129,13 +129,23 @@ def main(args): help="The image to extract CAM from", ) parser.add_argument("--class-idx", type=int, default=232, help="Index of the class to inspect") - parser.add_argument("--device", type=str, default=None, help="Default device to perform computation on") + parser.add_argument( + "--device", + type=str, + default=None, + help="Default device to perform computation on", + ) parser.add_argument("--savefig", type=str, default=None, help="Path to save figure") parser.add_argument("--method", type=str, default=None, help="CAM method to use") parser.add_argument("--target", type=str, default=None, help="the target layer") parser.add_argument("--alpha", type=float, default=0.5, help="Transparency of the heatmap") parser.add_argument("--rows", type=int, default=1, help="Number of rows for the layout") - parser.add_argument("--noblock", dest="noblock", help="Disables blocking visualization", action="store_true") + parser.add_argument( + "--noblock", + dest="noblock", + help="Disables blocking visualization", + action="store_true", + ) args = parser.parse_args() main(args) diff --git a/scripts/eval_latency.py b/scripts/eval_latency.py index 9c0eae0..9a5f1d4 100644 --- a/scripts/eval_latency.py +++ b/scripts/eval_latency.py @@ -60,13 +60,24 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="CAM method latency benchmark", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="CAM method latency benchmark", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("method", type=str, help="CAM method to use") - parser.add_argument("--arch", type=str, default="resnet18", help="Name of the torchvision architecture") + parser.add_argument( + "--arch", + type=str, + default="resnet18", + help="Name of the torchvision architecture", + ) parser.add_argument("--size", type=int, default=224, help="The image input size") parser.add_argument("--class-idx", type=int, default=232, help="Index of the class to inspect") - parser.add_argument("--device", type=str, default=None, help="Default device to perform computation on") + parser.add_argument( + "--device", + type=str, + default=None, + help="Default device to perform computation on", + ) parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") args = parser.parse_args() diff --git a/scripts/eval_perf.py b/scripts/eval_perf.py index e641ce7..35c5ffb 100644 --- a/scripts/eval_perf.py +++ b/scripts/eval_perf.py @@ -40,14 +40,12 @@ def main(args): scale_size = min(int(math.floor(args.size / crop_pct)), 320) if scale_size < 320: eval_tf.append(T.Resize(scale_size)) - eval_tf.extend( - [ - T.CenterCrop(args.size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float32), - T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ] - ) + eval_tf.extend([ + T.CenterCrop(args.size), + T.PILToTensor(), + T.ConvertImageDtype(torch.float32), + T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ]) ds = ImageFolder( Path(args.data_path).joinpath("val"), @@ -80,17 +78,32 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="CAM method performance evaluation", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="CAM method performance evaluation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("data_path", type=str, help="path to dataset folder") parser.add_argument("method", type=str, help="CAM method to use") - parser.add_argument("--arch", type=str, default="mobilenet_v3_large", help="Name of the torchvision architecture") + parser.add_argument( + "--arch", + type=str, + default="mobilenet_v3_large", + help="Name of the torchvision architecture", + ) parser.add_argument("--target", type=str, default=None, help="Target layer name") parser.add_argument("--size", type=int, default=224, help="The image input size") parser.add_argument("-b", "--batch-size", default=32, type=int, help="batch size") - parser.add_argument("--device", type=str, default=None, help="Default device to perform computation on") parser.add_argument( - "-j", "--workers", default=min(os.cpu_count(), 16), type=int, help="number of data loading workers" + "--device", + type=str, + default=None, + help="Default device to perform computation on", + ) + parser.add_argument( + "-j", + "--workers", + default=min(os.cpu_count(), 16), + type=int, + help="number of data loading workers", ) args = parser.parse_args() diff --git a/tests/conftest.py b/tests/conftest.py index fe2c3b7..ee08205 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,9 @@ def mock_img_tensor(): # Forward an image pil_img = Image.open(BytesIO(response.content), mode="r").convert("RGB") img_tensor = normalize( - to_tensor(resize(pil_img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] + to_tensor(resize(pil_img, (224, 224))), + [0.485, 0.456, 0.406], + [0.229, 0.224, 0.225], ).unsqueeze(0) except ConnectionError: img_tensor = torch.rand((1, 3, 224, 224)) diff --git a/tests/test_methods_activation.py b/tests/test_methods_activation.py index 76f7f7e..1b1c94d 100644 --- a/tests/test_methods_activation.py +++ b/tests/test_methods_activation.py @@ -5,7 +5,7 @@ from torchcam.methods import activation -def test_base_cam_constructor(mock_img_model): +def test_base_cam_constructor(): model = mobilenet_v2(weights=None).eval() for p in model.parameters(): p.requires_grad_(False) @@ -26,7 +26,14 @@ def _verify_cam(activation_map, output_size): @pytest.mark.parametrize( - ("cam_name", "target_layer", "fc_layer", "num_samples", "output_size", "batch_size"), + ( + "cam_name", + "target_layer", + "fc_layer", + "num_samples", + "output_size", + "batch_size", + ), [ ("CAM", None, None, None, (7, 7), 1), ("CAM", None, None, None, (7, 7), 2), @@ -38,7 +45,15 @@ def _verify_cam(activation_map, output_size): ("ISCAM", "features.16.conv.3", None, 4, (7, 7), 1), ], ) -def test_img_cams(cam_name, target_layer, fc_layer, num_samples, output_size, batch_size, mock_img_tensor): +def test_img_cams( + cam_name, + target_layer, + fc_layer, + num_samples, + output_size, + batch_size, + mock_img_tensor, +): model = mobilenet_v2(weights=None).eval() for p in model.parameters(): p.requires_grad_(False) @@ -52,21 +67,19 @@ def test_img_cams(cam_name, target_layer, fc_layer, num_samples, output_size, ba target_layer = target_layer(model) if callable(target_layer) else target_layer # Hook the corresponding layer in the model - with activation.__dict__[cam_name](model, target_layer, **kwargs) as extractor: - with torch.no_grad(): - scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1))) - # Use the hooked data to compute activation map - _verify_cam(extractor(scores[0].argmax().item(), scores)[0], (batch_size, *output_size)) - # Multiple class indices - _verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size)) + with activation.__dict__[cam_name](model, target_layer, **kwargs) as extractor, torch.no_grad(): + scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1))) + # Use the hooked data to compute activation map + _verify_cam(extractor(scores[0].argmax().item(), scores)[0], (batch_size, *output_size)) + # Multiple class indices + _verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size)) def test_cam_conv1x1(mock_fullyconv_model): - with activation.CAM(mock_fullyconv_model, fc_layer="1") as extractor: - with torch.no_grad(): - scores = mock_fullyconv_model(torch.rand((1, 3, 32, 32))) - # Use the hooked data to compute activation map - _verify_cam(extractor(scores[0].argmax().item(), scores)[0], (1, 32, 32)) + with activation.CAM(mock_fullyconv_model, fc_layer="1") as extractor, torch.no_grad(): + scores = mock_fullyconv_model(torch.rand((1, 3, 32, 32))) + # Use the hooked data to compute activation map + _verify_cam(extractor(scores[0].argmax().item(), scores)[0], (1, 32, 32)) @pytest.mark.parametrize( @@ -78,7 +91,14 @@ def test_cam_conv1x1(mock_fullyconv_model): ("ISCAM", "0.3", 4, (1, 8, 16, 16)), ], ) -def test_video_cams(cam_name, target_layer, num_samples, output_size, mock_video_model, mock_video_tensor): +def test_video_cams( + cam_name, + target_layer, + num_samples, + output_size, + mock_video_model, + mock_video_tensor, +): model = mock_video_model.eval() kwargs = {} # Speed up testing by reducing the number of samples @@ -86,8 +106,7 @@ def test_video_cams(cam_name, target_layer, num_samples, output_size, mock_video kwargs["num_samples"] = num_samples # Hook the corresponding layer in the model - with activation.__dict__[cam_name](model, target_layer, **kwargs) as extractor: - with torch.no_grad(): - scores = model(mock_video_tensor) - # Use the hooked data to compute activation map - _verify_cam(extractor(scores[0].argmax().item(), scores)[0], output_size) + with activation.__dict__[cam_name](model, target_layer, **kwargs) as extractor, torch.no_grad(): + scores = model(mock_video_tensor) + # Use the hooked data to compute activation map + _verify_cam(extractor(scores[0].argmax().item(), scores)[0], output_size) diff --git a/tests/test_methods_core.py b/tests/test_methods_core.py index c5d4e56..18aea1e 100644 --- a/tests/test_methods_core.py +++ b/tests/test_methods_core.py @@ -32,27 +32,26 @@ def test_cam_context_manager(mock_img_model): def test_cam_precheck(mock_img_model, mock_img_tensor): model = mock_img_model.eval() - with core._CAM(model, "0.3") as extractor: - with torch.no_grad(): - # Check missing forward raises Error - with pytest.raises(AssertionError): - extractor(0) + with core._CAM(model, "0.3") as extractor, torch.no_grad(): + # Check missing forward raises Error + with pytest.raises(AssertionError): + extractor(0) - # Correct forward - model(mock_img_tensor) + # Correct forward + model(mock_img_tensor) - # Check incorrect class index - with pytest.raises(ValueError): - extractor(-1) + # Check incorrect class index + with pytest.raises(ValueError): + extractor(-1) - # Check incorrect class index - with pytest.raises(ValueError): - extractor([-1]) + # Check incorrect class index + with pytest.raises(ValueError): + extractor([-1]) - # Check missing score - if extractor._score_used: - with pytest.raises(ValueError): - extractor(0) + # Check missing score + if extractor._score_used: + with pytest.raises(ValueError): + extractor(0) @pytest.mark.parametrize( diff --git a/tests/test_methods_gradient.py b/tests/test_methods_gradient.py index aa54a2c..d9d2a1a 100644 --- a/tests/test_methods_gradient.py +++ b/tests/test_methods_gradient.py @@ -35,7 +35,10 @@ def test_img_cams(cam_name, target_layer, output_size, batch_size, mock_img_tens with gradient.__dict__[cam_name](model, target_layer) as extractor: scores = model(mock_img_tensor.repeat((batch_size,) + (1,) * (mock_img_tensor.ndim - 1))) # Use the hooked data to compute activation map - _verify_cam(extractor(scores[0].argmax().item(), scores, retain_graph=True)[0], (batch_size, *output_size)) + _verify_cam( + extractor(scores[0].argmax().item(), scores, retain_graph=True)[0], + (batch_size, *output_size), + ) # Multiple class indices _verify_cam(extractor(list(range(batch_size)), scores)[0], (batch_size, *output_size)) @@ -86,7 +89,7 @@ def test_smoothgradcampp_repr(): assert repr(extractor) == "SmoothGradCAMpp(target_layer=['features.18.0'], num_samples=4, std=0.3)" -def test_layercam_fuse_cams(mock_img_model): +def test_layercam_fuse_cams(): with pytest.raises(TypeError): gradient.LayerCAM.fuse_cams(torch.zeros((3, 32, 32))) diff --git a/torchcam/__init__.py b/torchcam/__init__.py index 16c708b..89dcfd7 100644 --- a/torchcam/__init__.py +++ b/torchcam/__init__.py @@ -1,6 +1,5 @@ +from contextlib import suppress from torchcam import methods, metrics, utils -try: +with suppress(ImportError): from .version import __version__ -except ImportError: - pass diff --git a/torchcam/methods/_utils.py b/torchcam/methods/_utils.py index 727bc18..4d575f9 100644 --- a/torchcam/methods/_utils.py +++ b/torchcam/methods/_utils.py @@ -28,7 +28,7 @@ def locate_candidate_layer(mod: nn.Module, input_shape: Tuple[int, ...] = (3, 22 output_shapes: List[Tuple[Optional[str], Tuple[int, ...]]] = [] - def _record_output_shape(module: nn.Module, input: Tensor, output: Tensor, name: Optional[str] = None) -> None: + def _record_output_shape(_: nn.Module, _input: Tensor, output: Tensor, name: Optional[str] = None) -> None: """Activation hook.""" output_shapes.append((name, output.shape)) diff --git a/torchcam/methods/activation.py b/torchcam/methods/activation.py index aabb35f..fd9f2f7 100644 --- a/torchcam/methods/activation.py +++ b/torchcam/methods/activation.py @@ -81,10 +81,10 @@ def __init__( self._fc_weights = self._fc_weights.view(*self._fc_weights.shape[:2]) @torch.no_grad() - def _get_weights( # type: ignore[override] + def _get_weights( self, class_idx: Union[int, List[int]], - *args: Any, + *_: Any, ) -> List[Tensor]: """Computes the weight coefficients of the hooked activation maps.""" # Take the FC weights of the target class @@ -149,10 +149,10 @@ def __init__( # Ensure ReLU is applied to CAM before normalization self._relu = True - def _store_input(self, module: nn.Module, input: Tensor) -> None: + def _store_input(self, _: nn.Module, _input: Tensor) -> None: """Store model input tensor.""" if self._hooks_enabled: - self._input = input[0].data.clone() + self._input = _input[0].data.clone() @torch.no_grad() def _get_score_weights(self, activations: List[Tensor], class_idx: Union[int, List[int]]) -> List[Tensor]: @@ -187,10 +187,10 @@ def _get_score_weights(self, activations: List[Tensor], class_idx: Union[int, Li return [torch.softmax(w.view(b, c), -1) for w in weights] @torch.no_grad() - def _get_weights( # type: ignore[override] + def _get_weights( self, class_idx: Union[int, List[int]], - *args: Any, + *_: Any, ) -> List[Tensor]: """Computes the weight coefficients of the hooked activation maps.""" self.hook_a: List[Tensor] # type: ignore[assignment] @@ -204,7 +204,12 @@ def _get_weights( # type: ignore[override] spatial_dims = self._input.ndim - 2 interpolation_mode = "bilinear" if spatial_dims == 2 else "trilinear" if spatial_dims == 3 else "nearest" upsampled_a = [ - F.interpolate(up_a, self._input.shape[2:], mode=interpolation_mode, align_corners=False) + F.interpolate( + up_a, + self._input.shape[2:], + mode=interpolation_mode, + align_corners=False, + ) for up_a in upsampled_a ] diff --git a/torchcam/methods/core.py b/torchcam/methods/core.py index 107d190..c08f05f 100644 --- a/torchcam/methods/core.py +++ b/torchcam/methods/core.py @@ -7,7 +7,7 @@ from abc import abstractmethod from functools import partial from types import TracebackType -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, List, Optional, Tuple, Type, Union, cast import torch import torch.nn.functional as F @@ -61,7 +61,7 @@ def __init__( else: raise TypeError("invalid argument type for `target_layer`") - if any(name not in self.submodule_dict.keys() for name in target_names): + if any(name not in self.submodule_dict for name in target_names): raise ValueError(f"Unable to find all submodules {target_names} in the model") self.target_names = target_names self.model = model @@ -104,7 +104,7 @@ def _resolve_layer_name(self, target_layer: nn.Module) -> str: return target_name - def _hook_a(self, module: nn.Module, input: Tuple[Tensor, ...], output: Tensor, idx: int = 0) -> None: + def _hook_a(self, _: nn.Module, _input: Tuple[Tensor, ...], output: Tensor, idx: int = 0) -> None: """Activation hook.""" if self._hooks_enabled: self.hook_a[idx] = output.data @@ -132,7 +132,7 @@ def _normalize(cams: Tensor, spatial_dims: Optional[int] = None, eps: float = 1e return cams @abstractmethod - def _get_weights(self, class_idx, scores, **kwargs): # type: ignore[no-untyped-def] + def _get_weights(self, class_idx: Union[int, List[int]], *args: Any, **kwargs: Any) -> List[Tensor]: raise NotImplementedError def _precheck(self, class_idx: Union[int, List[int]], scores: Optional[Tensor] = None) -> None: @@ -243,7 +243,7 @@ def fuse_cams(cls, cams: List[Tensor], target_shape: Optional[Tuple[int, int]] = if isinstance(target_shape, tuple): _shape = target_shape else: - _shape = tuple(map(max, zip(*[tuple(cam.shape[1:]) for cam in cams]))) # type: ignore[assignment] + _shape = tuple(map(max, zip(*[tuple(cam.shape[1:]) for cam in cams]))) # Scale cams scaled_cams = cls._scale_cams(cams) return cls._fuse_cams(scaled_cams, _shape) @@ -257,8 +257,14 @@ def _fuse_cams(cams: List[Tensor], target_shape: Tuple[int, int]) -> Tensor: # Interpolate all CAMs interpolation_mode = "bilinear" if cams[0].ndim == 3 else "trilinear" if cams[0].ndim == 4 else "nearest" scaled_cams = [ - F.interpolate(cam.unsqueeze(1), target_shape, mode=interpolation_mode, align_corners=False) for cam in cams + F.interpolate( + cam.unsqueeze(1), + target_shape, + mode=interpolation_mode, + align_corners=False, + ) + for cam in cams ] # Fuse them - return torch.stack(scaled_cams).max(dim=0).values.squeeze(1) + return cast(Tensor, torch.stack(scaled_cams).max(dim=0).values.squeeze(1)) diff --git a/torchcam/methods/gradient.py b/torchcam/methods/gradient.py index 3fe5d41..c51106f 100644 --- a/torchcam/methods/gradient.py +++ b/torchcam/methods/gradient.py @@ -4,7 +4,7 @@ # See LICENSE or go to for full license details. from functools import partial -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, cast import torch from torch import Tensor, nn @@ -43,12 +43,17 @@ def _store_grad(self, grad: Tensor, idx: int = 0) -> None: if self._hooks_enabled: self.hook_g[idx] = grad.data - def _hook_g(self, module: nn.Module, input: Tuple[Tensor, ...], output: Tensor, idx: int = 0) -> None: + def _hook_g(self, _: nn.Module, _input: Tuple[Tensor, ...], output: Tensor, idx: int = 0) -> None: """Gradient hook""" if self._hooks_enabled: self.hook_handles.append(output.register_hook(partial(self._store_grad, idx=idx))) - def _backprop(self, scores: Tensor, class_idx: Union[int, List[int]], retain_graph: bool = False) -> None: + def _backprop( + self, + scores: Tensor, + class_idx: Union[int, List[int]], + retain_graph: bool = False, + ) -> None: """Backpropagate the loss for a specific output class""" # Backpropagate to get the gradients on the hooked layer if isinstance(class_idx, int): @@ -58,9 +63,6 @@ def _backprop(self, scores: Tensor, class_idx: Union[int, List[int]], retain_gra self.model.zero_grad() loss.backward(retain_graph=retain_graph) - def _get_weights(self, class_idx: Union[int, List[int]], scores: Tensor, **kwargs: Any) -> List[Tensor]: - raise NotImplementedError - class GradCAM(_GradCAM): r"""Implements a class activation map extractor as described in `"Grad-CAM: Visual Explanations from Deep Networks @@ -146,7 +148,11 @@ class GradCAMpp(_GradCAM): """ def _get_weights( - self, class_idx: Union[int, List[int]], scores: Tensor, eps: float = 1e-8, **kwargs: Any + self, + class_idx: Union[int, List[int]], + scores: Tensor, + eps: float = 1e-8, + **kwargs: Any, ) -> List[Tensor]: """Computes the weight coefficients of the hooked activation maps.""" # Backpropagate @@ -245,13 +251,17 @@ def __init__( # Specific input hook updater self._ihook_enabled = True - def _store_input(self, module: nn.Module, input: Tensor) -> None: + def _store_input(self, _: nn.Module, _input: Tensor) -> None: """Store model input tensor.""" if self._ihook_enabled: - self._input = input[0].data.clone() + self._input = _input[0].data.clone() def _get_weights( - self, class_idx: Union[int, List[int]], scores: Optional[Tensor] = None, eps: float = 1e-8, **kwargs: Any + self, + class_idx: Union[int, List[int]], + _: Union[Tensor, None] = None, + eps: float = 1e-8, + **kwargs: Any, ) -> List[Tensor]: """Computes the weight coefficients of the hooked activation maps.""" # Disable input update @@ -332,7 +342,11 @@ class XGradCAM(_GradCAM): """ def _get_weights( - self, class_idx: Union[int, List[int]], scores: Tensor, eps: float = 1e-8, **kwargs: Any + self, + class_idx: Union[int, List[int]], + scores: Tensor, + eps: float = 1e-8, + **kwargs: Any, ) -> List[Tensor]: """Computes the weight coefficients of the hooked activation maps.""" # Backpropagate @@ -390,4 +404,4 @@ def _get_weights(self, class_idx: Union[int, List[int]], scores: Tensor, **kwarg @staticmethod def _scale_cams(cams: List[Tensor], gamma: float = 2.0) -> List[Tensor]: # cf. Equation 9 in the paper - return [torch.tanh(gamma * cam) for cam in cams] + return [torch.tanh(cast(Tensor, gamma * cam)) for cam in cams]