Skip to content

Commit

Permalink
Merge pull request #21 from cdpierse/feature/use-word-input-embedding…
Browse files Browse the repository at this point in the history
…s-only

Embeddings defaults to model.get_input_embeddings() for all models
  • Loading branch information
cdpierse authored Feb 27, 2021
2 parents aa93caa + 45780c0 commit c97e50b
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions transformers_interpret/explainers/sequence_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ def run(self, text: str = None, index: int = None, class_name: str = None):

def _forward(self, input_ids):
preds = self.model(input_ids)[0]
self.pred_probs = torch.softmax(preds, dim=1)[0][1]
self.pred_probs = torch.softmax(preds, dim=1)[0][self.selected_index]
return torch.softmax(preds, dim=1)[0][self.selected_index].unsqueeze(-1)

@property
def predicted_class_index(self):
if self.input_ids is not None:
# we call this before _forward() so it has to be calculated twice
preds = self.model(self.input_ids)[0]
self.pred_class = torch.argmax(torch.softmax(preds, dim=0)[0])
return torch.argmax(torch.softmax(preds, dim=1)[0]).cpu().detach().numpy()
Expand Down Expand Up @@ -121,9 +122,7 @@ def _calculate_attributions(self, index: int = None, class_name: str = None):
else:
self.selected_index = self.predicted_class_index
if self.attribution_type == "lig":
embeddings = getattr(self.model, self.model_prefix).embeddings
# embeddings = self.model.get_input_embeddings()
# embeddings = getattr(self.model, self.model_prefix).get_input_embeddings()
embeddings = self.model.get_input_embeddings()
reference_tokens = [token.replace("Ġ","") for token in self.decode(self.input_ids)]
lig = LIGAttributions(
self._forward,
Expand Down

0 comments on commit c97e50b

Please sign in to comment.