Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 178 additions & 91 deletions cyclops/monitor/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
import pandas as pd
import sklearn
from datasets import Dataset, DatasetDict, concatenate_datasets
from datasets.utils.logging import disable_progress_bar
from scipy.special import expit as sigmoid
from scipy.special import softmax
from sklearn.base import BaseEstimator

from cyclops.data.transforms import Lambdad
from cyclops.data.utils import apply_transforms
from cyclops.models.catalog import wrap_model
from cyclops.models.utils import is_pytorch_model, is_sklearn_model
from cyclops.models.wrappers import PTModel, SKModel
from cyclops.monitor.utils import DetectronModule, DummyCriterion, get_args
from cyclops.utils.optional import import_optional_module


disable_progress_bar()


if TYPE_CHECKING:
import torch
from alibi_detect.cd import (
Expand Down Expand Up @@ -705,33 +711,55 @@ def __init__(
self.model = base_model
else:
self.model = model
if isinstance(base_model, nn.Module):
if is_pytorch_model(base_model):
self.base_model = wrap_model(
base_model,
batch_size=batch_size,
)
self.base_model.initialize()
else:
self.base_model.save_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
if transforms:
self.transforms = partial(apply_transforms, transforms=transforms)
model_transforms = transforms
model_transforms.transforms = model_transforms.transforms + (
Lambdad(
keys=("mask", "labels"),
func=lambda x: np.array(x),
allow_missing_keys=True,
),
)
self.model_transforms = partial(
apply_transforms,
transforms=model_transforms,
)
else:
self.transforms = None
self.model_transforms = None
elif is_sklearn_model(base_model):
self.base_model = wrap_model(base_model)
self.base_model.initialize()
self.feature_column = feature_column
if transforms:
self.transforms = partial(apply_transforms, transforms=transforms)
model_transforms = transforms
model_transforms.transforms = model_transforms.transforms + (
Lambdad(
keys=("mask", "labels"),
func=lambda x: np.array(x),
allow_missing_keys=True,
),
self.base_model.save_model(
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
self.model_transforms = partial(
apply_transforms,
transforms=model_transforms,
self.transforms = transforms
self.model_transforms = transforms
elif isinstance(base_model, SKModel):
self.base_model = base_model
self.base_model.save_model(
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
else:
self.transforms = transforms
self.model_transforms = transforms
elif isinstance(base_model, PTModel):
self.base_model = base_model
self.base_model.save_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
else:
raise ValueError("base_model must be a PyTorch or sklearn model.")

self.feature_column = feature_column
self.splits_mapping = splits_mapping
self.num_runs = num_runs
self.sample_size = sample_size
Expand All @@ -741,8 +769,7 @@ def __init__(
self.lr = lr
self.num_workers = num_workers
self.task = task
if save_dir is None:
self.save_dir = "detectron"
self.save_dir = "detectron" if save_dir is None else save_dir

self.fit(X_s)

Expand All @@ -759,24 +786,35 @@ def fit(self, X_s: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):
for seed in range(self.num_runs):
# train ensemble of for split 'p*'
for e in range(1, self.ensemble_size + 1):
if is_pytorch_model(self.base_model.model):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
elif is_sklearn_model(self.base_model.model):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
alpha = 1 / (len(X_s) * self.sample_size + 1)
model = wrap_model(
DetectronModule(
self.model,
feature_column=self.feature_column,
alpha=alpha,
),
batch_size=self.batch_size,
criterion=DummyCriterion,
max_epochs=self.max_epochs_per_model,
lr=self.lr,
num_workers=self.num_workers,
save_dir=self.save_dir,
concatenate_features=False,
)
if is_pytorch_model(self.base_model.model):
model = wrap_model(
DetectronModule(
self.model,
feature_column=self.feature_column,
alpha=alpha,
),
batch_size=self.batch_size,
criterion=DummyCriterion,
max_epochs=self.max_epochs_per_model,
lr=self.lr,
num_workers=self.num_workers,
save_dir=self.save_dir,
concatenate_features=False,
)
model.initialize()
elif is_sklearn_model(self.base_model.model):
model = self.base_model
if isinstance(X_s, (Dataset, DatasetDict)):
# create p/p* splits

p = (
X_s[self.splits_mapping["train"]]
.shuffle()
Expand Down Expand Up @@ -808,26 +846,39 @@ def fit(self, X_s: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):
np.array(pstar_pseudolabels),
)
pstar = pstar.add_column("labels", pstar_pseudolabels.tolist())
if is_sklearn_model(self.base_model.model):
pstar = pstar.map(
lambda x: x.update({"labels": int(1 - x["labels"])})
)

p_pstar = concatenate_datasets([p, pstar], axis=0)
p_pstar = p_pstar.train_test_split(test_size=0.5, shuffle=True)

train_features = [self.feature_column]
train_features.extend(["labels", "mask"])
model.fit(
X=p_pstar,
feature_columns=train_features,
target_columns="mask", # placeholder, not used in dummycriterion
transforms=self.model_transforms,
splits_mapping={"train": "train", "validation": "test"},
)
if is_pytorch_model(self.base_model.model):
train_features = [self.feature_column]
train_features.extend(["labels", "mask"])
model.fit(
X=p_pstar,
feature_columns=train_features,
target_columns="mask", # placeholder, not used in dummycriterion
transforms=self.model_transforms,
splits_mapping={"train": "train", "validation": "test"},
)
model.load_model(
os.path.join(
self.save_dir,
"saved_models/DetectronModule/best_model.pt",
),
log=False,
)
elif is_sklearn_model(self.base_model.model):
model.fit(
X=p_pstar,
feature_columns=self.feature_column,
target_columns="labels",
transforms=self.model_transforms,
)

model.load_model(
os.path.join(
self.save_dir,
"saved_models/DetectronModule/best_model.pt",
),
)
pstar_logits = model.predict(
X=pstar,
feature_columns=self.feature_column,
Expand Down Expand Up @@ -862,22 +913,33 @@ def predict(self, X_t: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):
for seed in range(self.num_runs):
# train ensemble of for split 'p*'
for e in range(1, self.ensemble_size + 1):
if is_pytorch_model(self.base_model.model):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
elif is_sklearn_model(self.base_model.model):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
alpha = 1 / (len(X_t) * self.sample_size + 1)
model = wrap_model(
DetectronModule(
self.model,
feature_column=self.feature_column,
alpha=alpha,
),
batch_size=self.batch_size,
criterion=DummyCriterion,
max_epochs=self.max_epochs_per_model,
lr=self.lr,
num_workers=self.num_workers,
save_dir=self.save_dir,
concatenate_features=False,
)
model.initialize()
if is_pytorch_model(self.base_model.model):
model = wrap_model(
DetectronModule(
self.model,
feature_column=self.feature_column,
alpha=alpha,
),
batch_size=self.batch_size,
criterion=DummyCriterion,
max_epochs=self.max_epochs_per_model,
lr=self.lr,
num_workers=self.num_workers,
save_dir=self.save_dir,
concatenate_features=False,
)
model.initialize()
elif is_sklearn_model(self.base_model.model):
model = self.base_model
if isinstance(X_t, (Dataset, DatasetDict)):
# create p/q splits
p = (
Expand Down Expand Up @@ -908,24 +970,36 @@ def predict(self, X_t: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):
)
q_pseudolabels = self.format_pseudolabels(np.array(q_pseudolabels))
q = q.add_column("labels", q_pseudolabels.tolist())
if is_sklearn_model(self.base_model.model):
q = q.map(lambda x: x.update({"labels": int(1 - x["labels"])}))
p_q = concatenate_datasets([p, q], axis=0)
p_q = p_q.train_test_split(test_size=0.5, shuffle=True)
train_features = [self.feature_column]
train_features.extend(["labels", "mask"])
model.fit(
X=p_q,
feature_columns=train_features,
target_columns="mask", # placeholder, not used in dummycriterion
transforms=self.model_transforms,
splits_mapping={"train": "train", "validation": "test"},
)

model.load_model(
os.path.join(
self.save_dir,
"saved_models/DetectronModule/best_model.pt",
),
)
if is_pytorch_model(self.base_model.model):
train_features = [self.feature_column]
train_features.extend(["labels", "mask"])
model.fit(
X=p_q,
feature_columns=train_features,
target_columns="mask", # placeholder, not used in dummycriterion
transforms=self.model_transforms,
splits_mapping={"train": "train", "validation": "test"},
)

model.load_model(
os.path.join(
self.save_dir,
"saved_models/DetectronModule/best_model.pt",
),
log=False,
)
elif is_sklearn_model(self.base_model.model):
model.fit(
X=p_q,
feature_columns=self.feature_column,
target_columns="labels",
transforms=self.model_transforms,
)
q_logits = model.predict(
X=q,
feature_columns=self.feature_column,
Expand All @@ -950,18 +1024,21 @@ def predict(self, X_t: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):

def format_pseudolabels(self, labels):
"""Format pseudolabels."""
if self.task in ("binary", "multilabel"):
labels = (
(labels > 0.5).astype("float32")
if ((labels <= 1).all() and (labels >= 0).all())
else (sigmoid(labels) > 0.5).astype("float32")
)
elif self.task == "multiclass":
labels = (
labels.argmax(dim=-1)
if np.isclose(labels.sum(axis=-1), 1).all()
else softmax(labels, axis=-1).argmax(axis=-1)
)
if is_pytorch_model(self.base_model.model):
if self.task in ("binary", "multilabel"):
labels = (
(labels > 0.5).astype("float32")
if ((labels <= 1).all() and (labels >= 0).all())
else (sigmoid(labels) > 0.5).astype("float32")
)
elif self.task == "multiclass":
labels = (
labels.argmax(dim=-1)
if np.isclose(labels.sum(axis=-1), 1).all()
else softmax(labels, axis=-1).argmax(axis=-1)
)
elif is_sklearn_model(self.base_model.model):
return labels
else:
raise ValueError(
f"Task must be either 'binary', 'multiclass' or 'multilabel', got {self.task} instead.",
Expand Down Expand Up @@ -1015,15 +1092,25 @@ def get_results(self, max_ensemble_size=None) -> float:
test_count = self.counts("test", max_ensemble_size)[0]
cdf = self.ecdf(cal_counts)
p_value = cdf(test_count)
self.model_health = self.get_model_health(max_ensemble_size)
return {
"data": {
"model_health": self.model_health,
"p_val": p_value,
"distance": test_count,
"cal_record": self.cal_record,
"test_record": self.test_record,
},
}

def get_model_health(self, max_ensemble_size=None) -> float:
"""Get model health."""
self.cal_counts = self.counts("calibration", max_ensemble_size)
self.test_count = self.counts("test", max_ensemble_size)[0]
self.baseline = self.cal_counts.mean()
self.model_health = self.test_count / self.baseline
return min(1, self.model_health)

@staticmethod
def split_dataset(X: Union[Dataset, DatasetDict]) -> DatasetDict:
"""Split dataset into train and test splits."""
Expand Down
Loading