Skip to content

bhark/NCjAx

Repository files navigation

NCjAx

Neural Cellular Automata as a universal digital computing substrate accelerated with JAX.

Note

This takes lots of inspiration from the papers at the bottom of the readme, but also takes some liberties here and there. Assume nothing; expect everything.

Usage

Following is a basic introduction to the NCjAx API. This should be enough to play around and get a feel for the substrate. For more advanced usage, import directly from core files.

Initialization

from NCjAx import NCA, Config

# set up a config
# the only required fields are num_input_nodes and num_output_nodes
conf = Config(
    num_input_nodes=num_input,
    num_output_nodes=num_output,
    k_default=35,
    grid_size=16,
    hidden_channels=4, # cell state hidden channels
    perception='learned3x3',
    hidden=30, # mlp hidden layer
)

# set up our nice little helper
nca = NCA(conf)

key, init_key, pretrain_key = jax.random.split(key, 3)

# initialize parameters
params = nca.init_params(init_key)

# curriculum pretraining - helps to break out of local minima from the get-go
params, pretrain_accuracy = nca.pretrain(params, pretrain_key, steps=3000)

Processing

The process() helper is just a shortcut to help you forward the substrate. It sends input, processes K ticks, and reads output.

key, processing_key = jax.random.split(key)
output, next_state = nca.process(nca_state, nca_params, processing_key, input)

Note that although NCjAx is functionally pure, NCA is by nature stateful. Preserve the state and pass it back the next time around.

Roadmap

If you feel like contributing, be my guest. Here's what's missing to align properly with the papers (and hopefully achieve some sort of learning ability beyond input->output mapping):

  • Fundamental NCA substrate
  • I/O tooling (sending data to, and receiving data from, the substrate)
  • Trainable convolutional filters to replace/extend identity+laplacian
  • "Fire rate" (stochastic per-cell dropout) as a stability measure
  • Pretraining helper (simple identity mapping to help escape local minima)
  • Trainable gain
  • Simple API interface
  • Solving CartPole!

Results

This implementation of Neural Cellular Automata was used as the control policy in Cartpole-v1. It was trained using Double DQN with PER and a slight overflow penalty. To avoid clock hacking and statue-like behavior, a fire rate of 0.5 was used, in line with the fantastic paper (3). The reward system was not manipulated - we relied entirely on the standard cartpole reward scheme. No pool sampling was used. We used a grid size of 12x12, a fixed K=35 per environment step, 6 hidden channels and 20 learned filters. A compacted version of the training regimen can be found in examples/nca_dqn.py.

NCA balancing a pole on a cart

But why?

NCA has some very interesting properties; the idea of a universal function approximator, capable of developing memory, self-regeneration and complex emergence is fantastic. It is extremely computationally demanding and requires lots of stabilizing tricks to solve something as simple as CartPole, thus cleary not a viable contender for any sort of practical use at the moment.

So, why? Because of the implications.

The papers

This implementation takes a bit from each of the following very nice papers:

About

Neural Cellular Automata as a universal digital computing substrate accelerated with JAX

Topics

Resources

License

Stars

Watchers

Forks

Languages