Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Batchnorm #13

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,6 @@ If you use the code in a publication, please cite our ICLR 2020 paper:

##### [14] [Wide Residual Networks.](https://arxiv.org/abs/1605.07146) *BMVC 2018.* Sergey Zagoruyko, Nikos Komodakis

##### [15] [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) *arXiv 2020.* Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee
##### [15] [Tensor Programs I: Wide Feedforward or Recurrent Neural Networks of Any Architecture are Gaussian Processes.](https://arxiv.org/abs/1910.12478) *NeurIPS 2019.* Greg Yang.

##### [16] [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) *arXiv 2020.* Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee
55 changes: 23 additions & 32 deletions examples/infinite_fcn.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
from examples import util
from jax import random


flags.DEFINE_integer('train_size', 1000,
Expand All @@ -37,43 +36,35 @@

FLAGS = flags.FLAGS

import pdb
from jax.experimental import callback
from functools import partial

def main(unused_argv):
# Build data pipelines.
print('Loading data.')
x_train, y_train, x_test, y_test = \
datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size)
key = random.PRNGKey(0)
key, split = random.split(key)
x_train = random.normal(key=key, shape=[2, 3, 4, 5])
x_train2 = random.normal(key=split, shape=[1, 3, 4, 5])

# Build the infinite network.
_, _, kernel_fn = stax.serial(
stax.Dense(1, 2., 0.05),
stax.Relu(),
stax.Dense(1, 2., 0.05)
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Conv(256, (3, 3), padding='SAME'),
stax.BatchNormRelu((0, 1, 2)),
stax.GlobalAvgPool(),
stax.Dense(256, 2., 0.05)
)

# Optionally, compute the kernel in batches, in parallel.
kernel_fn = nt.batch(kernel_fn,
device_count=0,
batch_size=FLAGS.batch_size)

start = time.time()
# Bayesian and infinite-time gradient descent inference with infinite network.
fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
x_train,
y_train,
x_test,
get=('nngp', 'ntk'),
diag_reg=1e-3)
fx_test_nngp.block_until_ready()
fx_test_ntk.block_until_ready()

duration = time.time() - start
print('Kernel construction and inference done in %s seconds.' % duration)

# Print out accuracy and loss for infinite network predictions.
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
# kernel_fn = callback.find_by_value(partial(kernel_fn, get='nngp'), np.nan)
kerobj = kernel_fn(x_train, x_train2, get='nngp')
theory_ker = kerobj
mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, 10000)
diff = theory_ker - mc_kernel_fn(x_train, x_train2, get='nngp')
print(diff)
# print(kerobj.cov1 - kerobj.nngp)
print(np.linalg.norm(diff) / np.linalg.norm(theory_ker))
# 0.0032839081
return


if __name__ == '__main__':
Expand Down
Loading