forked from jacobgil/pytorch-grad-cam
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathactivations_and_gradients.py
30 lines (25 loc) · 1.1 KB
/
activations_and_gradients.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class ActivationsAndGradients:
""" Class for extracting activations and
registering gradients from targetted intermediate layers """
def __init__(self, model, target_layer, reshape_transform):
self.model = model
self.gradients = []
self.activations = []
self.reshape_transform = reshape_transform
target_layer.register_forward_hook(self.save_activation)
target_layer.register_backward_hook(self.save_gradient)
def save_activation(self, module, input, output):
activation = output
if self.reshape_transform is not None:
activation = self.reshape_transform(activation)
self.activations.append(activation.cpu())
def save_gradient(self, module, grad_input, grad_output):
# Gradients are computed in reverse order
grad = grad_output[0]
if self.reshape_transform is not None:
grad = self.reshape_transform(grad)
self.gradients = [grad.cpu()] + self.gradients
def __call__(self, x):
self.gradients = []
self.activations = []
return self.model(x)