6
6
import neural_tangents
7
7
from neural_tangents import stax
8
8
9
- from pkg_resources import parse_version
10
- if parse_version (neural_tangents .__version__ ) >= parse_version ('0.5.0' ):
11
- from neural_tangents ._src .utils import utils , dataclasses
12
- from neural_tangents ._src .stax .linear import _pool_kernel , Padding
13
- from neural_tangents ._src .stax .linear import _Pooling as Pooling
14
- else :
15
- from neural_tangents .utils import utils , dataclasses
16
- from neural_tangents .stax import _pool_kernel , Padding , Pooling
17
-
18
- from sketching import TensorSRHT2 , PolyTensorSRHT
9
+ # from pkg_resources import parse_version
10
+ # if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'):
11
+ # from neural_tangents._src.utils import utils, dataclasses
12
+ # from neural_tangents._src.stax.linear import _pool_kernel, Padding
13
+ # from neural_tangents._src.stax.linear import _Pooling as Pooling
14
+ # else:
15
+ # from neural_tangents.utils import utils, dataclasses
16
+ # from neural_tangents.stax import _pool_kernel, Padding, Pooling
17
+ from neural_tangents ._src .utils import dataclasses
18
+ # from neural_tangents._src.utils.typing import Optional
19
+ from typing import Optional
20
+ from neural_tangents ._src .stax .linear import _pool_kernel , Padding
21
+ from neural_tangents ._src .stax .linear import _Pooling as Pooling
22
+
23
+ from experimental .sketching import TensorSRHT2 , PolyTensorSRHT
19
24
""" Implementation for NTK Sketching and Random Features """
20
25
21
26
@@ -50,13 +55,13 @@ def kappa1(x):
50
55
51
56
@dataclasses .dataclass
52
57
class Features :
53
- nngp_feat : np .ndarray
54
- ntk_feat : np .ndarray
58
+ nngp_feat : Optional [ np .ndarray ] = None
59
+ ntk_feat : Optional [ np .ndarray ] = None
55
60
56
61
batch_axis : int = dataclasses .field (pytree_node = False )
57
62
channel_axis : int = dataclasses .field (pytree_node = False )
58
63
59
- replace = ... # type: Callable[..., 'Features']
64
+ replace = ...
60
65
61
66
62
67
def _inputs_to_features (x : np .ndarray ,
@@ -69,7 +74,7 @@ def _inputs_to_features(x: np.ndarray,
69
74
nngp_feat = x / x .shape [channel_axis ]** 0.5
70
75
ntk_feat = np .empty ((), dtype = nngp_feat .dtype )
71
76
72
- return Features (nngp_feat = nngp_feat ,
77
+ return Features . replace (nngp_feat = nngp_feat ,
73
78
ntk_feat = ntk_feat ,
74
79
batch_axis = batch_axis ,
75
80
channel_axis = channel_axis )
0 commit comments