Skip to content

Commit

Permalink
Add first word spelling experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
sharanry committed Oct 28, 2024
1 parent d4aad15 commit 067cde9
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 8 deletions.
18 changes: 18 additions & 0 deletions probe_lens/experiments/experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod

"""
Probe experiments are used to generate data for probing tasks.
"""


class ProbeExperiment(ABC):
def __init__(self, task_name: str):
self.task_name = task_name
self.data = None

def __repr__(self) -> str:
return self.task_name

@abstractmethod
def get_data(self) -> list[tuple[str, int]]:
pass
36 changes: 36 additions & 0 deletions probe_lens/experiments/spelling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Callable

from probe_lens.experiments.experiments import ProbeExperiment

LETTERS = "abcdefghijklmnopqrstuvwxyz"
WORDS_DATASET = "https://www.mit.edu/~ecprice/wordlist.10000"


def default_spelling_prompt_generator(word: str):
return f"The word '{word}' is spelled:"


def first_letter_index(word: str):
return LETTERS.index(word.strip().lower()[0])


class FirstLetterSpelling(ProbeExperiment):
def __init__(
self,
words: list[str],
prompt_fn: Callable[[str], str] = default_spelling_prompt_generator,
class_fn: Callable[[str], int] = first_letter_index,
):
super().__init__("First Letter Spelling Experiment")
self.words = words
self.prompt_fn = prompt_fn
self.class_fn = class_fn
self.generate_data()

def generate_data(self):
self.classes = [self.class_fn(word) for word in self.words]
self.prompts = [self.prompt_fn(word) for word in self.words]
self.data = list(zip(self.prompts, self.classes))

def get_data(self) -> list[tuple[str, int]]:
return self.data
30 changes: 24 additions & 6 deletions probe_lens/probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def __init__(self, input_dim, output_dim=1, device="cpu", class_names=None):
def forward(self, x):
return self.linear(x)

def visualize_performance(self, dataloader: torch.utils.data.DataLoader, test=False):
def visualize_performance(
self, dataloader: torch.utils.data.DataLoader, test=False
):
preds = []
gts = []
for X, y in dataloader:
Expand All @@ -27,15 +29,28 @@ def visualize_performance(self, dataloader: torch.utils.data.DataLoader, test=Fa
gts = torch.cat(gts)

accuracy = accuracy_score(gts.cpu(), preds.cpu())
f2_score = fbeta_score(gts.cpu(), preds.cpu(), beta=2, average='weighted')
f2_score = fbeta_score(gts.cpu(), preds.cpu(), beta=2, average="weighted")

cm = confusion_matrix(gts.cpu(), preds.cpu())
plt.figure(figsize=(10, 7))
_class_names = self.class_names if self.class_names else [str(i) for i in range(cm.shape[0])]
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=_class_names, yticklabels=_class_names)
_class_names = (
self.class_names
if self.class_names
else [str(i) for i in range(cm.shape[0])]
)
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=_class_names,
yticklabels=_class_names,
)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title(f"Confusion Matrix (Accuracy: {accuracy:.4f}, F2 Score: {f2_score:.4f})")
plt.title(
f"Confusion Matrix (Accuracy: {accuracy:.4f}, F2 Score: {f2_score:.4f})"
)
if not test:
plt.show()
return plt
Expand All @@ -60,7 +75,10 @@ def train_probe(
optimizer.step()
loss_sum += loss.item()
if val_dataloader is not None:
dataset_names, datasets = ["train", "val"], [dataloader, val_dataloader]
dataset_names, datasets = (
["train", "val"],
[dataloader, val_dataloader],
)
else:
dataset_names, datasets = ["train"], [dataloader]
if verbose and (epoch + 1) % 10 == 0:
Expand Down
9 changes: 9 additions & 0 deletions tests/experiments/test_spelling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from probe_lens.experiments.spelling import LETTERS, FirstLetterSpelling


def test_first_letter_spelling():
words = ["example", "words", "to", "spell"]
spelling_task = FirstLetterSpelling(words)
data = spelling_task.data
classes = [c for _, c in data]
assert classes == [LETTERS.index(word.lower()[0]) for word in words]
9 changes: 7 additions & 2 deletions tests/test_probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,10 @@ def test_linear_probe_visualization():
assert plot is not None
# assert isinstance(plot, plt.Figure)
accuracy = accuracy_score(y.argmax(dim=1).cpu(), model(x).argmax(dim=1).cpu())
f2_score = fbeta_score(y.argmax(dim=1).cpu(), model(x).argmax(dim=1).cpu(), beta=2, average='weighted')
assert plot.gca().get_title() == f"Confusion Matrix (Accuracy: {accuracy:.4f}, F2 Score: {f2_score:.4f})"
f2_score = fbeta_score(
y.argmax(dim=1).cpu(), model(x).argmax(dim=1).cpu(), beta=2, average="weighted"
)
assert (
plot.gca().get_title()
== f"Confusion Matrix (Accuracy: {accuracy:.4f}, F2 Score: {f2_score:.4f})"
)

0 comments on commit 067cde9

Please sign in to comment.