Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit a16098c

Browse files
committed
Add ReluNTKFeatures test
1 parent af524d2 commit a16098c

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

experimental/features.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -521,12 +521,14 @@ def _cholesky(mat):
521521

522522

523523
@layer
524-
def ReluNTKFeatures(
525-
num_layers: int,
526-
poly_degree: int = 16,
527-
poly_sketch_dim: int = 1024,
528-
W_std: float = 1.,
529-
):
524+
def ReluNTKFeatures(num_layers: int,
525+
poly_degree: int = 16,
526+
poly_sketch_dim: int = 1024,
527+
batch_axis: int = 0,
528+
channel_axis: int = -1):
529+
530+
if batch_axis != 0 or channel_axis != -1:
531+
raise NotImplementedError(f'Not supported axes.')
530532

531533
def init_fn(rng, input_shape):
532534
input_dim = input_shape[0][-1]
@@ -541,14 +543,18 @@ def init_fn(rng, input_shape):
541543

542544
return (), (polysketch, nngp_coeffs, ntk_coeffs)
543545

544-
def feature_fn(f, input=None, **kwargs):
546+
@requires(batch_axis=batch_axis, channel_axis=channel_axis)
547+
def feature_fn(f: Features, input=None, **kwargs):
545548
input_shape = f.nngp_feat.shape[:-1]
546549

547550
polysketch: PolyTensorSketch = input[0]
548551
nngp_coeffs: np.ndarray = input[1]
549552
ntk_coeffs: np.ndarray = input[2]
550553

551-
polysketch_feats = polysketch.sketch(f.nngp_feat)
554+
norms = np.linalg.norm(f.nngp_feat, axis=channel_axis, keepdims=True)
555+
nngp_feat = f.nngp_feat / norms
556+
557+
polysketch_feats = polysketch.sketch(nngp_feat)
552558
nngp_feat = polysketch.expand_feats(polysketch_feats, nngp_coeffs)
553559
ntk_feat = polysketch.expand_feats(polysketch_feats, ntk_coeffs)
554560

@@ -557,8 +563,11 @@ def feature_fn(f, input=None, **kwargs):
557563
ntk_feat = polysketch.standardsrht(ntk_feat).reshape(input_shape + (-1,))
558564

559565
# Convert complex features to real ones.
560-
ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=-1)
561566
nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=-1)
567+
ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=-1)
568+
569+
nngp_feat *= norms / 2**(num_layers / 2.)
570+
ntk_feat *= norms / 2**(num_layers / 2.)
562571

563572
return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat)
564573

experimental/tests/features_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,37 @@ def test_aggregate_features(self):
534534
self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T)
535535
self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T)
536536

537+
@parameterized.product(n_layers=[1, 2, 3, 4, 5], do_jit=[True, False])
538+
def test_onepass_fc_relu_nngp_ntk(self, n_layers, do_jit):
539+
rng = random.PRNGKey(1)
540+
n, d = 4, 256
541+
x = _get_init_data(rng, (n, d))
542+
543+
kernel_fn = stax.serial(*[stax.Dense(1), stax.Relu()] * n_layers +
544+
[stax.Dense(1)])[2]
545+
546+
poly_degree = 8
547+
poly_sketch_dim = 4096
548+
549+
init_fn, feature_fn = ft.ReluNTKFeatures(n_layers, poly_degree,
550+
poly_sketch_dim)
551+
552+
rng2 = random.PRNGKey(2)
553+
_, feat_fn_inputs = init_fn(rng2, x.shape)
554+
555+
if do_jit:
556+
kernel_fn = jit(kernel_fn)
557+
feature_fn = jit(feature_fn)
558+
559+
k = kernel_fn(x)
560+
f = feature_fn(x, feat_fn_inputs)
561+
562+
k_nngp_approx = f.nngp_feat @ f.nngp_feat.T
563+
k_ntk_approx = f.ntk_feat @ f.ntk_feat.T
564+
565+
test_utils.assert_close_matrices(self, k.nngp, k_nngp_approx, 0.2, 1.)
566+
test_utils.assert_close_matrices(self, k.ntk, k_ntk_approx, 0.2, 1.)
567+
537568

538569
if __name__ == "__main__":
539570
absltest.main()

0 commit comments

Comments
 (0)