Skip to content

Commit

Permalink
Adding another way of calculating gradients in downsamplers (#540)
Browse files Browse the repository at this point in the history
This PR addresses #505.
A flag is introduced, to switch between using purely the last layer
gradients, or the concatenation of last layer gradients and penultimate
layer gradients, in selection algorithms.
  • Loading branch information
XianzheMa authored Jun 25, 2024
1 parent 70ce04c commit 72ffdca
Show file tree
Hide file tree
Showing 18 changed files with 292 additions and 137 deletions.
12 changes: 11 additions & 1 deletion modyn/config/schema/pipeline/sampling/downsampling_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def validate_ratio(self) -> Self:
return self


# These are options to approximate the full gradients used in several selection strategies.
# LastLayer: The full gradient is approximated by the gradient of the last layer.
# LastLayerWithEmbedding: The full gradient is approximated by the gradients of the last layer and the embedding layer.
# They are concatenated and used to represent the full gradient.
FullGradApproximation = Literal["LastLayer", "LastLayerWithEmbedding"]


class UncertaintyDownsamplingConfig(BaseDownsamplingConfig):
"""Config for the Craig downsampling strategy."""

Expand All @@ -74,6 +81,7 @@ class GradMatchDownsamplingConfig(BaseDownsamplingConfig):

strategy: Literal["GradMatch"] = "GradMatch"
balance: bool = Field(False, description="If True, the samples are balanced.")
full_grad_approximation: FullGradApproximation = Field(default="LastLayer")


class CraigDownsamplingConfig(BaseDownsamplingConfig):
Expand All @@ -85,6 +93,7 @@ class CraigDownsamplingConfig(BaseDownsamplingConfig):
greedy: Literal["NaiveGreedy", "LazyGreedy", "StochasticGreedy", "ApproximateLazyGreedy"] = Field(
"NaiveGreedy", description="The greedy strategy to use."
)
full_grad_approximation: FullGradApproximation = Field(default="LastLayer")


class LossDownsamplingConfig(BaseDownsamplingConfig):
Expand All @@ -103,6 +112,7 @@ class SubmodularDownsamplingConfig(BaseDownsamplingConfig):
)
selection_batch: int = Field(64, description="The batch size for the selection.")
balance: bool = Field(False, description="If True, the samples are balanced.")
full_grad_approximation: FullGradApproximation = Field(default="LastLayer")


class GradNormDownsamplingConfig(BaseDownsamplingConfig):
Expand Down Expand Up @@ -144,7 +154,7 @@ class RHOLossDownsamplingConfig(BaseDownsamplingConfig):

strategy: Literal["RHOLoss"] = "RHOLoss"
holdout_set_ratio: int = Field(
description=("How much of the training set is used as the holdout set."),
description="How much of the training set is used as the holdout set.",
min=0,
max=100,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
self.selection_batch = downsampling_config.selection_batch
self.balance = downsampling_config.balance
self.greedy = downsampling_config.greedy
self.full_grad_approximation = downsampling_config.full_grad_approximation
self.remote_downsampling_strategy_name = "RemoteCraigDownsamplingStrategy"

@cached_property
Expand All @@ -25,6 +26,7 @@ def downsampling_params(self) -> dict:
config["selection_batch"] = self.selection_batch
config["balance"] = self.balance
config["greedy"] = self.greedy
config["full_grad_approximation"] = self.full_grad_approximation
if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.")
return config
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ def __init__(
super().__init__(downsampling_config, modyn_config, pipeline_id, maximum_keys_in_memory)
self.balance = downsampling_config.balance
self.remote_downsampling_strategy_name = "RemoteGradMatchDownsamplingStrategy"
self.full_grad_approximation = downsampling_config.full_grad_approximation

@cached_property
def downsampling_params(self) -> dict:
config = super().downsampling_params
config["balance"] = self.balance
config["full_grad_approximation"] = self.full_grad_approximation
if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.")
return config
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
self.submodular_optimizer = downsampling_config.submodular_optimizer
self.selection_batch = downsampling_config.selection_batch
self.balance = downsampling_config.balance
self.full_grad_approximation = downsampling_config.full_grad_approximation
self.remote_downsampling_strategy_name = "RemoteSubmodularDownsamplingStrategy"

@cached_property
Expand All @@ -27,6 +28,7 @@ def downsampling_params(self) -> dict:
config["submodular_optimizer"] = self.submodular_optimizer
config["selection_batch"] = self.selection_batch
config["balance"] = self.balance
config["full_grad_approximation"] = self.full_grad_approximation
if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import patch

import numpy as np
import pytest
import torch
from modyn.config import ModynConfig
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_matrix_downsampling_strategy import (
Expand All @@ -10,7 +11,9 @@
)


def get_sampler_config(dummy_system_config: ModynConfig, balance=False):
def get_sampler_config(
dummy_system_config: ModynConfig, balance=False, matrix_content=MatrixContent.LAST_TWO_LAYERS_GRADIENTS
):
downsampling_ratio = 50
per_sample_loss_fct = torch.nn.CrossEntropyLoss(reduction="none")

Expand All @@ -29,7 +32,7 @@ def get_sampler_config(dummy_system_config: ModynConfig, balance=False):
dummy_system_config.model_dump(by_alias=True),
per_sample_loss_fct,
"cpu",
MatrixContent.GRADIENTS,
matrix_content,
)


Expand All @@ -39,7 +42,7 @@ def test_init(dummy_system_config: ModynConfig):

assert amds.requires_coreset_supporting_module
assert not amds.matrix_elements
assert amds.matrix_content == MatrixContent.GRADIENTS
assert amds.matrix_content == MatrixContent.LAST_TWO_LAYERS_GRADIENTS


@patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set())
Expand Down Expand Up @@ -106,9 +109,12 @@ def test_collect_embedding_balance(test_amds, dummy_system_config: ModynConfig):
assert amds.already_selected_samples == [1, 3, 1000, 1002]


