Skip to content

Commit

Permalink
review update
Browse files Browse the repository at this point in the history
  • Loading branch information
atticusg committed Mar 24, 2022
1 parent a1e7d50 commit b19a83d
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 57 deletions.
5 changes: 2 additions & 3 deletions IIT_01.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"from equality_datasets import get_equality_dataset, get_IIT_equality_dataset\n",
"from IIT_torch_shallow_neural_classifier import TorchShallowNeuralClassifierIIT\n",
"from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n",
"from torch_interventionable_model import InterventionableLayeredModel\n",
"from torch_rnn_classifier import TorchRNNClassifier\n",
"random.seed(42)"
]
Expand Down Expand Up @@ -773,7 +772,7 @@
"hash": "933b0a94e0d88ac80a17cb26ca3d8d36930c12815b02a2885c1925c2b1ae3c33"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -787,7 +786,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
"version": "3.7.3"
}
},
"nbformat": 4,
Expand Down
60 changes: 60 additions & 0 deletions equality_datasets.py → iit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,66 @@
import torch
from utils import randvec

class IITModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.layers = model.layers

def no_IIT_forward(self, X):
return self.model(X)

def forward(self, X):
base, source, coord_ids = X[:,0,:].squeeze(1), X[:,1,:].squeeze(1),X[:,2,:].squeeze(1)
get = self.id_to_coords[int(coord_ids.flatten()[0])]
base = base.type(torch.FloatTensor).to(self.device)
source = source.type(torch.FloatTensor).to(self.device)
self.activation = dict()
handlers = self._get_set(get,None)
source_logits = self.no_IIT_forward(source)
for handler in handlers:
handler.remove()

base_logits = self.no_IIT_forward(base)
set = {k:get[k] for k in get}
set["intervention"] = self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}']
handlers = self._get_set(get, set)
counterfactual_logits = self.no_IIT_forward(base)
for handler in handlers:
handler.remove()

return counterfactual_logits, base_logits

def _get_set(self,get, set = None):
if set is None:
def gethook(model,input,output):
self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"]]
set_handler = self.layers[get["layer"]].register_forward_hook(gethook)
return [set_handler]
elif set["layer"] != get["layer"]:
def sethook(model,input,output):
output[:,set["start"]: set["end"]] = set["intervention"]
set_handler = self.layers[set["layer"]].register_forward_hook(sethook)
def gethook(model,input,output):
self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"] ]
get_handler = self.layers[get["layer"]].register_forward_hook(gethook)
return [set_handler, get_handler]
else:
def bothhook(model, input, output):
output[:,set["start"]: set["end"]] = set["intervention"]
self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"] ]
both_handler = self.layers[set["layer"]].register_forward_hook(bothhook)
return [both_handler]

def retrieve_activations(self, input, get, set):
input = input.type(torch.FloatTensor).to(self.device)
self.activation = dict()
handlers = self._get_set(get, set)
logits = self.model(input)
for handler in handlers:
handler.remove()
return self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}']


