Skip to content

Commit

Permalink
Made review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
atticusg committed Mar 24, 2022
1 parent b19a83d commit babee6c
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 59 deletions.
136 changes: 82 additions & 54 deletions IIT_01.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"import random\n",
"import numpy as np\n",
"from sklearn.metrics import classification_report\n",
"from equality_datasets import get_equality_dataset, get_IIT_equality_dataset\n",
"from IIT_torch_shallow_neural_classifier import TorchShallowNeuralClassifierIIT\n",
"from iit import get_equality_dataset, get_IIT_equality_dataset\n",
"from torch_shallow_neural_classifier_iit import TorchShallowNeuralClassifierIIT\n",
"from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n",
"from torch_rnn_classifier import TorchRNNClassifier\n",
"random.seed(42)"
Expand Down Expand Up @@ -203,7 +203,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 553. Training loss did not improve more than tol=1e-05. Final error is 0.0003255244682804914."
"Stopping after epoch 478. Training loss did not improve more than tol=1e-05. Final error is 0.00043814663513330743."
]
}
],
Expand Down Expand Up @@ -299,31 +299,31 @@
" (activation): ReLU()\n",
")\n",
"\n",
"Neural Activations: tensor([[0.0000, 0.6438, 0.0000, 0.4686, 0.0000, 0.2442, 0.1763, 0.0000, 0.0000,\n",
" 0.0055, 0.4739, 0.0000, 0.0000, 0.0092, 0.0000, 0.1003, 0.0000, 0.0000,\n",
" 0.1359, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"Neural Activations: tensor([[0.0000, 0.0846, 0.0000, 0.9482, 0.0317, 0.3597, 0.0000, 0.3631, 0.0081,\n",
" 0.6297, 0.2109, 0.3294, 0.0000, 0.2897, 0.5303, 0.0000, 0.0000, 0.0560,\n",
" 0.0000, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"\n",
"Layer 1: ActivationLayer(\n",
" (linear): Linear(in_features=20, out_features=20, bias=True)\n",
" (activation): ReLU()\n",
")\n",
"\n",
"Neural Activations: tensor([[0.9234, 0.0949, 0.4198, 0.0000, 0.0256, 0.0000, 0.0000, 0.3851, 0.3467,\n",
" 0.0000, 0.3639, 0.0000, 0.0000, 0.2465, 0.2948, 0.0000, 0.4032, 0.2969,\n",
" 0.4415, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"Neural Activations: tensor([[0.0000, 0.8141, 0.0000, 0.0000, 0.6008, 0.4972, 0.0000, 0.4790, 0.6306,\n",
" 0.0000, 0.4974, 1.1702, 0.6278, 0.8363, 0.6990, 0.0000, 1.3104, 0.5918,\n",
" 0.5621, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"\n",
"Layer 2: ActivationLayer(\n",
" (linear): Linear(in_features=20, out_features=20, bias=True)\n",
" (activation): ReLU()\n",
")\n",
"\n",
"Neural Activations: tensor([[0.0000, 1.1027, 0.0000, 0.0000, 0.0000, 0.0000, 0.9665, 0.0000, 0.0000,\n",
" 0.0000, 0.0000, 1.1404, 0.0000, 1.2535, 0.0000, 0.0000, 1.4600, 0.4888,\n",
" 0.0000, 1.0620]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"Neural Activations: tensor([[1.7722, 1.6336, 3.5214, 1.8255, 1.6786, 0.4072, 2.4085, 3.8123, 0.5702,\n",
" 0.6714, 2.2578, 1.1029, 0.0000, 0.9687, 0.6084, 0.5882, 0.0000, 0.0000,\n",
" 0.0000, 0.4937]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"\n",
"Layer 3: Linear(in_features=20, out_features=2, bias=True)\n",
"\n",
"Neural Activations: tensor([[ 7.7798, -7.8989]], device='cuda:0', grad_fn=<SliceBackward>)\n"
"Neural Activations: tensor([[ 5.9969, -7.4518]], device='cuda:0', grad_fn=<SliceBackward>)\n"
]
}
],
Expand Down Expand Up @@ -366,31 +366,31 @@
" (activation): ReLU()\n",
")\n",
"\n",
"Neural Activations: tensor([[0.0000, 0.6438, 0.0000, 0.4686, 0.0000, 0.2442, 0.1763, 0.0000, 0.0000,\n",
" 0.0055, 0.4739, 0.0000, 0.0000, 0.0092, 0.0000, 0.1003, 0.0000, 0.0000,\n",
" 0.1359, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"Neural Activations: tensor([[0.0000, 0.0846, 0.0000, 0.9482, 0.0317, 0.3597, 0.0000, 0.3631, 0.0081,\n",
" 0.6297, 0.2109, 0.3294, 0.0000, 0.2897, 0.5303, 0.0000, 0.0000, 0.0560,\n",
" 0.0000, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"\n",
"Layer 1: ActivationLayer(\n",
" (linear): Linear(in_features=20, out_features=20, bias=True)\n",
" (activation): ReLU()\n",
")\n",
"\n",
"Neural Activations: tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3851, 0.3467,\n",
" 0.0000, 0.3639, 0.0000, 0.0000, 0.2465, 0.2948, 0.0000, 0.4032, 0.2969,\n",
" 0.4415, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"Neural Activations: tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4972, 0.0000, 0.4790, 0.6306,\n",
" 0.0000, 0.4974, 1.1702, 0.6278, 0.8363, 0.6990, 0.0000, 1.3104, 0.5918,\n",
" 0.5621, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"\n",
"Layer 2: ActivationLayer(\n",
" (linear): Linear(in_features=20, out_features=20, bias=True)\n",
" (activation): ReLU()\n",
")\n",
"\n",
"Neural Activations: tensor([[0.0000, 0.1819, 0.5323, 0.4157, 0.6224, 0.5768, 0.1968, 0.4671, 0.3697,\n",
" 0.0000, 0.4677, 0.0531, 0.3808, 0.1635, 0.5662, 0.0000, 0.3467, 0.4433,\n",
" 0.5137, 0.2045]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"Neural Activations: tensor([[1.5023, 1.6149, 2.8146, 1.8080, 1.4614, 0.2798, 2.1097, 3.0883, 0.4062,\n",
" 0.4194, 2.0156, 1.4818, 0.0000, 0.7672, 0.4956, 0.3941, 0.0000, 0.0000,\n",
" 0.0000, 0.1730]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"\n",
"Layer 3: Linear(in_features=20, out_features=2, bias=True)\n",
"\n",
"Neural Activations: tensor([[-4.7269, 4.8842]], device='cuda:0', grad_fn=<SliceBackward>)\n"
"Neural Activations: tensor([[ 2.3688, -3.5759]], device='cuda:0', grad_fn=<SliceBackward>)\n"
]
}
],
Expand Down Expand Up @@ -432,31 +432,31 @@
" (activation): ReLU()\n",
")\n",
"\n",
"Neural Activations: tensor([[0.0000, 0.6438, 0.0000, 0.4686, 0.0000, 0.2442, 0.1763, 0.0000, 0.0000,\n",
" 0.0055, 0.4739, 0.0000, 0.0000, 0.0092, 0.0000, 0.1003, 0.0000, 0.0000,\n",
" 0.1359, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"Neural Activations: tensor([[0.0000, 0.0846, 0.0000, 0.9482, 0.0317, 0.3597, 0.0000, 0.3631, 0.0081,\n",
" 0.6297, 0.2109, 0.3294, 0.0000, 0.2897, 0.5303, 0.0000, 0.0000, 0.0560,\n",
" 0.0000, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"\n",
"Layer 1: ActivationLayer(\n",
" (linear): Linear(in_features=20, out_features=20, bias=True)\n",
" (activation): ReLU()\n",
")\n",
"\n",
"Neural Activations: tensor([[0.0000, 0.8636, 0.4841, 0.2193, 0.5431, 0.0000, 0.0000, 0.3851, 0.3467,\n",
" 0.0000, 0.3639, 0.0000, 0.0000, 0.2465, 0.2948, 0.0000, 0.4032, 0.2969,\n",
" 0.4415, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"Neural Activations: tensor([[0.0000, 0.3640, 0.1019, 0.4705, 0.0000, 0.4972, 0.0000, 0.4790, 0.6306,\n",
" 0.0000, 0.4974, 1.1702, 0.6278, 0.8363, 0.6990, 0.0000, 1.3104, 0.5918,\n",
" 0.5621, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"\n",
"Layer 2: ActivationLayer(\n",
" (linear): Linear(in_features=20, out_features=20, bias=True)\n",
" (activation): ReLU()\n",
")\n",
"\n",
"Neural Activations: tensor([[0.0000, 0.1969, 0.8776, 0.7200, 0.9971, 0.8565, 0.3353, 0.8209, 0.9587,\n",
" 0.0000, 0.8599, 0.0000, 0.7895, 0.0000, 0.9622, 0.0000, 0.2153, 0.8869,\n",
" 0.8579, 0.0020]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"Neural Activations: tensor([[0.8910, 0.9440, 3.3704, 1.2600, 0.8820, 1.0984, 2.7786, 3.4767, 1.2818,\n",
" 1.1336, 1.3859, 0.7346, 0.0000, 1.9162, 1.3452, 1.1835, 0.0000, 0.0000,\n",
" 0.0000, 1.0403]], device='cuda:0', grad_fn=<SliceBackward>)\n",
"\n",
"Layer 3: Linear(in_features=20, out_features=2, bias=True)\n",
"\n",
"Neural Activations: tensor([[-9.8724, 10.1956]], device='cuda:0', grad_fn=<SliceBackward>)\n"
"Neural Activations: tensor([[ 15.5163, -17.1102]], device='cuda:0', grad_fn=<SliceBackward>)\n"
]
}
],
Expand Down Expand Up @@ -546,21 +546,21 @@
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.62 0.49 0.55 2916\n",
" 1 0.65 0.75 0.70 3645\n",
" 0 0.53 0.36 0.43 2916\n",
" 1 0.59 0.74 0.66 3645\n",
"\n",
" accuracy 0.64 6561\n",
" macro avg 0.63 0.62 0.62 6561\n",
"weighted avg 0.64 0.64 0.63 6561\n",
" accuracy 0.57 6561\n",
" macro avg 0.56 0.55 0.54 6561\n",
"weighted avg 0.56 0.57 0.56 6561\n",
"\n",
" precision recall f1-score support\n",
"\n",
" 0 0.53 0.57 0.55 2916\n",
" 1 0.63 0.59 0.61 3645\n",
" 0 0.50 0.31 0.38 2916\n",
" 1 0.58 0.75 0.65 3645\n",
"\n",
" accuracy 0.58 6561\n",
" macro avg 0.58 0.58 0.58 6561\n",
"weighted avg 0.59 0.58 0.58 6561\n",
" accuracy 0.55 6561\n",
" macro avg 0.54 0.53 0.52 6561\n",
"weighted avg 0.54 0.55 0.53 6561\n",
"\n"
]
}
Expand Down Expand Up @@ -593,7 +593,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 590. Training loss did not improve more than tol=1e-05. Final error is 0.001175935372884851."
"Stopping after epoch 607. Training loss did not improve more than tol=1e-05. Final error is 0.0012616138847079128."
]
}
],
Expand Down Expand Up @@ -637,12 +637,12 @@
"\n",
" precision recall f1-score support\n",
"\n",
" 0 0.65 0.65 0.65 5000\n",
" 1 0.65 0.65 0.65 5000\n",
" 0 0.62 0.65 0.63 5000\n",
" 1 0.63 0.59 0.61 5000\n",
"\n",
" accuracy 0.65 10000\n",
" macro avg 0.65 0.65 0.65 10000\n",
"weighted avg 0.65 0.65 0.65 10000\n",
" accuracy 0.62 10000\n",
" macro avg 0.62 0.62 0.62 10000\n",
"weighted avg 0.62 0.62 0.62 10000\n",
"\n"
]
}
Expand Down Expand Up @@ -689,7 +689,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 772. Training loss did not improve more than tol=1e-05. Final error is 0.3725438618566841."
"Stopping after epoch 729. Training loss did not improve more than tol=1e-05. Final error is 0.29536177532281727."
]
}
],
Expand Down Expand Up @@ -730,8 +730,8 @@
"\n",
" precision recall f1-score support\n",
"\n",
" 0 0.99 0.99 0.99 5000\n",
" 1 0.99 0.99 0.99 5000\n",
" 0 0.98 0.99 0.99 5000\n",
" 1 0.99 0.98 0.99 5000\n",
"\n",
" accuracy 0.99 10000\n",
" macro avg 0.99 0.99 0.99 10000\n",
Expand Down Expand Up @@ -765,14 +765,42 @@
"base_preds = np.array(base_preds.argmax(axis=1).cpu())\n",
"print(classification_report(y_IIT_test, IIT_preds))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "933b0a94e0d88ac80a17cb26ca3d8d36930c12815b02a2885c1925c2b1ae3c33"
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -786,7 +814,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.9.6"
}
},
"nbformat": 4,
Expand Down
6 changes: 4 additions & 2 deletions iit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from utils import randvec

class IITModel(torch.nn.Module):
def __init__(self, model):
def __init__(self, model, layers, id_to_coords,device):
super().__init__()
self.model = model
self.layers = model.layers
self.layers = layers
self.id_to_coords = id_to_coords
self.device = device

def no_IIT_forward(self, X):
return self.model(X)
Expand Down
31 changes: 30 additions & 1 deletion torch_shallow_neural_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,36 @@ def predict(self, X, device=None):
"""
probs = self.predict_proba(X, device=device)
return [self.classes_[i] for i in probs.argmax(axis=1)]


def _get_set(self,get, set = None):
if set is None:
def gethook(model,input,output):
self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"]]
set_handler = self.layers[get["layer"]].register_forward_hook(gethook)
return [set_handler]
elif set["layer"] != get["layer"]:
def sethook(model,input,output):
output[:,set["start"]: set["end"]] = set["intervention"]
set_handler = self.layers[set["layer"]].register_forward_hook(sethook)
def gethook(model,input,output):
self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"] ]
get_handler = self.layers[get["layer"]].register_forward_hook(gethook)
return [set_handler, get_handler]
else:
def bothhook(model, input, output):
output[:,set["start"]: set["end"]] = set["intervention"]
self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"] ]
both_handler = self.layers[set["layer"]].register_forward_hook(bothhook)
return [both_handler]

def retrieve_activations(self, input, get, set):
input = input.type(torch.FloatTensor).to(self.device)
self.activation = dict()
handlers = self._get_set(get, set)
logits = self.model(input)
for handler in handlers:
handler.remove()
return self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}']

def simple_example():
"""Assess on the digits dataset."""
Expand Down
6 changes: 4 additions & 2 deletions torch_shallow_neural_classifier_iit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class TorchShallowNeuralClassifierIIT(TorchShallowNeuralClassifier):
def __init__(self,id_to_coords, **base_kwargs):
def __init__(self,id_to_coords = None, **base_kwargs):
super().__init__(**base_kwargs)
loss_function= nn.CrossEntropyLoss(reduction="mean")
self.loss = lambda preds, labels: loss_function(preds[0],labels[:,0]) + loss_function(preds[1],labels[:,1])
Expand All @@ -20,7 +20,7 @@ def __init__(self,id_to_coords, **base_kwargs):

def build_graph(self):
model = super().build_graph()
IITmodel = IITModel(model)
IITmodel = IITModel(model, self.layers, self.id_to_coords, self.device)
return IITmodel

def batched_indices(self, max_len):
Expand Down Expand Up @@ -86,5 +86,7 @@ def prep_input(self, base, source, coord_ids):
bigX = torch.stack((base,source, coord_ids.unsqueeze(1).expand(-1, base.shape[1])), dim=1)
return bigX



if __name__ == '__main__':
simple_example()

0 comments on commit babee6c

Please sign in to comment.