@pytest.mark.parametrize(
"matrix_content", [MatrixContent.LAST_LAYER_GRADIENTS, MatrixContent.LAST_TWO_LAYERS_GRADIENTS]
)
@patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set())
def test_collect_gradients(dummy_system_config: ModynConfig):
amds = AbstractMatrixDownsamplingStrategy(*get_sampler_config(dummy_system_config))
def test_collect_gradients(matrix_content, dummy_system_config: ModynConfig):
amds = AbstractMatrixDownsamplingStrategy(*get_sampler_config(dummy_system_config, matrix_content=matrix_content))
with torch.inference_mode(mode=(not amds.requires_grad)):
forward_input = torch.randn((4, 5))
first_output = torch.randn((4, 2))
Expand All @@ -125,9 +131,14 @@ def test_collect_gradients(dummy_system_config: ModynConfig):

assert len(amds.matrix_elements) == 2

# expected shape = (a,b)
# expected shape = (a, gradient_shape)
# a = 7 (4 samples in the first batch and 3 samples in the second batch)
# b = 5 * 2 + 2 where 5 is the input dimension of the last layer and 2 is the output one
assert np.concatenate(amds.matrix_elements).shape == (7, 12)
if matrix_content == MatrixContent.LAST_LAYER_GRADIENTS:
# shape same as the last dimension of output
gradient_shape = 2
else:
# 5 is the input dimension of the last layer and 2 is the output one
gradient_shape = 5 * 2 + 2
assert np.concatenate(amds.matrix_elements).shape == (7, gradient_shape)

assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41]
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=abstract-class-instantiated
from unittest.mock import patch

import torch
from modyn.config import ModynConfig
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
AbstractRemoteDownsamplingStrategy,
Expand All @@ -21,3 +22,23 @@ def test_batch_then_sample_general(dummy_system_config: ModynConfig):
assert sampler.trigger_id == 128
assert sampler.pipeline_id == 154
assert sampler.batch_size == 64


@patch(
"modyn.trainer_server.internal.trainer.remote_downsamplers"
".abstract_remote_downsampling_strategy.torch.autograd.grad",
wraps=torch.autograd.grad,
)
def test__compute_last_layer_gradient_wrt_loss_sum(mock_torch_auto_grad):
per_sample_loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
forward_output = torch.randn((4, 2), requires_grad=True)
# random target
target = torch.randint(0, 2, (4,))
last_layer_gradients = AbstractRemoteDownsamplingStrategy._compute_last_layer_gradient_wrt_loss_sum(
per_sample_loss_fct, forward_output, target
)
# as we use CrossEntropyLoss, the gradient is computed in a closed form
assert mock_torch_auto_grad.call_count == 0
# verify that the gradients calculated via the closed form are equal to the ones calculated by autograd
expected_grad = torch.autograd.grad(per_sample_loss_fct(forward_output, target).sum(), forward_output)[0]
assert torch.allclose(last_layer_gradients, expected_grad)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=too-many-locals
import numpy as np
import pytest
import torch
from modyn.config import ModynConfig
from modyn.tests.trainer_server.internal.trainer.remote_downsamplers.deepcore_comparison_tests_utils import (
Expand All @@ -10,7 +11,7 @@
from torch.nn import BCEWithLogitsLoss


def get_sampler_config(modyn_config, balance=False):
def get_sampler_config(modyn_config, balance=False, grad_approx="LastLayerWithEmbedding"):
downsampling_ratio = 50
per_sample_loss_fct = torch.nn.CrossEntropyLoss(reduction="none")

Expand All @@ -20,6 +21,7 @@ def get_sampler_config(modyn_config, balance=False):
"balance": balance,
"selection_batch": 64,
"greedy": "NaiveGreedy",
"full_grad_approximation": grad_approx,
"ratio_max": 100,
}
return 0, 0, 0, params_from_selector, modyn_config.model_dump(by_alias=True), per_sample_loss_fct, "cpu"
Expand Down Expand Up @@ -89,8 +91,9 @@ def test_add_to_distance_matrix_large_submatrix(dummy_system_config: ModynConfig
assert np.array_equal(sampler.distance_matrix, expected_result)


def test_inform_end_of_current_label_and_select(dummy_system_config: ModynConfig):
sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config))
@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"])
def test_inform_end_of_current_label_and_select(grad_approx: str, dummy_system_config: ModynConfig):
sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, grad_approx=grad_approx))
with torch.inference_mode(mode=(not sampler.requires_grad)):
sample_ids = [1, 2, 3]
forward_input = torch.randn(3, 5) # 3 samples, 5 input features
Expand Down Expand Up @@ -129,8 +132,9 @@ def test_inform_end_of_current_label_and_select(dummy_system_config: ModynConfig
assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points)