def get_IIT_equality_dataset(variable, embed_dim, size):
class_size = size/2
Expand Down
152 changes: 152 additions & 0 deletions monli.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from transformers import BertModel, BertTokenizer\n",
"from iit import IITModel\n",
"from torch_model_base import TorchModelBase\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class HfBertClassifierModelIIT(IITModel):\n",
" def __init__(self, n_classes, weights_name='bert-base-cased'):\n",
" super().__init__()\n",
" self.n_classes = n_classes\n",
" self.weights_name = weights_name\n",
" self.bert = BertModel.from_pretrained(self.weights_name)\n",
" self.bert.train()\n",
" self.hidden_dim = self.bert.embeddings.word_embeddings.embedding_dim\n",
" # The only new parameters -- the classifier:\n",
" self.classifier_layer = nn.Linear(\n",
" self.hidden_dim, self.n_classes)\n",
" self.layers = self.bert.model.encoder.layer\n",
" \n",
" def no_IIT_forward(self, indices, attention_mask):\n",
" reps = self.bert(\n",
" indices, attention_mask=mask)\n",
" return self.classifier_layer(reps.pooler_output)\n",
" \n",
" def forward(self, X):\n",
" base_indices, base_mask,source_indices,source_mask, coord_ids = [X[:,0,:].squeeze(1) for j in range(5)]\n",
" get = self.id_to_coords[int(coord_ids.flatten()[0])]\n",
" base = base.type(torch.FloatTensor).to(self.device)\n",
" source = source.type(torch.FloatTensor).to(self.device)\n",
" \n",
" self.activation = dict()\n",
" handlers = self._get_set(get,None)\n",
" source_logits = self.model.no_IIT_forward(source_indices,source_mask)\n",
" for handler in handlers:\n",
" handler.remove()\n",
"\n",
" base_logits = self.model.no_IIT_forward(base_indices, base_mask)\n",
" set = {k:get[k] for k in get}\n",
" set[\"intervention\"] = self.activation[f'{get[\"layer\"]}-{get[\"start\"]}-{get[\"end\"]}']\n",
" handlers = self._get_set(get, set)\n",
" counterfactual_logits = self.model.no_IIT_forward(base_indices, base_mask)\n",
" for handler in handlers:\n",
" handler.remove()\n",
"\n",
" return counterfactual_logits, base_logits\n",
"\n",
" \n",
"\n",
"class HfBertClassifierIIT(TorchModelBase):\n",
" def __init__(self, *args, **kwargs):\n",
" self.weights_name = kwargs[\"weights_name\"]\n",
" self.tokenizer = BertTokenizer.from_pretrained(self.weights_name)\n",
" super().__init__(*args, **kwargs)\n",
" self.params += ['weights_name']\n",
"\n",
" def build_graph(self):\n",
" return HfBertClassifierModelIIT(self.n_classes_, self.weights_name)\n",
"\n",
" def build_dataset(self, base, source, base_y, IIT_y, coord_ids):\n",
" base_data = self.tokenizer.batch_encode_plus(\n",
" base,\n",
" max_length=None,\n",
" add_special_tokens=True,\n",
" padding='longest',\n",
" return_attention_mask=True)\n",
" source_data = self.tokenizer.batch_encode_plus(\n",
" source,\n",
" max_length=None,\n",
" add_special_tokens=True,\n",
" padding='longest',\n",
" return_attention_mask=True)\n",
" base_indices = torch.tensor(base_data['input_ids'])\n",
" base_mask = torch.tensor(base_data['attention_mask'])\n",
" source_indices = torch.tensor(source_data['input_ids'])\n",
" source_mask = torch.tensor(source_data['attention_mask'])\n",
" \n",
" self.classes_ = sorted(set(base_y))\n",
" self.n_classes_ = len(self.classes_)\n",
" class2index = dict(zip(self.classes_, range(self.n_classes_)))\n",
" base_y = [class2index[label] for label in base_y]\n",
" base_y = torch.tensor(base_y)\n",
"\n",
" self.classes_ = sorted(set(IIT_y))\n",
" self.n_classes_ = len(self.classes_)\n",
" class2index = dict(zip(self.classes_, range(self.n_classes_)))\n",
" IIT_y = [class2index[label] for label in base_y]\n",
" IIT_y = torch.tensor(IIT_y)\n",
" \n",
" bigX = torch.stack((base_indices, base_mask,source_indices,source_mask, coord_ids.unsqueeze(1).expand(-1, X.shape[1])), dim=1)\n",
" bigy = torch.stack((IIT_y, base_y), dim=1)\n",
" return dataset\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = HfBertClassifier(\n",
" weights_name='bert-base-cased',\n",
" batch_size=8, # Small batches to avoid memory overload.\n",
" max_iter=1, # We'll search based on 1 iteration for efficiency.\n",
" n_iter_no_change=5, # Early-stopping params are for the\n",
" early_stopping=True) # final evaluation.\n",
"\n",
"param_grid = {\n",
" 'gradient_accumulation_steps': [1, 4, 8],\n",
" 'eta': [0.00005, 0.0001, 0.001],\n",
" 'hidden_dim': [100, 200, 300]}\n",
"\n",
"X_base_test, X_source_test, y_base_test, y_IIT_test, interventions = get_IIT_MoNLI_dataset(os.path.join(\"data\", \"MoNLI\"))\n",
"\n",
"model.fit(X_base_test, X_source_test, y_base_test, y_IIT_test, interventions)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
30 changes: 0 additions & 30 deletions torch_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,33 +651,3 @@ def __repr__(self):
param_str = ["{}={}".format(a, getattr(self, a)) for a in self.params]
param_str = ",\n\t".join(param_str)
return "{}(\n\t{})".format(self.__class__.__name__, param_str)

