diff --git a/probe_lens/experiments/experiments.py b/probe_lens/experiments/experiments.py new file mode 100644 index 0000000..bf6ed5c --- /dev/null +++ b/probe_lens/experiments/experiments.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + +""" +Probe experiments are used to generate data for probing tasks. +""" + + +class ProbeExperiment(ABC): + def __init__(self, task_name: str): + self.task_name = task_name + self.data = None + + def __repr__(self) -> str: + return self.task_name + + @abstractmethod + def get_data(self) -> list[tuple[str, int]]: + pass diff --git a/probe_lens/experiments/spelling.py b/probe_lens/experiments/spelling.py index e69de29..8955c22 100644 --- a/probe_lens/experiments/spelling.py +++ b/probe_lens/experiments/spelling.py @@ -0,0 +1,36 @@ +from typing import Callable + +from probe_lens.experiments.experiments import ProbeExperiment + +LETTERS = "abcdefghijklmnopqrstuvwxyz" +WORDS_DATASET = "https://www.mit.edu/~ecprice/wordlist.10000" + + +def default_spelling_prompt_generator(word: str): + return f"The word '{word}' is spelled:" + + +def first_letter_index(word: str): + return LETTERS.index(word.strip().lower()[0]) + + +class FirstLetterSpelling(ProbeExperiment): + def __init__( + self, + words: list[str], + prompt_fn: Callable[[str], str] = default_spelling_prompt_generator, + class_fn: Callable[[str], int] = first_letter_index, + ): + super().__init__("First Letter Spelling Experiment") + self.words = words + self.prompt_fn = prompt_fn + self.class_fn = class_fn + self.generate_data() + + def generate_data(self): + self.classes = [self.class_fn(word) for word in self.words] + self.prompts = [self.prompt_fn(word) for word in self.words] + self.data = list(zip(self.prompts, self.classes)) + + def get_data(self) -> list[tuple[str, int]]: + return self.data diff --git a/probe_lens/probes.py b/probe_lens/probes.py index ce2b8b7..d50f789 100644 --- a/probe_lens/probes.py +++ b/probe_lens/probes.py @@ -15,7 +15,9 @@ def __init__(self, input_dim, output_dim=1, device="cpu", class_names=None): def forward(self, x): return self.linear(x) - def visualize_performance(self, dataloader: torch.utils.data.DataLoader, test=False): + def visualize_performance( + self, dataloader: torch.utils.data.DataLoader, test=False + ): preds = [] gts = [] for X, y in dataloader: @@ -27,15 +29,28 @@ def visualize_performance(self, dataloader: torch.utils.data.DataLoader, test=Fa gts = torch.cat(gts) accuracy = accuracy_score(gts.cpu(), preds.cpu()) - f2_score = fbeta_score(gts.cpu(), preds.cpu(), beta=2, average='weighted') + f2_score = fbeta_score(gts.cpu(), preds.cpu(), beta=2, average="weighted") cm = confusion_matrix(gts.cpu(), preds.cpu()) plt.figure(figsize=(10, 7)) - _class_names = self.class_names if self.class_names else [str(i) for i in range(cm.shape[0])] - sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=_class_names, yticklabels=_class_names) + _class_names = ( + self.class_names + if self.class_names + else [str(i) for i in range(cm.shape[0])] + ) + sns.heatmap( + cm, + annot=True, + fmt="d", + cmap="Blues", + xticklabels=_class_names, + yticklabels=_class_names, + ) plt.xlabel("Predicted") plt.ylabel("True") - plt.title(f"Confusion Matrix (Accuracy: {accuracy:.4f}, F2 Score: {f2_score:.4f})") + plt.title( + f"Confusion Matrix (Accuracy: {accuracy:.4f}, F2 Score: {f2_score:.4f})" + ) if not test: plt.show() return plt @@ -60,7 +75,10 @@ def train_probe( optimizer.step() loss_sum += loss.item() if val_dataloader is not None: - dataset_names, datasets = ["train", "val"], [dataloader, val_dataloader] + dataset_names, datasets = ( + ["train", "val"], + [dataloader, val_dataloader], + ) else: dataset_names, datasets = ["train"], [dataloader] if verbose and (epoch + 1) % 10 == 0: diff --git a/tests/experiments/test_spelling.py b/tests/experiments/test_spelling.py new file mode 100644 index 0000000..c377278 --- /dev/null +++ b/tests/experiments/test_spelling.py @@ -0,0 +1,9 @@ +from probe_lens.experiments.spelling import LETTERS, FirstLetterSpelling + + +def test_first_letter_spelling(): + words = ["example", "words", "to", "spell"] + spelling_task = FirstLetterSpelling(words) + data = spelling_task.data + classes = [c for _, c in data] + assert classes == [LETTERS.index(word.lower()[0]) for word in words] diff --git a/tests/test_probes.py b/tests/test_probes.py index 2235fcb..16adad3 100644 --- a/tests/test_probes.py +++ b/tests/test_probes.py @@ -77,5 +77,10 @@ def test_linear_probe_visualization(): assert plot is not None # assert isinstance(plot, plt.Figure) accuracy = accuracy_score(y.argmax(dim=1).cpu(), model(x).argmax(dim=1).cpu()) - f2_score = fbeta_score(y.argmax(dim=1).cpu(), model(x).argmax(dim=1).cpu(), beta=2, average='weighted') - assert plot.gca().get_title() == f"Confusion Matrix (Accuracy: {accuracy:.4f}, F2 Score: {f2_score:.4f})" + f2_score = fbeta_score( + y.argmax(dim=1).cpu(), model(x).argmax(dim=1).cpu(), beta=2, average="weighted" + ) + assert ( + plot.gca().get_title() + == f"Confusion Matrix (Accuracy: {accuracy:.4f}, F2 Score: {f2_score:.4f})" + )