Fast JAX compilation when network architecture has per-dataset components. #2316
Open
Description
I have a reasonably common neural network architectural motif that has input and output matrices that are per-dataset dependent, but internal core matrices that are common to all datasets. I'm trying to figure out a way to get this to work in jit'd JAX and I keep getting errors in when I attempt to execute the jit'd function. There's a toy example shown below. As I understand it the best I can currently do is to understand that JAX will compile each network
function on a per id
basis, which isn't great when you have hundreds of datasets. Ideas?
Thank you!
from jax import jit
ndatasets = 2
params = {str(k) : {'Win' : jnp.ones((2,2)), 'Wout' : jnp.ones((2,2))} for k in jnp.arange(ndatasets)}
params['common'] = {'W' : jnp.ones((2,2))}
print(params)
def network(params, id):
return jnp.dot(params[id]['Wout'], (jnp.dot(params['common']['W'], jnp.dot(params[id]['Win'], jnp.ones((2,))))))
print(network(params, '1'))
network_jit = jit(network)
print(network_jit(params, '1'))
with output
{'0': {'Win': DeviceArray([[1., 1.],
[1., 1.]], dtype=float32), 'Wout': DeviceArray([[1., 1.],
[1., 1.]], dtype=float32)}, '1': {'Win': DeviceArray([[1., 1.],
[1., 1.]], dtype=float32), 'Wout': DeviceArray([[1., 1.],
[1., 1.]], dtype=float32)}, 'common': {'W': DeviceArray([[1., 1.],
[1., 1.]], dtype=float32)}}
[8. 8.]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-71-88b774892abd> in <module>()
16
17 network_jit = jit(network)
---> 18 print(network_jit(params, '1'))
1 frames
google3/third_party/py/jax/api.py in _check_args(args)
1411 if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
1412 raise TypeError("Argument '{}' of type {} is not a valid JAX type"
-> 1413 .format(arg, type(arg)))
1414
1415 def _valid_jaxtype(arg):
TypeError: Argument '1' of type <class 'str'> is not a valid JAX type
Metadata
Assignees
Labels
No labels