forked from cgpotts/cs224u
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
218 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import BertModel, BertTokenizer\n", | ||
"from iit import IITModel\n", | ||
"from torch_model_base import TorchModelBase\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class HfBertClassifierModelIIT(IITModel):\n", | ||
" def __init__(self, n_classes, weights_name='bert-base-cased'):\n", | ||
" super().__init__()\n", | ||
" self.n_classes = n_classes\n", | ||
" self.weights_name = weights_name\n", | ||
" self.bert = BertModel.from_pretrained(self.weights_name)\n", | ||
" self.bert.train()\n", | ||
" self.hidden_dim = self.bert.embeddings.word_embeddings.embedding_dim\n", | ||
" # The only new parameters -- the classifier:\n", | ||
" self.classifier_layer = nn.Linear(\n", | ||
" self.hidden_dim, self.n_classes)\n", | ||
" self.layers = self.bert.model.encoder.layer\n", | ||
" \n", | ||
" def no_IIT_forward(self, indices, attention_mask):\n", | ||
" reps = self.bert(\n", | ||
" indices, attention_mask=mask)\n", | ||
" return self.classifier_layer(reps.pooler_output)\n", | ||
" \n", | ||
" def forward(self, X):\n", | ||
" base_indices, base_mask,source_indices,source_mask, coord_ids = [X[:,0,:].squeeze(1) for j in range(5)]\n", | ||
" get = self.id_to_coords[int(coord_ids.flatten()[0])]\n", | ||
" base = base.type(torch.FloatTensor).to(self.device)\n", | ||
" source = source.type(torch.FloatTensor).to(self.device)\n", | ||
" \n", | ||
" self.activation = dict()\n", | ||
" handlers = self._get_set(get,None)\n", | ||
" source_logits = self.model.no_IIT_forward(source_indices,source_mask)\n", | ||
" for handler in handlers:\n", | ||
" handler.remove()\n", | ||
"\n", | ||
" base_logits = self.model.no_IIT_forward(base_indices, base_mask)\n", | ||
" set = {k:get[k] for k in get}\n", | ||
" set[\"intervention\"] = self.activation[f'{get[\"layer\"]}-{get[\"start\"]}-{get[\"end\"]}']\n", | ||
" handlers = self._get_set(get, set)\n", | ||
" counterfactual_logits = self.model.no_IIT_forward(base_indices, base_mask)\n", | ||
" for handler in handlers:\n", | ||
" handler.remove()\n", | ||
"\n", | ||
" return counterfactual_logits, base_logits\n", | ||
"\n", | ||
" \n", | ||
"\n", | ||
"class HfBertClassifierIIT(TorchModelBase):\n", | ||
" def __init__(self, *args, **kwargs):\n", | ||
" self.weights_name = kwargs[\"weights_name\"]\n", | ||
" self.tokenizer = BertTokenizer.from_pretrained(self.weights_name)\n", | ||
" super().__init__(*args, **kwargs)\n", | ||
" self.params += ['weights_name']\n", | ||
"\n", | ||
" def build_graph(self):\n", | ||
" return HfBertClassifierModelIIT(self.n_classes_, self.weights_name)\n", | ||
"\n", | ||
" def build_dataset(self, base, source, base_y, IIT_y, coord_ids):\n", | ||
" base_data = self.tokenizer.batch_encode_plus(\n", | ||
" base,\n", | ||
" max_length=None,\n", | ||
" add_special_tokens=True,\n", | ||
" padding='longest',\n", | ||
" return_attention_mask=True)\n", | ||
" source_data = self.tokenizer.batch_encode_plus(\n", | ||
" source,\n", | ||
" max_length=None,\n", | ||
" add_special_tokens=True,\n", | ||
" padding='longest',\n", | ||
" return_attention_mask=True)\n", | ||
" base_indices = torch.tensor(base_data['input_ids'])\n", | ||
" base_mask = torch.tensor(base_data['attention_mask'])\n", | ||
" source_indices = torch.tensor(source_data['input_ids'])\n", | ||
" source_mask = torch.tensor(source_data['attention_mask'])\n", | ||
" \n", | ||
" self.classes_ = sorted(set(base_y))\n", | ||
" self.n_classes_ = len(self.classes_)\n", | ||
" class2index = dict(zip(self.classes_, range(self.n_classes_)))\n", | ||
" base_y = [class2index[label] for label in base_y]\n", | ||
" base_y = torch.tensor(base_y)\n", | ||
"\n", | ||
" self.classes_ = sorted(set(IIT_y))\n", | ||
" self.n_classes_ = len(self.classes_)\n", | ||
" class2index = dict(zip(self.classes_, range(self.n_classes_)))\n", | ||
" IIT_y = [class2index[label] for label in base_y]\n", | ||
" IIT_y = torch.tensor(IIT_y)\n", | ||
" \n", | ||
" bigX = torch.stack((base_indices, base_mask,source_indices,source_mask, coord_ids.unsqueeze(1).expand(-1, X.shape[1])), dim=1)\n", | ||
" bigy = torch.stack((IIT_y, base_y), dim=1)\n", | ||
" return dataset\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = HfBertClassifier(\n", | ||
" weights_name='bert-base-cased',\n", | ||
" batch_size=8, # Small batches to avoid memory overload.\n", | ||
" max_iter=1, # We'll search based on 1 iteration for efficiency.\n", | ||
" n_iter_no_change=5, # Early-stopping params are for the\n", | ||
" early_stopping=True) # final evaluation.\n", | ||
"\n", | ||
"param_grid = {\n", | ||
" 'gradient_accumulation_steps': [1, 4, 8],\n", | ||
" 'eta': [0.00005, 0.0001, 0.001],\n", | ||
" 'hidden_dim': [100, 200, 300]}\n", | ||
"\n", | ||
"X_base_test, X_source_test, y_base_test, y_IIT_test, interventions = get_IIT_MoNLI_dataset(os.path.join(\"data\", \"MoNLI\"))\n", | ||
"\n", | ||
"model.fit(X_base_test, X_source_test, y_base_test, y_IIT_test, interventions)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters