Skip to content

Commit

Permalink
Move ntk computation to analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 31, 2024
1 parent 3b5f5a1 commit 6e0fad9
Show file tree
Hide file tree
Showing 24 changed files with 30 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from flax import linen as nn
from jax import random

from znnl.analysis import JAXNTKComputation
from znnl.models import FlaxModel
from znnl.ntk_computation import JAXNTKComputation


class FlaxTestModule(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from flax import linen as nn
from jax import random

from znnl.analysis import JAXNTKClassWise
from znnl.models import FlaxModel
from znnl.ntk_computation import JAXNTKClassWise


class FlaxTestModule(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from jax import random
from papyrus.utils.matrix_utils import flatten_rank_4_tensor

from znnl.analysis import JAXNTKCombinations
from znnl.models import FlaxModel
from znnl.ntk_computation import JAXNTKCombinations


class FlaxTestModule(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from flax import linen as nn
from jax import random

from znnl.analysis import JAXNTKSubsampling
from znnl.models import FlaxModel
from znnl.ntk_computation import JAXNTKSubsampling


class FlaxTestModule(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion CI/unit_tests/models/_test_huggingface_flax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from jax import random
from transformers import FlaxResNetForImageClassification, ResNetConfig

from znnl.analysis import JAXNTKComputation
from znnl.models import HuggingFaceFlaxModel
from znnl.ntk_computation import JAXNTKComputation


class TestFlaxHFModule:
Expand Down
2 changes: 1 addition & 1 deletion CI/unit_tests/models/test_flax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from flax import linen as nn
from jax import random

from znnl.analysis import JAXNTKComputation
from znnl.models import FlaxModel
from znnl.ntk_computation import JAXNTKComputation


class FlaxTestModule(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion CI/unit_tests/models/test_nt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from jax import random
from neural_tangents import stax

from znnl.analysis import JAXNTKComputation
from znnl.models import NTModel
from znnl.ntk_computation import JAXNTKComputation


class TestNTModule:
Expand Down
2 changes: 1 addition & 1 deletion CI/unit_tests/optimizers/test_trace_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
import jax.numpy as np
from neural_tangents import stax

from znnl.analysis import JAXNTKComputation
from znnl.data import MNISTGenerator
from znnl.models import NTModel
from znnl.ntk_computation import JAXNTKComputation
from znnl.optimizers import TraceOptimizer


Expand Down
2 changes: 1 addition & 1 deletion CI/unit_tests/training_recording/test_jax_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from numpy.testing import assert_raises
from papyrus.measurements import Accuracy, Loss, NTKTrace

from znnl.analysis import JAXNTKComputation
from znnl.models import FlaxModel
from znnl.ntk_computation import JAXNTKComputation
from znnl.training_recording import JaxRecorder


Expand Down
2 changes: 1 addition & 1 deletion examples/Computing-Collective-Variables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
"metadata": {},
"outputs": [],
"source": [
"ntk_computation = nl.ntk_computation.JAXNTKComputation(\n",
"ntk_computation = nl.analysis.JAXNTKComputation(\n",
" apply_fn=fuel_model.ntk_apply_fn, \n",
" batch_size=314,\n",
")\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/Contrastive-Loss.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@
" update_rate=1, \n",
" chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n",
")\n",
"ntk_computation = znnl.ntk_computation.JAXNTKComputation(\n",
"ntk_computation = znnl.analysis.JAXNTKComputation(\n",
" apply_fn=model.ntk_apply_fn, \n",
" batch_size=10,\n",
")\n",
Expand Down Expand Up @@ -461,7 +461,7 @@
" update_rate=1, \n",
" chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n",
")\n",
"ntk_computation = znnl.ntk_computation.JAXNTKComputation(\n",
"ntk_computation = znnl.analysis.JAXNTKComputation(\n",
" apply_fn=model.ntk_apply_fn, \n",
" batch_size=10,\n",
")\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/ResNet-Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
" ],\n",
" update_rate=1, \n",
")\n",
"ntk_computation = nl.ntk_computation.JAXNTKComputation(\n",
"ntk_computation = nl.analysis.JAXNTKComputation(\n",
" apply_fn=model.ntk_apply_fn, \n",
" batch_size=10, \n",
")\n",
Expand Down
2 changes: 1 addition & 1 deletion znnl/agents/approximate_maximum_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@

from znnl.agents.agent import Agent
from znnl.analysis.entropy import EntropyAnalysis
from znnl.analysis.jax_ntk import JAXNTKComputation
from znnl.data import DataGenerator
from znnl.models import JaxModel
from znnl.ntk_computation.jax_ntk import JAXNTKComputation
from znnl.utils.prng import PRNGKey


Expand Down
8 changes: 8 additions & 0 deletions znnl/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,18 @@

from znnl.analysis.eigensystem import EigenSpaceAnalysis
from znnl.analysis.entropy import EntropyAnalysis
from znnl.analysis.jax_ntk import JAXNTKComputation
from znnl.analysis.jax_ntk_classwise import JAXNTKClassWise
from znnl.analysis.jax_ntk_combinations import JAXNTKCombinations
from znnl.analysis.jax_ntk_subsampling import JAXNTKSubsampling
from znnl.analysis.loss_fn_derivative import LossDerivative

__all__ = [
EntropyAnalysis.__name__,
EigenSpaceAnalysis.__name__,
LossDerivative.__name__,
JAXNTKComputation.__name__,
JAXNTKClassWise.__name__,
JAXNTKSubsampling.__name__,
JAXNTKCombinations.__name__,
]
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from jax import random, vmap
from jax.tree_util import tree_map as jmap

from znnl.ntk_computation.jax_ntk import JAXNTKComputation
from znnl.analysis.jax_ntk import JAXNTKComputation


class JAXNTKClassWise(JAXNTKComputation):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import neural_tangents as nt
from papyrus.utils.matrix_utils import flatten_rank_4_tensor, unflatten_rank_4_tensor

from znnl.ntk_computation.jax_ntk import JAXNTKComputation
from znnl.analysis.jax_ntk import JAXNTKComputation


class JAXNTKCombinations(JAXNTKComputation):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from jax import random
from jax.tree_util import tree_map as jmap

from znnl.ntk_computation.jax_ntk import JAXNTKComputation
from znnl.analysis.jax_ntk import JAXNTKComputation


class JAXNTKSubsampling(JAXNTKComputation):
Expand Down
38 changes: 0 additions & 38 deletions znnl/ntk_computation/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
from znnl.accuracy_functions.accuracy_function import AccuracyFunction
from znnl.analysis.eigensystem import EigenSpaceAnalysis
from znnl.analysis.entropy import EntropyAnalysis
from znnl.analysis.jax_ntk import JAXNTKComputation
from znnl.analysis.loss_fn_derivative import LossDerivative
from znnl.loss_functions import SimpleLoss
from znnl.models.jax_model import JaxModel
from znnl.ntk_computation.jax_ntk import JAXNTKComputation
from znnl.training_recording.data_storage import DataStorage
from znnl.utils.matrix_utils import (
calculate_trace,
Expand Down
2 changes: 1 addition & 1 deletion znnl/training_recording/papyrus_jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from papyrus.measurements import BaseMeasurement
from papyrus.recorders import BaseRecorder

from znnl.analysis import JAXNTKComputation
from znnl.models import JaxModel
from znnl.ntk_computation import JAXNTKComputation


class JaxRecorder(BaseRecorder):
Expand Down
2 changes: 1 addition & 1 deletion znnl/training_strategies/loss_aware_reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
from tqdm import trange

from znnl.accuracy_functions.accuracy_function import AccuracyFunction
from znnl.analysis.jax_ntk import JAXNTKComputation
from znnl.distance_metrics import DistanceMetric
from znnl.models.jax_model import JaxModel
from znnl.ntk_computation.jax_ntk import JAXNTKComputation
from znnl.optimizers.trace_optimizer import TraceOptimizer
from znnl.training_recording import JaxRecorder
from znnl.training_strategies.recursive_mode import RecursiveMode
Expand Down
2 changes: 1 addition & 1 deletion znnl/training_strategies/partitioned_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from tqdm import trange

from znnl.accuracy_functions.accuracy_function import AccuracyFunction
from znnl.analysis.jax_ntk import JAXNTKComputation
from znnl.models.jax_model import JaxModel
from znnl.ntk_computation.jax_ntk import JAXNTKComputation
from znnl.optimizers.trace_optimizer import TraceOptimizer
from znnl.training_recording import JaxRecorder
from znnl.training_strategies.recursive_mode import RecursiveMode
Expand Down
2 changes: 1 addition & 1 deletion znnl/training_strategies/simple_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
from tqdm import trange

from znnl.accuracy_functions.accuracy_function import AccuracyFunction
from znnl.analysis.jax_ntk import JAXNTKComputation
from znnl.models.jax_model import JaxModel
from znnl.ntk_computation.jax_ntk import JAXNTKComputation
from znnl.optimizers.trace_optimizer import TraceOptimizer
from znnl.training_recording import JaxRecorder
from znnl.training_strategies.recursive_mode import RecursiveMode
Expand Down

0 comments on commit 6e0fad9

Please sign in to comment.