def _get_set(self,get, set = None):
if set is None:
def gethook(model,input,output):
self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"]]
set_handler = self.layers[get["layer"]].register_forward_hook(gethook)
return [set_handler]
elif set["layer"] != get["layer"]:
def sethook(model,input,output):
output[:,set["start"]: set["end"]] = set["intervention"]
set_handler = self.layers[set["layer"]].register_forward_hook(sethook)
def gethook(model,input,output):
self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"] ]
get_handler = self.layers[get["layer"]].register_forward_hook(gethook)
return [set_handler, get_handler]
else:
def bothhook(model, input, output):
output[:,set["start"]: set["end"]] = set["intervention"]
self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"] ]
both_handler = self.layers[set["layer"]].register_forward_hook(bothhook)
return [both_handler]

def retrieve_activations(self, input, get, set):
input = input.type(torch.FloatTensor).to(self.device)
self.activation = dict()
handlers = self._get_set(get, set)
logits = self.model(input)
for handler in handlers:
handler.remove()
return self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}']
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import torch.utils.data
from torch_shallow_neural_classifier import TorchShallowNeuralClassifier
import utils
from iit import IITModel

__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Spring 2021"


class TorchShallowNeuralClassifierIIT(TorchShallowNeuralClassifier):
def __init__(self,id_to_coords, **base_kwargs):
super().__init__(**base_kwargs)
Expand All @@ -18,9 +20,8 @@ def __init__(self,id_to_coords, **base_kwargs):

def build_graph(self):
model = super().build_graph()
model.no_IIT_forward = model.forward
model.forward = self.IIT_forward
return model
IITmodel = IITModel(model)
return IITmodel

def batched_indices(self, max_len):
batch_indices = [ x for x in range((max_len // self.batch_size))]
Expand Down Expand Up @@ -85,26 +86,5 @@ def prep_input(self, base, source, coord_ids):
bigX = torch.stack((base,source, coord_ids.unsqueeze(1).expand(-1, base.shape[1])), dim=1)
return bigX

def IIT_forward(self, X):
base, source, coord_ids = X[:,0,:].squeeze(1), X[:,1,:].squeeze(1),X[:,2,:].squeeze(1)
get = self.id_to_coords[int(coord_ids.flatten()[0])]
base = base.type(torch.FloatTensor).to(self.device)
source = source.type(torch.FloatTensor).to(self.device)
self.activation = dict()
handlers = self._get_set(get,None)
source_logits = self.model.no_IIT_forward(source)
for handler in handlers:
handler.remove()

base_logits = self.model.no_IIT_forward(base)
set = {k:get[k] for k in get}
set["intervention"] = self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}']
handlers = self._get_set(get, set)
counterfactual_logits = self.model.no_IIT_forward(base)
for handler in handlers:
handler.remove()

return counterfactual_logits, base_logits

if __name__ == '__main__':
simple_example()

0 comments on commit b19a83d

Please sign in to comment.