@@ -521,12 +521,14 @@ def _cholesky(mat):
521
521
522
522
523
523
@layer
524
- def ReluNTKFeatures (
525
- num_layers : int ,
526
- poly_degree : int = 16 ,
527
- poly_sketch_dim : int = 1024 ,
528
- W_std : float = 1. ,
529
- ):
524
+ def ReluNTKFeatures (num_layers : int ,
525
+ poly_degree : int = 16 ,
526
+ poly_sketch_dim : int = 1024 ,
527
+ batch_axis : int = 0 ,
528
+ channel_axis : int = - 1 ):
529
+
530
+ if batch_axis != 0 or channel_axis != - 1 :
531
+ raise NotImplementedError (f'Not supported axes.' )
530
532
531
533
def init_fn (rng , input_shape ):
532
534
input_dim = input_shape [0 ][- 1 ]
@@ -541,14 +543,18 @@ def init_fn(rng, input_shape):
541
543
542
544
return (), (polysketch , nngp_coeffs , ntk_coeffs )
543
545
544
- def feature_fn (f , input = None , ** kwargs ):
546
+ @requires (batch_axis = batch_axis , channel_axis = channel_axis )
547
+ def feature_fn (f : Features , input = None , ** kwargs ):
545
548
input_shape = f .nngp_feat .shape [:- 1 ]
546
549
547
550
polysketch : PolyTensorSketch = input [0 ]
548
551
nngp_coeffs : np .ndarray = input [1 ]
549
552
ntk_coeffs : np .ndarray = input [2 ]
550
553
551
- polysketch_feats = polysketch .sketch (f .nngp_feat )
554
+ norms = np .linalg .norm (f .nngp_feat , axis = channel_axis , keepdims = True )
555
+ nngp_feat = f .nngp_feat / norms
556
+
557
+ polysketch_feats = polysketch .sketch (nngp_feat )
552
558
nngp_feat = polysketch .expand_feats (polysketch_feats , nngp_coeffs )
553
559
ntk_feat = polysketch .expand_feats (polysketch_feats , ntk_coeffs )
554
560
@@ -557,8 +563,11 @@ def feature_fn(f, input=None, **kwargs):
557
563
ntk_feat = polysketch .standardsrht (ntk_feat ).reshape (input_shape + (- 1 ,))
558
564
559
565
# Convert complex features to real ones.
560
- ntk_feat = np .concatenate ((ntk_feat .real , ntk_feat .imag ), axis = - 1 )
561
566
nngp_feat = np .concatenate ((nngp_feat .real , nngp_feat .imag ), axis = - 1 )
567
+ ntk_feat = np .concatenate ((ntk_feat .real , ntk_feat .imag ), axis = - 1 )
568
+
569
+ nngp_feat *= norms / 2 ** (num_layers / 2. )
570
+ ntk_feat *= norms / 2 ** (num_layers / 2. )
562
571
563
572
return f .replace (nngp_feat = nngp_feat , ntk_feat = ntk_feat )
564
573
0 commit comments