Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit e13d9d9

Browse files
committed
Fix initialization of ndarray
1 parent 9dc3536 commit e13d9d9

File tree

5 files changed

+24
-19
lines changed

5 files changed

+24
-19
lines changed

experimental/__init__.py

Whitespace-only changes.

experimental/features.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,21 @@
66
import neural_tangents
77
from neural_tangents import stax
88

9-
from pkg_resources import parse_version
10-
if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'):
11-
from neural_tangents._src.utils import utils, dataclasses
12-
from neural_tangents._src.stax.linear import _pool_kernel, Padding
13-
from neural_tangents._src.stax.linear import _Pooling as Pooling
14-
else:
15-
from neural_tangents.utils import utils, dataclasses
16-
from neural_tangents.stax import _pool_kernel, Padding, Pooling
17-
18-
from sketching import TensorSRHT2, PolyTensorSRHT
9+
# from pkg_resources import parse_version
10+
# if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'):
11+
# from neural_tangents._src.utils import utils, dataclasses
12+
# from neural_tangents._src.stax.linear import _pool_kernel, Padding
13+
# from neural_tangents._src.stax.linear import _Pooling as Pooling
14+
# else:
15+
# from neural_tangents.utils import utils, dataclasses
16+
# from neural_tangents.stax import _pool_kernel, Padding, Pooling
17+
from neural_tangents._src.utils import dataclasses
18+
# from neural_tangents._src.utils.typing import Optional
19+
from typing import Optional
20+
from neural_tangents._src.stax.linear import _pool_kernel, Padding
21+
from neural_tangents._src.stax.linear import _Pooling as Pooling
22+
23+
from experimental.sketching import TensorSRHT2, PolyTensorSRHT
1924
""" Implementation for NTK Sketching and Random Features """
2025

2126

@@ -50,13 +55,13 @@ def kappa1(x):
5055

5156
@dataclasses.dataclass
5257
class Features:
53-
nngp_feat: np.ndarray
54-
ntk_feat: np.ndarray
58+
nngp_feat: Optional[np.ndarray] = None
59+
ntk_feat: Optional[np.ndarray] = None
5560

5661
batch_axis: int = dataclasses.field(pytree_node=False)
5762
channel_axis: int = dataclasses.field(pytree_node=False)
5863

59-
replace = ... # type: Callable[..., 'Features']
64+
replace = ...
6065

6166

6267
def _inputs_to_features(x: np.ndarray,
@@ -69,7 +74,7 @@ def _inputs_to_features(x: np.ndarray,
6974
nngp_feat = x / x.shape[channel_axis]**0.5
7075
ntk_feat = np.empty((), dtype=nngp_feat.dtype)
7176

72-
return Features(nngp_feat=nngp_feat,
77+
return Features.replace(nngp_feat=nngp_feat,
7378
ntk_feat=ntk_feat,
7479
batch_axis=batch_axis,
7580
channel_axis=channel_axis)

experimental/sketching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def __init__(self, rng, input_dim, sketch_dim, coeffs):
6161
degree = len(coeffs) - 1
6262
self.degree = degree
6363

64-
self.tree_rand_signs = [0 for i in range((self.degree - 1).bit_length())]
65-
self.tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())]
64+
self.tree_rand_signs: Optional[np.ndarray] = [0 for i in range((self.degree - 1).bit_length())]
65+
self.tree_rand_inds: Optional[np.ndarray] = [0 for i in range((self.degree - 1).bit_length())]
6666
rng1, rng2, rng3 = random.split(rng, 3)
6767

6868
ske_dim_ = sketch_dim // 4
@@ -92,7 +92,7 @@ def __init__(self, rng, input_dim, sketch_dim, coeffs):
9292
def sketch(self, x):
9393
n = x.shape[0]
9494
log_degree = len(self.tree_rand_signs)
95-
V = [0 for i in range(log_degree)]
95+
V: Optional[np.ndarray] = [0 for i in range(log_degree)]
9696
E1 = np.concatenate((np.ones(
9797
(n, 1), dtype=x.dtype), np.zeros((n, x.shape[-1] - 1), dtype=x.dtype)),
9898
1)

experimental/test_fc_ntk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
config.update("jax_enable_x64", True)
77
from neural_tangents import stax
88

9-
from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial
9+
from experimental.features import _inputs_to_features, DenseFeatures, ReluFeatures, serial
1010

1111
seed = 1
1212
n, d = 6, 4

experimental/test_myrtle_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from jax import random
1313

1414
from neural_tangents import stax
15-
from features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features
15+
from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features
1616

1717
layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
1818
width = 1

0 commit comments

Comments
 (0)