Skip to content

Fast JAX compilation when network architecture has per-dataset components. #2316

Open
@sussillo

Description

I have a reasonably common neural network architectural motif that has input and output matrices that are per-dataset dependent, but internal core matrices that are common to all datasets. I'm trying to figure out a way to get this to work in jit'd JAX and I keep getting errors in when I attempt to execute the jit'd function. There's a toy example shown below. As I understand it the best I can currently do is to understand that JAX will compile each network function on a per id basis, which isn't great when you have hundreds of datasets. Ideas?

Thank you!

from jax import jit

ndatasets = 2
params = {str(k) : {'Win' : jnp.ones((2,2)), 'Wout' : jnp.ones((2,2))} for k in jnp.arange(ndatasets)}
params['common'] = {'W' : jnp.ones((2,2))}

print(params)


def network(params, id):
  return jnp.dot(params[id]['Wout'], (jnp.dot(params['common']['W'], jnp.dot(params[id]['Win'], jnp.ones((2,))))))

print(network(params, '1'))


network_jit = jit(network)
print(network_jit(params, '1'))

with output

{'0': {'Win': DeviceArray([[1., 1.],
             [1., 1.]], dtype=float32), 'Wout': DeviceArray([[1., 1.],
             [1., 1.]], dtype=float32)}, '1': {'Win': DeviceArray([[1., 1.],
             [1., 1.]], dtype=float32), 'Wout': DeviceArray([[1., 1.],
             [1., 1.]], dtype=float32)}, 'common': {'W': DeviceArray([[1., 1.],
             [1., 1.]], dtype=float32)}}
[8. 8.]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-71-88b774892abd> in <module>()
     16 
     17 network_jit = jit(network)
---> 18 print(network_jit(params, '1'))

1 frames
google3/third_party/py/jax/api.py in _check_args(args)
   1411     if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
   1412       raise TypeError("Argument '{}' of type {} is not a valid JAX type"
-> 1413                       .format(arg, type(arg)))
   1414 
   1415 def _valid_jaxtype(arg):

TypeError: Argument '1' of type <class 'str'> is not a valid JAX type

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions