Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port over miscellaneous changes from SIGMOD branch #582

Merged
merged 16 commits into from
Jul 30, 2024
Merged
1 change: 1 addition & 0 deletions modyn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .resnet50.resnet50 import ResNet50 # noqa: F401
from .resnet152.resnet152 import ResNet152 # noqa: F401
from .rho_loss_twin_model.rho_loss_twin_model import RHOLOSSTwinModel # noqa: F401
from .smallyearbooknet.smallyearbooknet import SmallYearbookNet # noqa: F401
from .yearbooknet.yearbooknet import YearbookNet # noqa: F401

files = os.listdir(os.path.dirname(__file__))
Expand Down
2 changes: 2 additions & 0 deletions modyn/models/resnet152/resnet152.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, model_configuration: dict[str, Any]) -> None:
# We need to initialize the model with the number of classees
# in the pretrained weights
model_configuration["num_classes"] = len(weights.meta["categories"])

if "use_pretrained" in model_configuration: # no matter if True or False
del model_configuration["use_pretrained"] # don't want to forward this to torchvision

super().__init__(Bottleneck, [3, 8, 36, 3], **model_configuration) # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions modyn/models/resnet18/resnet18.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, model_configuration: dict[str, Any]) -> None:
# We need to initialize the model with the number of classees
# in the pretrained weights
model_configuration["num_classes"] = len(weights.meta["categories"])

if "use_pretrained" in model_configuration: # no matter if True or False
del model_configuration["use_pretrained"] # don't want to forward this to torchvision

super().__init__(BasicBlock, [2, 2, 2, 2], **model_configuration) # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions modyn/models/resnet50/resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, model_configuration: dict[str, Any]) -> None:
# We need to initialize the model with the number of classees
# in the pretrained weights
model_configuration["num_classes"] = len(weights.meta["categories"])

if "use_pretrained" in model_configuration: # no matter if True or False
del model_configuration["use_pretrained"] # don't want to forward this to torchvision

super().__init__(Bottleneck, [3, 4, 6, 3], **model_configuration) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion modyn/models/rho_loss_twin_model/rho_loss_twin_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import copy
import logging
from typing import Any, Optional

import torch
import copy
from modyn.utils import dynamic_module_import
from torch import nn