def test_inform_end_of_current_label_and_select_balanced(dummy_system_config: ModynConfig):
sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, True))
@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"])
def test_inform_end_of_current_label_and_select_balanced(grad_approx: str, dummy_system_config: ModynConfig):
sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, True, grad_approx=grad_approx))
with torch.inference_mode(mode=(not sampler.requires_grad)):
sample_ids = [1, 2, 3, 4]
forward_input = torch.randn(4, 5)
Expand Down Expand Up @@ -174,8 +178,9 @@ def test_inform_end_of_current_label_and_select_balanced(dummy_system_config: Mo
assert sum(id in [10, 11, 12, 13, 14, 15] for id in selected_points) == 3


def test_bts(dummy_system_config: ModynConfig):
sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config))
@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"])
def test_bts(grad_approx: str, dummy_system_config: ModynConfig):
sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, grad_approx=grad_approx))
with torch.inference_mode(mode=(not sampler.requires_grad)):
sample_ids = [1, 2, 3, 10, 11, 12, 13]
forward_input = torch.randn(7, 5) # 7 samples, 5 input features
Expand All @@ -200,7 +205,8 @@ def test_bts(dummy_system_config: ModynConfig):
assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points)


def test_bts_equals_stb(dummy_system_config: ModynConfig):
@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"])
def test_bts_equals_stb(grad_approx: str, dummy_system_config: ModynConfig):
# data
sample_ids = [1, 2, 3, 10, 11, 12, 13]
forward_input = torch.randn(7, 5) # 7 samples, 5 input features
Expand All @@ -210,7 +216,7 @@ def test_bts_equals_stb(dummy_system_config: ModynConfig):
embedding = torch.randn(7, 10) # 7 samples, embedding dimension 10

# BTS, all in one call
bts_sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config))
bts_sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, grad_approx=grad_approx))
with torch.inference_mode(mode=(not bts_sampler.requires_grad)):
bts_sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding)

Expand All @@ -219,7 +225,7 @@ def test_bts_equals_stb(dummy_system_config: ModynConfig):
# STB, first class 0 and then class 1
class0 = target == 0
class1 = target == 1
stb_sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config))
stb_sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, grad_approx=grad_approx))
stb_sampler.inform_samples(
[sample_ids[i] for i, keep in enumerate(class0) if keep],
forward_input[class0],
Expand Down Expand Up @@ -348,7 +354,14 @@ def test_matching_results_with_deepcore(dummy_system_config: ModynConfig):
0,
0,
5,
{"downsampling_ratio": 20, "balance": False, "selection_batch": 64, "greedy": "NaiveGreedy", "ratio_max": 100},
{
"downsampling_ratio": 20,
"balance": False,
"selection_batch": 64,
"greedy": "NaiveGreedy",
"full_grad_approximation": "LastLayerWithEmbedding",
"ratio_max": 100,
},
dummy_system_config.model_dump(by_alias=True),
BCEWithLogitsLoss(reduction="none"),
"cpu",
Expand Down Expand Up @@ -403,7 +416,14 @@ def test_matching_results_with_deepcore_permutation(dummy_system_config: ModynCo
0,
0,
5,
{"downsampling_ratio": 30, "balance": False, "selection_batch": 64, "greedy": "NaiveGreedy", "ratio_max": 100},
{
"downsampling_ratio": 30,
"balance": False,
"selection_batch": 64,
"greedy": "NaiveGreedy",
"full_grad_approximation": "LastLayerWithEmbedding",
"ratio_max": 100,
},
dummy_system_config.model_dump(by_alias=True),
BCEWithLogitsLoss(reduction="none"),
"cpu",
Expand Down Expand Up @@ -462,7 +482,14 @@ def test_matching_results_with_deepcore_permutation_fancy_ids(dummy_system_confi
0,
0,
5,
{"downsampling_ratio": 50, "balance": False, "selection_batch": 64, "greedy": "NaiveGreedy", "ratio_max": 100},
{
"downsampling_ratio": 50,
"balance": False,
"selection_batch": 64,
"greedy": "NaiveGreedy",
"full_grad_approximation": "LastLayerWithEmbedding",
"ratio_max": 100,
},
dummy_system_config.model_dump(by_alias=True),
BCEWithLogitsLoss(reduction="none"),
"cpu",
Expand Down
Loading

0 comments on commit 72ffdca

Please sign in to comment.