diff --git a/CI/unit_tests/ntk_computation/test_jax_ntk.py b/CI/unit_tests/analysis/test_jax_ntk.py similarity index 99% rename from CI/unit_tests/ntk_computation/test_jax_ntk.py rename to CI/unit_tests/analysis/test_jax_ntk.py index 8c1c046..1836bf0 100644 --- a/CI/unit_tests/ntk_computation/test_jax_ntk.py +++ b/CI/unit_tests/analysis/test_jax_ntk.py @@ -31,8 +31,8 @@ from flax import linen as nn from jax import random +from znnl.analysis import JAXNTKComputation from znnl.models import FlaxModel -from znnl.ntk_computation import JAXNTKComputation class FlaxTestModule(nn.Module): diff --git a/CI/unit_tests/ntk_computation/test_jax_ntk_classwise.py b/CI/unit_tests/analysis/test_jax_ntk_classwise.py similarity index 99% rename from CI/unit_tests/ntk_computation/test_jax_ntk_classwise.py rename to CI/unit_tests/analysis/test_jax_ntk_classwise.py index 4313af6..f25f0c2 100644 --- a/CI/unit_tests/ntk_computation/test_jax_ntk_classwise.py +++ b/CI/unit_tests/analysis/test_jax_ntk_classwise.py @@ -30,8 +30,8 @@ from flax import linen as nn from jax import random +from znnl.analysis import JAXNTKClassWise from znnl.models import FlaxModel -from znnl.ntk_computation import JAXNTKClassWise class FlaxTestModule(nn.Module): diff --git a/CI/unit_tests/ntk_computation/test_jax_ntk_combinations.py b/CI/unit_tests/analysis/test_jax_ntk_combinations.py similarity index 99% rename from CI/unit_tests/ntk_computation/test_jax_ntk_combinations.py rename to CI/unit_tests/analysis/test_jax_ntk_combinations.py index 6130e08..3258526 100644 --- a/CI/unit_tests/ntk_computation/test_jax_ntk_combinations.py +++ b/CI/unit_tests/analysis/test_jax_ntk_combinations.py @@ -32,8 +32,8 @@ from jax import random from papyrus.utils.matrix_utils import flatten_rank_4_tensor +from znnl.analysis import JAXNTKCombinations from znnl.models import FlaxModel -from znnl.ntk_computation import JAXNTKCombinations class FlaxTestModule(nn.Module): diff --git a/CI/unit_tests/ntk_computation/test_jax_ntk_subsampling.py b/CI/unit_tests/analysis/test_jax_ntk_subsampling.py similarity index 98% rename from CI/unit_tests/ntk_computation/test_jax_ntk_subsampling.py rename to CI/unit_tests/analysis/test_jax_ntk_subsampling.py index 14e7dfa..f425710 100644 --- a/CI/unit_tests/ntk_computation/test_jax_ntk_subsampling.py +++ b/CI/unit_tests/analysis/test_jax_ntk_subsampling.py @@ -30,8 +30,8 @@ from flax import linen as nn from jax import random +from znnl.analysis import JAXNTKSubsampling from znnl.models import FlaxModel -from znnl.ntk_computation import JAXNTKSubsampling class FlaxTestModule(nn.Module): diff --git a/CI/unit_tests/models/_test_huggingface_flax_model.py b/CI/unit_tests/models/_test_huggingface_flax_model.py index 1d66bdd..fc7411f 100644 --- a/CI/unit_tests/models/_test_huggingface_flax_model.py +++ b/CI/unit_tests/models/_test_huggingface_flax_model.py @@ -29,8 +29,8 @@ from jax import random from transformers import FlaxResNetForImageClassification, ResNetConfig +from znnl.analysis import JAXNTKComputation from znnl.models import HuggingFaceFlaxModel -from znnl.ntk_computation import JAXNTKComputation class TestFlaxHFModule: diff --git a/CI/unit_tests/models/test_flax_model.py b/CI/unit_tests/models/test_flax_model.py index 249b39c..05db29b 100644 --- a/CI/unit_tests/models/test_flax_model.py +++ b/CI/unit_tests/models/test_flax_model.py @@ -34,8 +34,8 @@ from flax import linen as nn from jax import random +from znnl.analysis import JAXNTKComputation from znnl.models import FlaxModel -from znnl.ntk_computation import JAXNTKComputation class FlaxTestModule(nn.Module): diff --git a/CI/unit_tests/models/test_nt_model.py b/CI/unit_tests/models/test_nt_model.py index fff995e..b1114da 100644 --- a/CI/unit_tests/models/test_nt_model.py +++ b/CI/unit_tests/models/test_nt_model.py @@ -34,8 +34,8 @@ from jax import random from neural_tangents import stax +from znnl.analysis import JAXNTKComputation from znnl.models import NTModel -from znnl.ntk_computation import JAXNTKComputation class TestNTModule: diff --git a/CI/unit_tests/optimizers/test_trace_optimizer.py b/CI/unit_tests/optimizers/test_trace_optimizer.py index 943cf3b..55b5e08 100644 --- a/CI/unit_tests/optimizers/test_trace_optimizer.py +++ b/CI/unit_tests/optimizers/test_trace_optimizer.py @@ -32,9 +32,9 @@ import jax.numpy as np from neural_tangents import stax +from znnl.analysis import JAXNTKComputation from znnl.data import MNISTGenerator from znnl.models import NTModel -from znnl.ntk_computation import JAXNTKComputation from znnl.optimizers import TraceOptimizer diff --git a/CI/unit_tests/training_recording/test_jax_recorder.py b/CI/unit_tests/training_recording/test_jax_recorder.py index 27aecb6..fc061d7 100644 --- a/CI/unit_tests/training_recording/test_jax_recorder.py +++ b/CI/unit_tests/training_recording/test_jax_recorder.py @@ -33,8 +33,8 @@ from numpy.testing import assert_raises from papyrus.measurements import Accuracy, Loss, NTKTrace +from znnl.analysis import JAXNTKComputation from znnl.models import FlaxModel -from znnl.ntk_computation import JAXNTKComputation from znnl.training_recording import JaxRecorder diff --git a/examples/Computing-Collective-Variables.ipynb b/examples/Computing-Collective-Variables.ipynb index 3309690..23928ea 100644 --- a/examples/Computing-Collective-Variables.ipynb +++ b/examples/Computing-Collective-Variables.ipynb @@ -165,7 +165,7 @@ "metadata": {}, "outputs": [], "source": [ - "ntk_computation = nl.ntk_computation.JAXNTKComputation(\n", + "ntk_computation = nl.analysis.JAXNTKComputation(\n", " apply_fn=fuel_model.ntk_apply_fn, \n", " batch_size=314,\n", ")\n", diff --git a/examples/Contrastive-Loss.ipynb b/examples/Contrastive-Loss.ipynb index 1f625e5..a13117a 100644 --- a/examples/Contrastive-Loss.ipynb +++ b/examples/Contrastive-Loss.ipynb @@ -192,7 +192,7 @@ " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", - "ntk_computation = znnl.ntk_computation.JAXNTKComputation(\n", + "ntk_computation = znnl.analysis.JAXNTKComputation(\n", " apply_fn=model.ntk_apply_fn, \n", " batch_size=10,\n", ")\n", @@ -461,7 +461,7 @@ " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", - "ntk_computation = znnl.ntk_computation.JAXNTKComputation(\n", + "ntk_computation = znnl.analysis.JAXNTKComputation(\n", " apply_fn=model.ntk_apply_fn, \n", " batch_size=10,\n", ")\n", diff --git a/examples/ResNet-Example.ipynb b/examples/ResNet-Example.ipynb index 6e6ae47..d274030 100644 --- a/examples/ResNet-Example.ipynb +++ b/examples/ResNet-Example.ipynb @@ -174,7 +174,7 @@ " ],\n", " update_rate=1, \n", ")\n", - "ntk_computation = nl.ntk_computation.JAXNTKComputation(\n", + "ntk_computation = nl.analysis.JAXNTKComputation(\n", " apply_fn=model.ntk_apply_fn, \n", " batch_size=10, \n", ")\n", diff --git a/znnl/agents/approximate_maximum_entropy.py b/znnl/agents/approximate_maximum_entropy.py index 3a2272f..c44c25b 100644 --- a/znnl/agents/approximate_maximum_entropy.py +++ b/znnl/agents/approximate_maximum_entropy.py @@ -30,9 +30,9 @@ from znnl.agents.agent import Agent from znnl.analysis.entropy import EntropyAnalysis +from znnl.analysis.jax_ntk import JAXNTKComputation from znnl.data import DataGenerator from znnl.models import JaxModel -from znnl.ntk_computation.jax_ntk import JAXNTKComputation from znnl.utils.prng import PRNGKey diff --git a/znnl/analysis/__init__.py b/znnl/analysis/__init__.py index 94701c5..d8b6c50 100644 --- a/znnl/analysis/__init__.py +++ b/znnl/analysis/__init__.py @@ -27,10 +27,18 @@ from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis +from znnl.analysis.jax_ntk import JAXNTKComputation +from znnl.analysis.jax_ntk_classwise import JAXNTKClassWise +from znnl.analysis.jax_ntk_combinations import JAXNTKCombinations +from znnl.analysis.jax_ntk_subsampling import JAXNTKSubsampling from znnl.analysis.loss_fn_derivative import LossDerivative __all__ = [ EntropyAnalysis.__name__, EigenSpaceAnalysis.__name__, LossDerivative.__name__, + JAXNTKComputation.__name__, + JAXNTKClassWise.__name__, + JAXNTKSubsampling.__name__, + JAXNTKCombinations.__name__, ] diff --git a/znnl/ntk_computation/jax_ntk.py b/znnl/analysis/jax_ntk.py similarity index 100% rename from znnl/ntk_computation/jax_ntk.py rename to znnl/analysis/jax_ntk.py diff --git a/znnl/ntk_computation/jax_ntk_classwise.py b/znnl/analysis/jax_ntk_classwise.py similarity index 99% rename from znnl/ntk_computation/jax_ntk_classwise.py rename to znnl/analysis/jax_ntk_classwise.py index 6b05c7f..e6f6123 100644 --- a/znnl/ntk_computation/jax_ntk_classwise.py +++ b/znnl/analysis/jax_ntk_classwise.py @@ -32,7 +32,7 @@ from jax import random, vmap from jax.tree_util import tree_map as jmap -from znnl.ntk_computation.jax_ntk import JAXNTKComputation +from znnl.analysis.jax_ntk import JAXNTKComputation class JAXNTKClassWise(JAXNTKComputation): diff --git a/znnl/ntk_computation/jax_ntk_combinations.py b/znnl/analysis/jax_ntk_combinations.py similarity index 99% rename from znnl/ntk_computation/jax_ntk_combinations.py rename to znnl/analysis/jax_ntk_combinations.py index 059be7f..b1568fb 100644 --- a/znnl/ntk_computation/jax_ntk_combinations.py +++ b/znnl/analysis/jax_ntk_combinations.py @@ -32,7 +32,7 @@ import neural_tangents as nt from papyrus.utils.matrix_utils import flatten_rank_4_tensor, unflatten_rank_4_tensor -from znnl.ntk_computation.jax_ntk import JAXNTKComputation +from znnl.analysis.jax_ntk import JAXNTKComputation class JAXNTKCombinations(JAXNTKComputation): diff --git a/znnl/ntk_computation/jax_ntk_subsampling.py b/znnl/analysis/jax_ntk_subsampling.py similarity index 99% rename from znnl/ntk_computation/jax_ntk_subsampling.py rename to znnl/analysis/jax_ntk_subsampling.py index 036b11d..836ef9c 100644 --- a/znnl/ntk_computation/jax_ntk_subsampling.py +++ b/znnl/analysis/jax_ntk_subsampling.py @@ -32,7 +32,7 @@ from jax import random from jax.tree_util import tree_map as jmap -from znnl.ntk_computation.jax_ntk import JAXNTKComputation +from znnl.analysis.jax_ntk import JAXNTKComputation class JAXNTKSubsampling(JAXNTKComputation): diff --git a/znnl/ntk_computation/__init__.py b/znnl/ntk_computation/__init__.py deleted file mode 100644 index 3f9b02b..0000000 --- a/znnl/ntk_computation/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -ZnNL: A Zincwarecode package. - -License -------- -This program and the accompanying materials are made available under the terms -of the Eclipse Public License v2.0 which accompanies this distribution, and is -available at https://www.eclipse.org/legal/epl-v20.html - -SPDX-License-Identifier: EPL-2.0 - -Copyright Contributors to the Zincwarecode Project. - -Contact Information -------------------- -email: zincwarecode@gmail.com -github: https://github.com/zincware -web: https://zincwarecode.com/ - -Citation --------- -If you use this module please cite us with: - -Summary -------- -""" - -from znnl.ntk_computation.jax_ntk import JAXNTKComputation -from znnl.ntk_computation.jax_ntk_classwise import JAXNTKClassWise -from znnl.ntk_computation.jax_ntk_combinations import JAXNTKCombinations -from znnl.ntk_computation.jax_ntk_subsampling import JAXNTKSubsampling - -__all__ = [ - JAXNTKComputation.__name__, - JAXNTKClassWise.__name__, - JAXNTKSubsampling.__name__, - JAXNTKCombinations.__name__, -] diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 72574c0..73457f5 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -37,10 +37,10 @@ from znnl.accuracy_functions.accuracy_function import AccuracyFunction from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis +from znnl.analysis.jax_ntk import JAXNTKComputation from znnl.analysis.loss_fn_derivative import LossDerivative from znnl.loss_functions import SimpleLoss from znnl.models.jax_model import JaxModel -from znnl.ntk_computation.jax_ntk import JAXNTKComputation from znnl.training_recording.data_storage import DataStorage from znnl.utils.matrix_utils import ( calculate_trace, diff --git a/znnl/training_recording/papyrus_jax_recording.py b/znnl/training_recording/papyrus_jax_recording.py index 8657135..1503591 100644 --- a/znnl/training_recording/papyrus_jax_recording.py +++ b/znnl/training_recording/papyrus_jax_recording.py @@ -31,8 +31,8 @@ from papyrus.measurements import BaseMeasurement from papyrus.recorders import BaseRecorder +from znnl.analysis import JAXNTKComputation from znnl.models import JaxModel -from znnl.ntk_computation import JAXNTKComputation class JaxRecorder(BaseRecorder): diff --git a/znnl/training_strategies/loss_aware_reservoir.py b/znnl/training_strategies/loss_aware_reservoir.py index 38ead5c..287f12d 100644 --- a/znnl/training_strategies/loss_aware_reservoir.py +++ b/znnl/training_strategies/loss_aware_reservoir.py @@ -35,9 +35,9 @@ from tqdm import trange from znnl.accuracy_functions.accuracy_function import AccuracyFunction +from znnl.analysis.jax_ntk import JAXNTKComputation from znnl.distance_metrics import DistanceMetric from znnl.models.jax_model import JaxModel -from znnl.ntk_computation.jax_ntk import JAXNTKComputation from znnl.optimizers.trace_optimizer import TraceOptimizer from znnl.training_recording import JaxRecorder from znnl.training_strategies.recursive_mode import RecursiveMode diff --git a/znnl/training_strategies/partitioned_training.py b/znnl/training_strategies/partitioned_training.py index a4005f0..0b91180 100644 --- a/znnl/training_strategies/partitioned_training.py +++ b/znnl/training_strategies/partitioned_training.py @@ -32,8 +32,8 @@ from tqdm import trange from znnl.accuracy_functions.accuracy_function import AccuracyFunction +from znnl.analysis.jax_ntk import JAXNTKComputation from znnl.models.jax_model import JaxModel -from znnl.ntk_computation.jax_ntk import JAXNTKComputation from znnl.optimizers.trace_optimizer import TraceOptimizer from znnl.training_recording import JaxRecorder from znnl.training_strategies.recursive_mode import RecursiveMode diff --git a/znnl/training_strategies/simple_training.py b/znnl/training_strategies/simple_training.py index 9fafb75..a53fc08 100644 --- a/znnl/training_strategies/simple_training.py +++ b/znnl/training_strategies/simple_training.py @@ -36,8 +36,8 @@ from tqdm import trange from znnl.accuracy_functions.accuracy_function import AccuracyFunction +from znnl.analysis.jax_ntk import JAXNTKComputation from znnl.models.jax_model import JaxModel -from znnl.ntk_computation.jax_ntk import JAXNTKComputation from znnl.optimizers.trace_optimizer import TraceOptimizer from znnl.training_recording import JaxRecorder from znnl.training_strategies.recursive_mode import RecursiveMode