diff --git a/modyn/config/schema/pipeline/sampling/downsampling_config.py b/modyn/config/schema/pipeline/sampling/downsampling_config.py index 8c54e58a4..d088499e5 100644 --- a/modyn/config/schema/pipeline/sampling/downsampling_config.py +++ b/modyn/config/schema/pipeline/sampling/downsampling_config.py @@ -52,6 +52,13 @@ def validate_ratio(self) -> Self: return self +# These are options to approximate the full gradients used in several selection strategies. +# LastLayer: The full gradient is approximated by the gradient of the last layer. +# LastLayerWithEmbedding: The full gradient is approximated by the gradients of the last layer and the embedding layer. +# They are concatenated and used to represent the full gradient. +FullGradApproximation = Literal["LastLayer", "LastLayerWithEmbedding"] + + class UncertaintyDownsamplingConfig(BaseDownsamplingConfig): """Config for the Craig downsampling strategy.""" @@ -74,6 +81,7 @@ class GradMatchDownsamplingConfig(BaseDownsamplingConfig): strategy: Literal["GradMatch"] = "GradMatch" balance: bool = Field(False, description="If True, the samples are balanced.") + full_grad_approximation: FullGradApproximation = Field(default="LastLayer") class CraigDownsamplingConfig(BaseDownsamplingConfig): @@ -85,6 +93,7 @@ class CraigDownsamplingConfig(BaseDownsamplingConfig): greedy: Literal["NaiveGreedy", "LazyGreedy", "StochasticGreedy", "ApproximateLazyGreedy"] = Field( "NaiveGreedy", description="The greedy strategy to use." ) + full_grad_approximation: FullGradApproximation = Field(default="LastLayer") class LossDownsamplingConfig(BaseDownsamplingConfig): @@ -103,6 +112,7 @@ class SubmodularDownsamplingConfig(BaseDownsamplingConfig): ) selection_batch: int = Field(64, description="The batch size for the selection.") balance: bool = Field(False, description="If True, the samples are balanced.") + full_grad_approximation: FullGradApproximation = Field(default="LastLayer") class GradNormDownsamplingConfig(BaseDownsamplingConfig): @@ -144,7 +154,7 @@ class RHOLossDownsamplingConfig(BaseDownsamplingConfig): strategy: Literal["RHOLoss"] = "RHOLoss" holdout_set_ratio: int = Field( - description=("How much of the training set is used as the holdout set."), + description="How much of the training set is used as the holdout set.", min=0, max=100, ) diff --git a/modyn/selector/internal/selector_strategies/downsampling_strategies/craig_downsampling_strategy.py b/modyn/selector/internal/selector_strategies/downsampling_strategies/craig_downsampling_strategy.py index 513aad314..0a2f21fbd 100644 --- a/modyn/selector/internal/selector_strategies/downsampling_strategies/craig_downsampling_strategy.py +++ b/modyn/selector/internal/selector_strategies/downsampling_strategies/craig_downsampling_strategy.py @@ -17,6 +17,7 @@ def __init__( self.selection_batch = downsampling_config.selection_batch self.balance = downsampling_config.balance self.greedy = downsampling_config.greedy + self.full_grad_approximation = downsampling_config.full_grad_approximation self.remote_downsampling_strategy_name = "RemoteCraigDownsamplingStrategy" @cached_property @@ -25,6 +26,7 @@ def downsampling_params(self) -> dict: config["selection_batch"] = self.selection_batch config["balance"] = self.balance config["greedy"] = self.greedy + config["full_grad_approximation"] = self.full_grad_approximation if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.") return config diff --git a/modyn/selector/internal/selector_strategies/downsampling_strategies/gradmatch_downsampling_strategy.py b/modyn/selector/internal/selector_strategies/downsampling_strategies/gradmatch_downsampling_strategy.py index 6c211b988..e266c795a 100644 --- a/modyn/selector/internal/selector_strategies/downsampling_strategies/gradmatch_downsampling_strategy.py +++ b/modyn/selector/internal/selector_strategies/downsampling_strategies/gradmatch_downsampling_strategy.py @@ -16,11 +16,13 @@ def __init__( super().__init__(downsampling_config, modyn_config, pipeline_id, maximum_keys_in_memory) self.balance = downsampling_config.balance self.remote_downsampling_strategy_name = "RemoteGradMatchDownsamplingStrategy" + self.full_grad_approximation = downsampling_config.full_grad_approximation @cached_property def downsampling_params(self) -> dict: config = super().downsampling_params config["balance"] = self.balance + config["full_grad_approximation"] = self.full_grad_approximation if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.") return config diff --git a/modyn/selector/internal/selector_strategies/downsampling_strategies/submodular_downsampling_strategy.py b/modyn/selector/internal/selector_strategies/downsampling_strategies/submodular_downsampling_strategy.py index c519f6059..62db742fd 100644 --- a/modyn/selector/internal/selector_strategies/downsampling_strategies/submodular_downsampling_strategy.py +++ b/modyn/selector/internal/selector_strategies/downsampling_strategies/submodular_downsampling_strategy.py @@ -18,6 +18,7 @@ def __init__( self.submodular_optimizer = downsampling_config.submodular_optimizer self.selection_batch = downsampling_config.selection_batch self.balance = downsampling_config.balance + self.full_grad_approximation = downsampling_config.full_grad_approximation self.remote_downsampling_strategy_name = "RemoteSubmodularDownsamplingStrategy" @cached_property @@ -27,6 +28,7 @@ def downsampling_params(self) -> dict: config["submodular_optimizer"] = self.submodular_optimizer config["selection_batch"] = self.selection_batch config["balance"] = self.balance + config["full_grad_approximation"] = self.full_grad_approximation if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.") diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py index 4745b7f09..a7e8d410e 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py @@ -2,6 +2,7 @@ from unittest.mock import patch import numpy as np +import pytest import torch from modyn.config import ModynConfig from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_matrix_downsampling_strategy import ( @@ -10,7 +11,9 @@ ) -def get_sampler_config(dummy_system_config: ModynConfig, balance=False): +def get_sampler_config( + dummy_system_config: ModynConfig, balance=False, matrix_content=MatrixContent.LAST_TWO_LAYERS_GRADIENTS +): downsampling_ratio = 50 per_sample_loss_fct = torch.nn.CrossEntropyLoss(reduction="none") @@ -29,7 +32,7 @@ def get_sampler_config(dummy_system_config: ModynConfig, balance=False): dummy_system_config.model_dump(by_alias=True), per_sample_loss_fct, "cpu", - MatrixContent.GRADIENTS, + matrix_content, ) @@ -39,7 +42,7 @@ def test_init(dummy_system_config: ModynConfig): assert amds.requires_coreset_supporting_module assert not amds.matrix_elements - assert amds.matrix_content == MatrixContent.GRADIENTS + assert amds.matrix_content == MatrixContent.LAST_TWO_LAYERS_GRADIENTS @patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set()) @@ -106,9 +109,12 @@ def test_collect_embedding_balance(test_amds, dummy_system_config: ModynConfig): assert amds.already_selected_samples == [1, 3, 1000, 1002] +@pytest.mark.parametrize( + "matrix_content", [MatrixContent.LAST_LAYER_GRADIENTS, MatrixContent.LAST_TWO_LAYERS_GRADIENTS] +) @patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set()) -def test_collect_gradients(dummy_system_config: ModynConfig): - amds = AbstractMatrixDownsamplingStrategy(*get_sampler_config(dummy_system_config)) +def test_collect_gradients(matrix_content, dummy_system_config: ModynConfig): + amds = AbstractMatrixDownsamplingStrategy(*get_sampler_config(dummy_system_config, matrix_content=matrix_content)) with torch.inference_mode(mode=(not amds.requires_grad)): forward_input = torch.randn((4, 5)) first_output = torch.randn((4, 2)) @@ -125,9 +131,14 @@ def test_collect_gradients(dummy_system_config: ModynConfig): assert len(amds.matrix_elements) == 2 - # expected shape = (a,b) + # expected shape = (a, gradient_shape) # a = 7 (4 samples in the first batch and 3 samples in the second batch) - # b = 5 * 2 + 2 where 5 is the input dimension of the last layer and 2 is the output one - assert np.concatenate(amds.matrix_elements).shape == (7, 12) + if matrix_content == MatrixContent.LAST_LAYER_GRADIENTS: + # shape same as the last dimension of output + gradient_shape = 2 + else: + # 5 is the input dimension of the last layer and 2 is the output one + gradient_shape = 5 * 2 + 2 + assert np.concatenate(amds.matrix_elements).shape == (7, gradient_shape) assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41] diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_remote_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_remote_downsampling_strategy.py index 8a428e8cc..b52c6de70 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_remote_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_remote_downsampling_strategy.py @@ -1,6 +1,7 @@ # pylint: disable=abstract-class-instantiated from unittest.mock import patch +import torch from modyn.config import ModynConfig from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( AbstractRemoteDownsamplingStrategy, @@ -21,3 +22,23 @@ def test_batch_then_sample_general(dummy_system_config: ModynConfig): assert sampler.trigger_id == 128 assert sampler.pipeline_id == 154 assert sampler.batch_size == 64 + + +@patch( + "modyn.trainer_server.internal.trainer.remote_downsamplers" + ".abstract_remote_downsampling_strategy.torch.autograd.grad", + wraps=torch.autograd.grad, +) +def test__compute_last_layer_gradient_wrt_loss_sum(mock_torch_auto_grad): + per_sample_loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + forward_output = torch.randn((4, 2), requires_grad=True) + # random target + target = torch.randint(0, 2, (4,)) + last_layer_gradients = AbstractRemoteDownsamplingStrategy._compute_last_layer_gradient_wrt_loss_sum( + per_sample_loss_fct, forward_output, target + ) + # as we use CrossEntropyLoss, the gradient is computed in a closed form + assert mock_torch_auto_grad.call_count == 0 + # verify that the gradients calculated via the closed form are equal to the ones calculated by autograd + expected_grad = torch.autograd.grad(per_sample_loss_fct(forward_output, target).sum(), forward_output)[0] + assert torch.allclose(last_layer_gradients, expected_grad) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py index 87719cfef..47f4d6e09 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py @@ -1,5 +1,6 @@ # pylint: disable=too-many-locals import numpy as np +import pytest import torch from modyn.config import ModynConfig from modyn.tests.trainer_server.internal.trainer.remote_downsamplers.deepcore_comparison_tests_utils import ( @@ -10,7 +11,7 @@ from torch.nn import BCEWithLogitsLoss -def get_sampler_config(modyn_config, balance=False): +def get_sampler_config(modyn_config, balance=False, grad_approx="LastLayerWithEmbedding"): downsampling_ratio = 50 per_sample_loss_fct = torch.nn.CrossEntropyLoss(reduction="none") @@ -20,6 +21,7 @@ def get_sampler_config(modyn_config, balance=False): "balance": balance, "selection_batch": 64, "greedy": "NaiveGreedy", + "full_grad_approximation": grad_approx, "ratio_max": 100, } return 0, 0, 0, params_from_selector, modyn_config.model_dump(by_alias=True), per_sample_loss_fct, "cpu" @@ -89,8 +91,9 @@ def test_add_to_distance_matrix_large_submatrix(dummy_system_config: ModynConfig assert np.array_equal(sampler.distance_matrix, expected_result) -def test_inform_end_of_current_label_and_select(dummy_system_config: ModynConfig): - sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config)) +@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"]) +def test_inform_end_of_current_label_and_select(grad_approx: str, dummy_system_config: ModynConfig): + sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, grad_approx=grad_approx)) with torch.inference_mode(mode=(not sampler.requires_grad)): sample_ids = [1, 2, 3] forward_input = torch.randn(3, 5) # 3 samples, 5 input features @@ -129,8 +132,9 @@ def test_inform_end_of_current_label_and_select(dummy_system_config: ModynConfig assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points) -def test_inform_end_of_current_label_and_select_balanced(dummy_system_config: ModynConfig): - sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, True)) +@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"]) +def test_inform_end_of_current_label_and_select_balanced(grad_approx: str, dummy_system_config: ModynConfig): + sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, True, grad_approx=grad_approx)) with torch.inference_mode(mode=(not sampler.requires_grad)): sample_ids = [1, 2, 3, 4] forward_input = torch.randn(4, 5) @@ -174,8 +178,9 @@ def test_inform_end_of_current_label_and_select_balanced(dummy_system_config: Mo assert sum(id in [10, 11, 12, 13, 14, 15] for id in selected_points) == 3 -def test_bts(dummy_system_config: ModynConfig): - sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config)) +@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"]) +def test_bts(grad_approx: str, dummy_system_config: ModynConfig): + sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, grad_approx=grad_approx)) with torch.inference_mode(mode=(not sampler.requires_grad)): sample_ids = [1, 2, 3, 10, 11, 12, 13] forward_input = torch.randn(7, 5) # 7 samples, 5 input features @@ -200,7 +205,8 @@ def test_bts(dummy_system_config: ModynConfig): assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points) -def test_bts_equals_stb(dummy_system_config: ModynConfig): +@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"]) +def test_bts_equals_stb(grad_approx: str, dummy_system_config: ModynConfig): # data sample_ids = [1, 2, 3, 10, 11, 12, 13] forward_input = torch.randn(7, 5) # 7 samples, 5 input features @@ -210,7 +216,7 @@ def test_bts_equals_stb(dummy_system_config: ModynConfig): embedding = torch.randn(7, 10) # 7 samples, embedding dimension 10 # BTS, all in one call - bts_sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config)) + bts_sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, grad_approx=grad_approx)) with torch.inference_mode(mode=(not bts_sampler.requires_grad)): bts_sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding) @@ -219,7 +225,7 @@ def test_bts_equals_stb(dummy_system_config: ModynConfig): # STB, first class 0 and then class 1 class0 = target == 0 class1 = target == 1 - stb_sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config)) + stb_sampler = RemoteCraigDownsamplingStrategy(*get_sampler_config(dummy_system_config, grad_approx=grad_approx)) stb_sampler.inform_samples( [sample_ids[i] for i, keep in enumerate(class0) if keep], forward_input[class0], @@ -348,7 +354,14 @@ def test_matching_results_with_deepcore(dummy_system_config: ModynConfig): 0, 0, 5, - {"downsampling_ratio": 20, "balance": False, "selection_batch": 64, "greedy": "NaiveGreedy", "ratio_max": 100}, + { + "downsampling_ratio": 20, + "balance": False, + "selection_batch": 64, + "greedy": "NaiveGreedy", + "full_grad_approximation": "LastLayerWithEmbedding", + "ratio_max": 100, + }, dummy_system_config.model_dump(by_alias=True), BCEWithLogitsLoss(reduction="none"), "cpu", @@ -403,7 +416,14 @@ def test_matching_results_with_deepcore_permutation(dummy_system_config: ModynCo 0, 0, 5, - {"downsampling_ratio": 30, "balance": False, "selection_batch": 64, "greedy": "NaiveGreedy", "ratio_max": 100}, + { + "downsampling_ratio": 30, + "balance": False, + "selection_batch": 64, + "greedy": "NaiveGreedy", + "full_grad_approximation": "LastLayerWithEmbedding", + "ratio_max": 100, + }, dummy_system_config.model_dump(by_alias=True), BCEWithLogitsLoss(reduction="none"), "cpu", @@ -462,7 +482,14 @@ def test_matching_results_with_deepcore_permutation_fancy_ids(dummy_system_confi 0, 0, 5, - {"downsampling_ratio": 50, "balance": False, "selection_batch": 64, "greedy": "NaiveGreedy", "ratio_max": 100}, + { + "downsampling_ratio": 50, + "balance": False, + "selection_batch": 64, + "greedy": "NaiveGreedy", + "full_grad_approximation": "LastLayerWithEmbedding", + "ratio_max": 100, + }, dummy_system_config.model_dump(by_alias=True), BCEWithLogitsLoss(reduction="none"), "cpu", diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradmatch_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradmatch_downsampling_strategy.py index d6355eee7..f3bc58b08 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradmatch_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradmatch_downsampling_strategy.py @@ -1,6 +1,7 @@ # pylint: disable=too-many-locals import numpy as np +import pytest import torch from modyn.config import ModynConfig from modyn.tests.trainer_server.internal.trainer.remote_downsamplers.deepcore_comparison_tests_utils import DummyModel @@ -10,7 +11,7 @@ from torch.nn import BCEWithLogitsLoss -def get_sampler_config(modyn_config: ModynConfig, balance=False): +def get_sampler_config(modyn_config: ModynConfig, balance=False, grad_approx="LastLayerWithEmbedding"): downsampling_ratio = 50 per_sample_loss_fct = torch.nn.CrossEntropyLoss(reduction="none") @@ -19,13 +20,15 @@ def get_sampler_config(modyn_config: ModynConfig, balance=False): "sample_then_batch": False, "args": {}, "balance": balance, + "full_grad_approximation": grad_approx, "ratio_max": 100, } return 0, 0, 0, params_from_selector, modyn_config.model_dump(by_alias=True), per_sample_loss_fct, "cpu" -def test_select(dummy_system_config: ModynConfig): - sampler = RemoteGradMatchDownsamplingStrategy(*get_sampler_config(dummy_system_config)) +@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"]) +def test_select(grad_approx, dummy_system_config: ModynConfig): + sampler = RemoteGradMatchDownsamplingStrategy(*get_sampler_config(dummy_system_config, grad_approx=grad_approx)) with torch.inference_mode(mode=(not sampler.requires_grad)): sample_ids = [1, 2, 3] forward_input = torch.randn(3, 5) # 3 samples, 5 input features @@ -37,7 +40,11 @@ def test_select(dummy_system_config: ModynConfig): sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding) assert len(sampler.matrix_elements) == 1 - assert sampler.matrix_elements[0].shape == (3, 55) + if grad_approx == "LastLayerWithEmbedding": + grad_feature_size = 55 # dim 5 * 10 + 5 + else: + grad_feature_size = 5 # same as output feature size + assert sampler.matrix_elements[0].shape == (3, grad_feature_size) sample_ids = [10, 11, 12, 13] forward_input = torch.randn(4, 5) # 4 samples, 5 input features @@ -49,8 +56,8 @@ def test_select(dummy_system_config: ModynConfig): sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding) assert len(sampler.matrix_elements) == 2 - assert sampler.matrix_elements[0].shape == (3, 55) - assert sampler.matrix_elements[1].shape == (4, 55) + assert sampler.matrix_elements[0].shape == (3, grad_feature_size) + assert sampler.matrix_elements[1].shape == (4, grad_feature_size) assert sampler.index_sampleid_map == [1, 2, 3, 10, 11, 12, 13] selected_points, selected_weights = sampler.select_points() @@ -61,10 +68,15 @@ def test_select(dummy_system_config: ModynConfig): assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points) -def test_select_balanced(dummy_system_config: ModynConfig): - sampler = RemoteGradMatchDownsamplingStrategy(*get_sampler_config(dummy_system_config, True)) +@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"]) +def test_select_balanced(grad_approx, dummy_system_config: ModynConfig): + sampler = RemoteGradMatchDownsamplingStrategy(*get_sampler_config(dummy_system_config, True, grad_approx)) with torch.inference_mode(mode=(not sampler.requires_grad)): + if grad_approx == "LastLayerWithEmbedding": + grad_feature_size = 55 # dim 5 * 10 + 5 + else: + grad_feature_size = 5 # same as output feature size sample_ids = [1, 2, 3] forward_input = torch.randn(3, 5) # 3 samples, 5 input features forward_output = torch.randn(3, 5) # 3 samples, 5 output classes @@ -75,7 +87,7 @@ def test_select_balanced(dummy_system_config: ModynConfig): sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding) assert len(sampler.matrix_elements) == 1 - assert sampler.matrix_elements[0].shape == (3, 55) + assert sampler.matrix_elements[0].shape == (3, grad_feature_size) sampler.inform_end_of_current_label() assert len(sampler.matrix_elements) == 0 @@ -92,7 +104,7 @@ def test_select_balanced(dummy_system_config: ModynConfig): sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding) assert len(sampler.matrix_elements) == 1 - assert sampler.matrix_elements[0].shape == (4, 55) + assert sampler.matrix_elements[0].shape == (4, grad_feature_size) assert sampler.index_sampleid_map == [10, 11, 12, 13] sampler.inform_end_of_current_label() @@ -186,7 +198,12 @@ def test_matching_results_with_deepcore(dummy_system_config: ModynConfig): 0, 0, 5, - {"downsampling_ratio": 10 * num_of_target_samples, "balance": False, "ratio_max": 100}, + { + "downsampling_ratio": 10 * num_of_target_samples, + "balance": False, + "ratio_max": 100, + "full_grad_approximation": "LastLayerWithEmbedding", + }, dummy_system_config.model_dump(by_alias=True), BCEWithLogitsLoss(reduction="none"), "cpu", @@ -238,7 +255,12 @@ def test_matching_results_with_deepcore_permutation_fancy_ids(dummy_system_confi 0, 0, 5, - {"downsampling_ratio": 50, "balance": False, "ratio_max": 100}, + { + "downsampling_ratio": 50, + "balance": False, + "ratio_max": 100, + "full_grad_approximation": "LastLayerWithEmbedding", + }, dummy_system_config.model_dump(by_alias=True), BCEWithLogitsLoss(reduction="none"), "cpu", diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_submodular_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_submodular_downsampling_strategy.py index 6cc0366cc..f94e2cb0b 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_submodular_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_submodular_downsampling_strategy.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import torch from modyn.config import ModynConfig from modyn.tests.trainer_server.internal.trainer.remote_downsamplers.deepcore_comparison_tests_utils import DummyModel @@ -8,7 +9,9 @@ from torch.nn import BCEWithLogitsLoss -def get_sampler_config(modyn_config: ModynConfig, submodular: str = "GraphCut", balance=False): +def get_sampler_config( + modyn_config: ModynConfig, submodular: str = "GraphCut", balance=False, grad_approx="LastLayerWithEmbedding" +): downsampling_ratio = 50 per_sample_loss_fct = torch.nn.CrossEntropyLoss(reduction="none") @@ -20,24 +23,25 @@ def get_sampler_config(modyn_config: ModynConfig, submodular: str = "GraphCut", "balance": balance, "selection_batch": 64, "ratio_max": 100, + "full_grad_approximation": grad_approx, } return 0, 0, 0, params_from_selector, modyn_config.model_dump(by_alias=True), per_sample_loss_fct, "cpu" -def test_select_different_submodulars(dummy_system_config: ModynConfig): - _test_select_subm(dummy_system_config, "FacilityLocation") - _test_select_subm(dummy_system_config, "GraphCut") - _test_select_subm(dummy_system_config, "LogDeterminant") +@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"]) +@pytest.mark.parametrize("submodular", ["FacilityLocation", "GraphCut", "LogDeterminant"]) +def test_select_different_submodulars(submodular: str, grad_approx: str, dummy_system_config: ModynConfig): + _test_select_subm(dummy_system_config, submodular, grad_approx) -def test_select_different_submodulars_balanced(dummy_system_config: ModynConfig): - _test_select_subm_balance(dummy_system_config, "FacilityLocation") - _test_select_subm_balance(dummy_system_config, "GraphCut") - _test_select_subm_balance(dummy_system_config, "LogDeterminant") +@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"]) +@pytest.mark.parametrize("submodular", ["FacilityLocation", "GraphCut", "LogDeterminant"]) +def test_select_different_submodulars_balanced(submodular: str, grad_approx: str, dummy_system_config: ModynConfig): + _test_select_subm_balance(dummy_system_config, submodular, grad_approx) -def _test_select_subm(modyn_config, submodular, balance=False): - sampler = RemoteSubmodularDownsamplingStrategy(*get_sampler_config(modyn_config, submodular, balance)) +def _test_select_subm(modyn_config, submodular, grad_approx): + sampler = RemoteSubmodularDownsamplingStrategy(*get_sampler_config(modyn_config, submodular, False, grad_approx)) with torch.inference_mode(mode=(not sampler.requires_grad)): sample_ids = [1, 2, 3] forward_input = torch.randn(3, 5) # 3 samples, 5 input features @@ -47,8 +51,12 @@ def _test_select_subm(modyn_config, submodular, balance=False): embedding = torch.randn(3, 10) # 3 samples, embedding dimension 10 sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding) assert len(sampler.matrix_elements) == 1 - # 3 samples of dim 5 * 10 + 5 - assert sampler.matrix_elements[0].shape == (3, 55) + if grad_approx == "LastLayerWithEmbedding": + grad_feature_size = 55 # dim 5 * 10 + 5 + else: + grad_feature_size = 5 # same as output feature size + # 3 samples + assert sampler.matrix_elements[0].shape == (3, grad_feature_size) sample_ids = [10, 11, 12, 13] forward_input = torch.randn(4, 5) # 4 samples, 5 input features forward_output = torch.randn(4, 5) # 4 samples, 5 output classes @@ -57,8 +65,8 @@ def _test_select_subm(modyn_config, submodular, balance=False): embedding = torch.randn(4, 10) # 4 samples, embedding dimension 10 sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding) assert len(sampler.matrix_elements) == 2 - assert sampler.matrix_elements[0].shape == (3, 55) - assert sampler.matrix_elements[1].shape == (4, 55) + assert sampler.matrix_elements[0].shape == (3, grad_feature_size) + assert sampler.matrix_elements[1].shape == (4, grad_feature_size) assert sampler.index_sampleid_map == [1, 2, 3, 10, 11, 12, 13] selected_points, selected_weights = sampler.select_points() assert len(selected_points) == 3 @@ -67,8 +75,8 @@ def _test_select_subm(modyn_config, submodular, balance=False): assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points) -def _test_select_subm_balance(modyn_config, submodular): - sampler = RemoteSubmodularDownsamplingStrategy(*get_sampler_config(modyn_config, submodular, True)) +def _test_select_subm_balance(modyn_config, submodular, grad_approx): + sampler = RemoteSubmodularDownsamplingStrategy(*get_sampler_config(modyn_config, submodular, True, grad_approx)) with torch.inference_mode(mode=(not sampler.requires_grad)): sample_ids = [1, 2, 3] forward_input = torch.randn(3, 5) # 3 samples, 5 input features @@ -78,8 +86,12 @@ def _test_select_subm_balance(modyn_config, submodular): embedding = torch.randn(3, 10) # 3 samples, embedding dimension 10 sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding) assert len(sampler.matrix_elements) == 1 - # 3 samples of dim 5 * 10 + 5 - assert sampler.matrix_elements[0].shape == (3, 55) + if grad_approx == "LastLayerWithEmbedding": + grad_feature_size = 55 # dim 5 * 10 + 5 + else: + grad_feature_size = 5 # same as output feature size + # 3 samples + assert sampler.matrix_elements[0].shape == (3, grad_feature_size) sampler.inform_end_of_current_label() assert len(sampler.already_selected_weights) == 1 @@ -95,7 +107,7 @@ def _test_select_subm_balance(modyn_config, submodular): embedding = torch.randn(4, 10) # 4 samples, embedding dimension 10 sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding) assert len(sampler.matrix_elements) == 1 - assert sampler.matrix_elements[0].shape == (4, 55) + assert sampler.matrix_elements[0].shape == (4, grad_feature_size) assert sampler.index_sampleid_map == [10, 11, 12, 13] sampler.inform_end_of_current_label() @@ -126,6 +138,7 @@ def _get_selected_samples( "balance": False, "selection_batch": 64, "ratio_max": 100, + "full_grad_approximation": "LastLayerWithEmbedding", }, modyn_config.model_dump(by_alias=True), BCEWithLogitsLoss(reduction="none"), diff --git a/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py b/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py index 7291a8c64..a3f2a00f4 100644 --- a/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py +++ b/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py @@ -1067,6 +1067,7 @@ def test_downsample_trigger_training_set_label_by_label( "downsampling_period": 1, "sample_then_batch": True, "balance": True, + "full_grad_approximation": "LastLayer", "ratio_max": 100, }, ), @@ -1128,6 +1129,7 @@ def test_downsample_trigger_training_set( "downsampling_period": 1, "sample_then_batch": True, "balance": False, + "full_grad_approximation": "LastLayer", "ratio_max": 100, }, ), diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py index 51e7a3794..de5e3e4e4 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py @@ -9,7 +9,11 @@ ) from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor -MatrixContent = Enum("MatrixContent", ["EMBEDDINGS", "GRADIENTS"]) + +class MatrixContent(Enum): + EMBEDDINGS = 1 + LAST_LAYER_GRADIENTS = 2 + LAST_TWO_LAYERS_GRADIENTS = 3 class AbstractMatrixDownsamplingStrategy(AbstractPerLabelRemoteDownsamplingStrategy): @@ -33,14 +37,17 @@ def __init__( self.criterion = per_sample_loss - # This class uses the embedding recorder - self.requires_coreset_supporting_module = True self.matrix_elements: list[torch.Tensor] = [] # actual classes must specify which content should be stored. Can be either Gradients or Embeddings. Use the # enum defined above to specify what should be stored self.matrix_content = matrix_content + self.requires_coreset_supporting_module = self.matrix_content in [ + MatrixContent.LAST_TWO_LAYERS_GRADIENTS, + MatrixContent.EMBEDDINGS, + ] + # if true, the downsampling is balanced across classes ex class sizes = [10, 50, 30] and 50% downsampling # yields the following downsampled class sizes [5, 25, 15] while without balance something like [0, 45, 0] can # happen @@ -61,12 +68,24 @@ def inform_samples( target: torch.Tensor, embedding: Optional[torch.Tensor] = None, ) -> None: - assert embedding is not None + batch_size = len(sample_ids) assert self.matrix_content is not None - - if self.matrix_content == MatrixContent.GRADIENTS: - new_elements = self._compute_gradients(forward_output, target, embedding) + if self.matrix_content == MatrixContent.LAST_LAYER_GRADIENTS: + grads_wrt_loss_sum = self._compute_last_layer_gradient_wrt_loss_sum(self.criterion, forward_output, target) + grads_wrt_loss_mean = grads_wrt_loss_sum / batch_size + new_elements = grads_wrt_loss_mean.detach().cpu() + elif self.matrix_content == MatrixContent.LAST_TWO_LAYERS_GRADIENTS: + assert embedding is not None + # using the gradients w.r.t. the sum of the loss or the mean of the loss does not make a difference + # since the scaling factor is the same for all samples. We use mean here to pass the unit test + # containing the hard-coded values from deepcore + grads_wrt_loss_sum = self._compute_last_two_layers_gradient_wrt_loss_sum( + self.criterion, forward_output, target, embedding + ) + grads_wrt_loss_mean = grads_wrt_loss_sum / batch_size + new_elements = grads_wrt_loss_mean.detach().cpu() elif self.matrix_content == MatrixContent.EMBEDDINGS: + assert embedding is not None new_elements = embedding.detach().cpu() else: raise AssertionError("The required content does not exits.") @@ -75,22 +94,6 @@ def inform_samples( # keep the mapping index<->sample_id self.index_sampleid_map += sample_ids - def _compute_gradients( - self, forward_output: torch.Tensor, target: torch.Tensor, embedding: torch.Tensor - ) -> torch.Tensor: - loss = self.criterion(forward_output, target).mean() - embedding_dim = embedding.shape[1] - num_classes = forward_output.shape[1] - batch_num = target.shape[0] - # compute the gradient for each element provided - with torch.no_grad(): - bias_parameters_grads = torch.autograd.grad(loss, forward_output)[0] - weight_parameters_grads = embedding.view(batch_num, 1, embedding_dim).repeat( - 1, num_classes, 1 - ) * bias_parameters_grads.view(batch_num, num_classes, 1).repeat(1, 1, embedding_dim) - gradients = torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)], dim=1).cpu().numpy() - return gradients - def inform_end_of_current_label(self) -> None: assert self.balance selected_samples, selected_weights = self._select_from_matrix() @@ -127,4 +130,4 @@ def _select_indexes_from_matrix(self, matrix: np.ndarray, target_size: int) -> t @property def requires_grad(self) -> bool: # Default to true if None - return self.matrix_content == MatrixContent.GRADIENTS + return self.matrix_content in [MatrixContent.LAST_LAYER_GRADIENTS, MatrixContent.LAST_TWO_LAYERS_GRADIENTS] diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_per_label_remote_downsample_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_per_label_remote_downsample_strategy.py index c654a537a..3696f4d76 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_per_label_remote_downsample_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_per_label_remote_downsample_strategy.py @@ -6,17 +6,6 @@ class AbstractPerLabelRemoteDownsamplingStrategy(AbstractRemoteDownsamplingStrategy): - def __init__( - self, - pipeline_id: int, - trigger_id: int, - batch_size: int, - params_from_selector: dict, - modyn_config: dict, - device: str, - ): - super().__init__(pipeline_id, trigger_id, batch_size, params_from_selector, modyn_config, device) - self.requires_data_label_by_label = True @abstractmethod def inform_end_of_current_label(self) -> None: diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py index e1dc49d28..733a6518c 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import Any, Optional, Union import torch +FULL_GRAD_APPROXIMATION = ["LastLayer", "LastLayerWithEmbedding"] + def get_tensors_subset( selected_indexes: list[int], data: Union[torch.Tensor, dict], target: torch.Tensor, sample_ids: list @@ -101,3 +103,51 @@ def select_points(self) -> tuple[list[int], torch.Tensor]: @abstractmethod def requires_grad(self) -> bool: raise NotImplementedError + + @staticmethod + def _compute_last_layer_gradient_wrt_loss_sum( + per_sample_loss_fct: Any, forward_output: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute the gradient of the last layer with respect to the sum of the loss. + Note: if the gradient with respect to the mean of the loss is needed, the result of this function should be + divided by the number of samples in the batch. + """ + if isinstance(per_sample_loss_fct, torch.nn.CrossEntropyLoss): + # no need to autograd if cross entropy loss is used since closed form solution exists. + # Because CrossEntropyLoss includes the softmax, we need to apply the + # softmax to the forward output to obtain the probabilities + probs = torch.nn.functional.softmax(forward_output, dim=1) + num_classes = forward_output.shape[-1] + + # Pylint complains torch.nn.functional.one_hot is not callable for whatever reason + one_hot_targets = torch.nn.functional.one_hot( # pylint: disable=not-callable + target, num_classes=num_classes + ) + last_layer_gradients = probs - one_hot_targets + else: + sample_losses = per_sample_loss_fct(forward_output, target) + last_layer_gradients = torch.autograd.grad(sample_losses.sum(), forward_output, retain_graph=False)[0] + return last_layer_gradients + + @staticmethod + def _compute_last_two_layers_gradient_wrt_loss_sum( + per_sample_loss_fct: Any, forward_output: torch.Tensor, target: torch.Tensor, embedding: torch.Tensor + ) -> torch.Tensor: + """ + Compute the gradient of the last two layers with respect to the sum of the loss. + Note: if the gradient with respect to the mean of the loss is needed, the result of this function should be + divided by the number of samples in the batch. + """ + loss = per_sample_loss_fct(forward_output, target).sum() + embedding_dim = embedding.shape[1] + num_classes = forward_output.shape[1] + batch_num = target.shape[0] + + with torch.no_grad(): + bias_parameters_grads = torch.autograd.grad(loss, forward_output)[0] + weight_parameters_grads = embedding.view(batch_num, 1, embedding_dim).repeat( + 1, num_classes, 1 + ) * bias_parameters_grads.view(batch_num, num_classes, 1).repeat(1, 1, embedding_dim) + gradients = torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)], dim=1) + return gradients diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py index fee662dad..0d40f3741 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py @@ -6,6 +6,9 @@ from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_per_label_remote_downsample_strategy import ( AbstractPerLabelRemoteDownsamplingStrategy, ) +from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( + FULL_GRAD_APPROXIMATION, +) from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils import submodular_optimizer from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.euclidean import euclidean_dist_pair_np from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor @@ -43,13 +46,16 @@ def __init__( self.selection_batch = params_from_selector["selection_batch"] self.greedy = params_from_selector["greedy"] + self.full_grad_approximation = params_from_selector["full_grad_approximation"] + assert self.full_grad_approximation in FULL_GRAD_APPROXIMATION + if self.greedy not in OPTIMIZER_CHOICES: raise ValueError( f"The required Greedy optimizer is not available. Pick one of the following: {OPTIMIZER_CHOICES}" ) - # This class uses the embedding recorder - self.requires_coreset_supporting_module = True + self.requires_coreset_supporting_module = self.full_grad_approximation == "LastLayerWithEmbedding" + self.requires_data_label_by_label = True # Samples are supplied label by label (this class is instance of AbstractPerLabelRemoteDownsamplingStrategy). # The following list keeps the gradients of the current label. When all the samples belonging to the current @@ -75,8 +81,6 @@ def inform_samples( target: torch.Tensor, embedding: Optional[torch.Tensor] = None, ) -> None: - assert embedding is not None - # Slightly different implementation for BTS and STB since in STB points are supplied class by class while in # BTS are not. STB will always use the first branch, BTS will typically (might use the first if all the points # belong to the same class) use the second one @@ -92,32 +96,29 @@ def inform_samples( for current_target in different_targets_in_this_batch: mask = target == current_target this_target_sample_ids = [sample_ids[i] for i, keep in enumerate(mask) if keep] + sub_embedding = embedding[mask] if embedding is not None else None self._inform_samples_single_class( - this_target_sample_ids, forward_output[mask], target[mask], embedding[mask] + this_target_sample_ids, forward_output[mask], target[mask], sub_embedding ) self.inform_end_of_current_label() def _inform_samples_single_class( - self, sample_ids: list[int], forward_output: torch.Tensor, target: torch.Tensor, embedding: torch.Tensor + self, + sample_ids: list[int], + forward_output: torch.Tensor, + target: torch.Tensor, + embedding: Optional[torch.Tensor], ) -> None: - embedding_dim = embedding.shape[1] - num_classes = forward_output.shape[1] - batch_num = target.shape[0] - - loss = self.criterion(forward_output, target).mean() - - # compute the gradient for each element provided - with torch.no_grad(): - bias_parameters_grads = torch.autograd.grad(loss, forward_output)[0] - weight_parameters_grads = embedding.view(batch_num, 1, embedding_dim).repeat( - 1, num_classes, 1 - ) * bias_parameters_grads.view(batch_num, num_classes, 1).repeat(1, 1, embedding_dim) - - # store the computed gradients - self.current_class_gradients.append( - torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)], dim=1).cpu().numpy() + if self.full_grad_approximation == "LastLayerWithEmbedding": + assert embedding is not None + grads_wrt_loss_sum = self._compute_last_two_layers_gradient_wrt_loss_sum( + self.criterion, forward_output, target, embedding ) - + else: + grads_wrt_loss_sum = self._compute_last_layer_gradient_wrt_loss_sum(self.criterion, forward_output, target) + batch_num = target.shape[0] + grads_wrt_loss_mean = grads_wrt_loss_sum / batch_num + self.current_class_gradients.append(grads_wrt_loss_mean.detach().cpu().numpy()) # keep the mapping index<->sample_id self.index_sampleid_map += sample_ids diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_grad_match_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_grad_match_downsampling_strategy.py index e725f974e..e1ed508e6 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_grad_match_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_grad_match_downsampling_strategy.py @@ -6,6 +6,9 @@ AbstractMatrixDownsamplingStrategy, MatrixContent, ) +from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( + FULL_GRAD_APPROXIMATION, +) from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.orthogonal_matching_pursuit import ( orthogonal_matching_pursuit, orthogonal_matching_pursuit_np, @@ -36,6 +39,8 @@ def __init__( per_sample_loss: Any, device: str, ): + self.full_grad_approximation = params_from_selector["full_grad_approximation"] + assert self.full_grad_approximation in FULL_GRAD_APPROXIMATION super().__init__( pipeline_id, trigger_id, @@ -44,7 +49,11 @@ def __init__( modyn_config, per_sample_loss, device, - MatrixContent.GRADIENTS, + ( + MatrixContent.LAST_LAYER_GRADIENTS + if self.full_grad_approximation == "LastLayer" + else MatrixContent.LAST_TWO_LAYERS_GRADIENTS + ), ) def _select_indexes_from_matrix(self, matrix: np.ndarray, target_size: int) -> tuple[list[int], torch.Tensor]: diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py index ac456f6a6..a4558d8fa 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py @@ -35,26 +35,6 @@ def __init__( self.probabilities: list[torch.Tensor] = [] self.number_of_points_seen = 0 - def get_scores(self, forward_output: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - if isinstance(self.per_sample_loss_fct, torch.nn.CrossEntropyLoss): - # no need to autograd if cross entropy loss is used since closed form solution exists. - # Because CrossEntropyLoss includes the softmax, we need to apply the - # softmax to the forward output to obtain the probabilities - probs = torch.nn.functional.softmax(forward_output, dim=1) - num_classes = forward_output.shape[-1] - - # Pylint complains torch.nn.functional.one_hot is not callable for whatever reason - one_hot_targets = torch.nn.functional.one_hot( # pylint: disable=not-callable - target, num_classes=num_classes - ) - scores = torch.norm(probs - one_hot_targets, dim=-1) - else: - sample_losses = self.per_sample_loss_fct(forward_output, target) - last_layer_gradients = torch.autograd.grad(sample_losses.sum(), forward_output, retain_graph=False)[0] - scores = torch.norm(last_layer_gradients, dim=-1) - - return scores.cpu() - def init_downsampler(self) -> None: self.probabilities = [] self.index_sampleid_map: list[int] = [] @@ -68,7 +48,10 @@ def inform_samples( target: torch.Tensor, embedding: Optional[torch.Tensor] = None, ) -> None: - scores = self.get_scores(forward_output, target) + last_layer_gradients = self._compute_last_layer_gradient_wrt_loss_sum( + self.per_sample_loss_fct, forward_output, target + ) + scores = torch.norm(last_layer_gradients, dim=-1).cpu() self.probabilities.append(scores) self.number_of_points_seen += forward_output.shape[0] self.index_sampleid_map += sample_ids diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_kcenter_greedy_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_kcenter_greedy_downsampling_strategy.py index 5b721fb18..65f484801 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_kcenter_greedy_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_kcenter_greedy_downsampling_strategy.py @@ -42,7 +42,6 @@ def __init__( device, MatrixContent.EMBEDDINGS, ) - self.metric = euclidean_dist def _select_indexes_from_matrix(self, matrix: np.ndarray, target_size: int) -> tuple[list[int], torch.Tensor]: diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_submodular_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_submodular_downsampling_strategy.py index 939dc2c2a..3ca3cd337 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_submodular_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_submodular_downsampling_strategy.py @@ -7,6 +7,9 @@ AbstractMatrixDownsamplingStrategy, MatrixContent, ) +from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( + FULL_GRAD_APPROXIMATION, +) from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils import ( submodular_function, submodular_optimizer, @@ -43,6 +46,8 @@ def __init__( per_sample_loss: Any, device: str, ): + self.full_grad_approximation = params_from_selector["full_grad_approximation"] + assert self.full_grad_approximation in FULL_GRAD_APPROXIMATION super().__init__( pipeline_id, trigger_id, @@ -51,7 +56,11 @@ def __init__( modyn_config, per_sample_loss, device, - MatrixContent.GRADIENTS, + ( + MatrixContent.LAST_LAYER_GRADIENTS + if self.full_grad_approximation == "LastLayer" + else MatrixContent.LAST_TWO_LAYERS_GRADIENTS + ), ) self.selection_batch = params_from_selector["selection_batch"]