Expand Down
9 changes: 9 additions & 0 deletions modyn/models/smallyearbooknet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Small CNN for Yearbook dataset
"""

import os

files = os.listdir(os.path.dirname(__file__))
files.remove("__init__.py")
__all__ = [f[:-3] for f in files if f.endswith(".py")]
45 changes: 45 additions & 0 deletions modyn/models/smallyearbooknet/smallyearbooknet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Any

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import nn


class SmallYearbookNet:
"""
Adapted from WildTime.
Here you can find the original implementation:
https://github.com/huaxiuyao/Wild-Time/blob/main/wildtime/networks/yearbook.py
Can be used for experiments on RHO-LOSS as the IL model.
"""

# pylint: disable-next=unused-argument
def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None:
self.model = SmallYearbookNetModel(**model_configuration)
self.model.to(device)


class SmallYearbookNetModel(CoresetSupportingModule):
def __init__(self, num_input_channels: int, num_classes: int) -> None:
super().__init__()
self.enc = nn.Sequential(
self.conv_block(num_input_channels, 16),
self.conv_block(16, 16),
self.conv_block(16, 16),
)
self.hid_dim = 16
self.classifier = nn.Linear(16, num_classes)

def conv_block(self, in_channels: int, out_channels: int) -> nn.Module:
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.MaxPool2d(2)
)

def forward(self, data: torch.Tensor) -> torch.Tensor:
data = self.enc(data)
data = torch.mean(data, dim=(2, 3))
data = self.embedding_recorder(data)
return self.classifier(data)

def get_last_layer(self) -> nn.Module:
return self.classifier
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def init_from_path(cls, pipeline_logdir: Path) -> "EvaluationExecutor":
eval_state_config.pipeline,
grpc_handler,
)
executor.grpc.init_cluster_connection()
executor.context = context
return executor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,8 @@ def _handle_triggers(
s.triggers.append(trigger_index)
s.current_sample_index += len(trigger_data)

self.logs.materialize(s.log_directory, mode="increment") # materialize after every trigger

if s.maximum_triggers is not None and len(s.triggers) >= s.maximum_triggers:
break

Expand Down
8 changes: 6 additions & 2 deletions modyn/supervisor/internal/triggers/drift/alibi_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ def detect_drift(
) -> dict[str, MetricResult]:
assert isinstance(embeddings_ref, (np.ndarray, torch.Tensor))
assert isinstance(embeddings_cur, (np.ndarray, torch.Tensor))
embeddings_ref = embeddings_ref.numpy() if isinstance(embeddings_ref, torch.Tensor) else embeddings_ref
embeddings_cur = embeddings_cur.numpy() if isinstance(embeddings_cur, torch.Tensor) else embeddings_cur
embeddings_ref = (
embeddings_ref.detach().cpu().numpy() if isinstance(embeddings_ref, torch.Tensor) else embeddings_ref
)
embeddings_cur = (
embeddings_cur.detach().cpu().numpy() if isinstance(embeddings_cur, torch.Tensor) else embeddings_cur
)

results: dict[str, MetricResult] = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(
base_dir: pathlib.Path,
model_storage_address: str,
):
self.modyn_config = modyn_config
# TODO(MaxiBoether): Update this class to use the model
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
self.modyn_config = modyn_config.model_dump(by_alias=True)
self.pipeline_id = pipeline_id
self.base_dir = base_dir
assert self.base_dir.exists(), f"Temporary Directory {self.base_dir} should have been created."
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# pylint: skip-file

import io
import json
import logging
import os
import pathlib
import threading
from typing import Any, Callable, Generator, Iterator, Optional, Tuple

from modyn.common.benchmark.stopwatch import Stopwatch
from PIL import Image
from torch.utils.data import IterableDataset, get_worker_info
from torchvision import transforms

logger = logging.getLogger(__name__)


class CglmLocalDataset(IterableDataset):
# pylint: disable=too-many-instance-attributes, abstract-method

def __init__(
self,
pipeline_id: int,
trigger_id: int,
dataset_id: str,
bytes_parser: str,
serialized_transforms: list[str],
storage_address: str,
selector_address: str,
training_id: int,
num_prefetched_partitions: int,
parallel_prefetch_requests: int,
shuffle: bool,
tokenizer: Optional[str],
log_path: Optional[pathlib.Path],
):
self._pipeline_id = pipeline_id
self._trigger_id = trigger_id
self._training_id = training_id
self._dataset_id = dataset_id
self._first_call = True
self._num_prefetched_partitions = num_prefetched_partitions
self._parallel_prefetch_requests = parallel_prefetch_requests

self._bytes_parser = bytes_parser
self._serialized_transforms = serialized_transforms
self._storage_address = storage_address
self._selector_address = selector_address
self._transform_list: list[Callable] = []
self._transform: Optional[Callable] = None
self._log_path = log_path
self._log: dict[str, Any] = {"partitions": {}}
self._log_lock: Optional[threading.Lock] = None
self._sw = Stopwatch()
self._cloc_path = "/tmp/cglm"

if log_path is None:
logger.warning("Did not provide log path for CglmDataset - logging disabled.")

logger.debug("Initialized CglmDataset.")

@staticmethod
def bytes_parser_function(data: memoryview) -> Image:
return Image.open(io.BytesIO(data)).convert("RGB")

def _setup_composed_transform(self) -> None:
self._transform_list = [
CglmLocalDataset.bytes_parser_function,
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
self._transform = transforms.Compose(self._transform_list)

def _init_transforms(self) -> None:
self._setup_composed_transform()

def _silence_pil(self) -> None: # pragma: no cover
pil_logger = logging.getLogger("PIL")
pil_logger.setLevel(logging.INFO) # by default, PIL on DEBUG spams the console

def _info(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cover
logger.info(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}")

def _debug(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cover
logger.debug(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}")

def _get_transformed_data_tuple(
self, key: int, sample: memoryview, label: int, weight: Optional[float]
) -> Optional[Tuple]:
self._sw.start("transform", resume=True)
# mypy complains here because _transform has unknown type, which is ok
transformed_sample = self._transform(sample) # type: ignore
self._sw.stop("transform")
return key, transformed_sample, label

def _persist_log(self, worker_id: int) -> None:
if self._log_path is None:
return

assert self._log_lock is not None

with self._log_lock:
if "PYTEST_CURRENT_TEST" in os.environ:
json.dumps(self._log) # Enforce serialization to catch issues
return # But don't actually store in tests

log_file = f"{self._log_path / str(worker_id)}.log"
self._log["transform"] = self._sw.measurements.get("transform", 0)
self._log["wait_for_later_partitions"] = self._sw.measurements.get("wait_for_later_partitions", 0)
self._log["wait_for_initial_partition"] = self._sw.measurements.get("wait_for_initial_partition", 0)

with open(log_file, "w", encoding="utf-8") as logfile:
json.dump(self._log, logfile)

def cloc_generator(
self, worker_id: int, num_workers: int
) -> Iterator[tuple[int, memoryview, int, Optional[float]]]:
self._info("Globbing paths", worker_id)

pathlist = sorted(pathlib.Path(self._cloc_path).glob("*.jpg"))
self._info("Paths globbed", worker_id)

def split(list_to_split: list, split_every: int) -> Any:
k, m = divmod(len(list_to_split), split_every)
return (list_to_split[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(split_every))

pathgen = split(pathlist, num_workers)
worker_paths = next(x for i, x in enumerate(pathgen) if i == worker_id)
self._info(f"Got {len(worker_paths)} paths.", worker_id)

sample_idx = 0
for path in worker_paths:
path = pathlib.Path(path)
label_path = path.with_suffix(".label")

with open(path, "rb") as file:
data = file.read()
with open(label_path, "rb") as file:
label = int(file.read().decode("utf-8"))

yield sample_idx, memoryview(data), label, None
sample_idx = sample_idx + 1

def __iter__(self) -> Generator:
worker_info = get_worker_info()
if worker_info is None:
# Non-multithreaded data loading. We use worker_id 0.
worker_id = 0
num_workers = 1
else:
worker_id = worker_info.id
num_workers = worker_info.num_workers

if self._first_call:
self._first_call = False
self._debug("This is the first run of iter, making gRPC connections.", worker_id)
# We have to initialize transformations and gRPC connections here to do it per dataloader worker,
# otherwise the transformations/gRPC connections cannot be pickled for the new processes.
self._init_transforms()
self._uses_weights = False
self._silence_pil()
self._sw = Stopwatch()
self._log_lock = threading.Lock()

assert self._transform is not None

for data_tuple in self.cloc_generator(worker_id, num_workers):
if (transformed_tuple := self._get_transformed_data_tuple(*data_tuple)) is not None:
yield transformed_tuple

self._persist_log(worker_id)

def end_of_trigger_cleaning(self) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
from modyn.common.benchmark.stopwatch import Stopwatch
from modyn.trainer_server.internal.dataset.binary_file_wrapper import BinaryFileWrapper
from modyn.trainer_server.internal.dataset.extra_local_eval.binary_file_wrapper import BinaryFileWrapper
from torch.utils.data import IterableDataset, get_worker_info
from torchvision import transforms

Expand All @@ -32,6 +32,7 @@ def __init__(
training_id: int,
num_prefetched_partitions: int,
parallel_prefetch_requests: int,
shuffle: bool,
tokenizer: Optional[str],
log_path: Optional[pathlib.Path],
):
Expand Down Expand Up @@ -154,7 +155,7 @@ def __iter__(self) -> Generator:

if self._first_call:
self._first_call = False
self._debug("This is the first run of iter, making gRPC connections.", worker_id)
self._debug("This is the first run of iter", worker_id)
# We have to initialize transformations and gRPC connections here to do it per dataloader worker,
# otherwise the transformations/gRPC connections cannot be pickled for the new processes.
self._init_transforms()
Expand Down
Loading
Loading