-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: Miscellaneous updates mostly wrt. offline drift exploration (#592
- Loading branch information
1 parent
5d1b088
commit 7130d6b
Showing
10 changed files
with
229 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,10 +79,10 @@ cmake-build-debug/ | |
clang-tidy-build/ | ||
libbuild/ | ||
|
||
|
||
# Data & config files | ||
|
||
.data/ | ||
.debug/ | ||
.env | ||
|
||
exploration/ |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
modyn/config/schema/pipeline/trigger/drift/preprocess/alibi_detect.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from collections.abc import Callable | ||
from functools import partial | ||
|
||
from alibi_detect.cd.pytorch import preprocess_drift | ||
from alibi_detect.models.pytorch import TransformerEmbedding | ||
from pydantic import Field | ||
from transformers import AutoTokenizer | ||
|
||
from modyn.config.schema.base_model import ModynBaseModel | ||
|
||
|
||
class AlibiDetectNLPreprocessor(ModynBaseModel): | ||
tokenizer_model: str = Field(description="AutoTokenizer pretrained model name. E.g. bert-base-cased") | ||
n_layers: int = Field(8) | ||
max_len: int = Field(..., description="Maximum length of input token sequences.") | ||
batch_size: int = Field(32, description="Batch size for tokenization.") | ||
|
||
def gen_preprocess_fn(self, device: str | None) -> Callable: | ||
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_model) | ||
emb_type = "hidden_state" | ||
layers = [-_ for _ in range(1, self.n_layers + 1)] | ||
|
||
embedding = TransformerEmbedding(self.tokenizer_model, emb_type, layers) | ||
if device: | ||
embedding = embedding.to(device) | ||
embedding = embedding.eval() | ||
|
||
return partial( | ||
preprocess_drift, model=embedding, tokenizer=tokenizer, max_len=self.max_len, batch_size=self.batch_size | ||
) |
5 changes: 5 additions & 0 deletions
5
modyn/supervisor/internal/triggers/drift/classifier_models/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from modyn.supervisor.internal.triggers.drift.classifier_models.ybnet_classifier import YearbookNetDriftDetector | ||
|
||
alibi_classifier_models = { | ||
"ybnet": YearbookNetDriftDetector(3), | ||
} |
36 changes: 36 additions & 0 deletions
36
modyn/supervisor/internal/triggers/drift/classifier_models/ybnet_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from modyn.models.coreset_methods_support import CoresetSupportingModule | ||
|
||
|
||
class YearbookNetDriftDetector(CoresetSupportingModule): | ||
def __init__(self, num_input_channels: int) -> None: | ||
super().__init__() | ||
self.enc = nn.Sequential( | ||
self.conv_block(num_input_channels, 32), | ||
self.conv_block(32, 32), | ||
self.conv_block(32, 32), | ||
self.conv_block(32, 32), | ||
) | ||
self.hid_dim = 32 | ||
# Binary classifier for drift detection | ||
# see: https://docs.seldon.io/projects/alibi-detect/en/latest/cd/methods/classifierdrift.html | ||
self.classifier = nn.Sequential(nn.Flatten(), nn.Linear(32, 2)) | ||
|
||
def conv_block(self, in_channels: int, out_channels: int) -> nn.Module: | ||
return nn.Sequential( | ||
nn.Conv2d(in_channels, out_channels, 3, padding=1), | ||
nn.BatchNorm2d(out_channels), | ||
nn.ReLU(), | ||
nn.MaxPool2d(2), | ||
) | ||
|
||
def forward(self, data: torch.Tensor) -> torch.Tensor: | ||
data = self.enc(data) | ||
data = torch.mean(data, dim=(2, 3)) | ||
data = self.classifier(data) | ||
return data | ||
|
||
def get_last_layer(self) -> nn.Module: | ||
return self.classifier |
Oops, something went wrong.