Skip to content

jax leak problems #710

Open
Open
@Laohusong

Description

@Laohusong

I followed the official tutorial to use the bptt trainer and found jax leak problems.

with jax.checking_leaks():
    trainer.fit(train_data, num_epoch=30)

it casued

UnexpectedTracerError                     Traceback (most recent call last)
Cell In[28], [line 4](vscode-notebook-cell:?execution_count=28&line=4)
      [2](vscode-notebook-cell:?execution_count=28&line=2) import jax
      [3](vscode-notebook-cell:?execution_count=28&line=3) with jax.checking_leaks():
----> [4](vscode-notebook-cell:?execution_count=28&line=4)     trainer.fit(train_data, num_epoch=30)

File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\brainpy\_src\train\back_propagation.py:285, in BPTrainer.fit(self, train_data, test_data, num_epoch, num_report, reset_state, shared_args, fun_after_report, batch_size)
    [282](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:282)   self.reset_state()
    [284](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:284) # training
--> [285](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:285) res = self.f_train(shared_args, x, y)
    [287](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:287) # loss
    [288](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/train/back_propagation.py:288) fit_epoch_metric['loss'].append(res[0])

File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\brainpy\_src\math\object_transform\jit.py:213, in JITTransform.__call__(self, *args, **kwargs)
    [210](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:210)     return rets
    [212](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:212) # call the transformed function
--> [213](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:213) return _jit_call_take_care_of_rngs(self._transform, self._dyn_vars, *args, **kwargs)

File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\brainpy\_src\math\object_transform\jit.py:94, in _jit_call_take_care_of_rngs(transform, stack, *args, **kwargs)
     [91](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:91) def _jit_call_take_care_of_rngs(transform, stack, *args, **kwargs):
     [92](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:92)   # call the transformed function
     [93](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:93)   rng_keys = stack.call_on_subset(_is_rng, _rng_split_key)
---> [94](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:94)   changes, out = transform(stack.dict_data(), *args, **kwargs)
     [95](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:95)   for key, v in changes.items():
     [96](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/brainpy/_src/math/object_transform/jit.py:96)     stack[key]._value = v

    [... skipping hidden 3 frame]

File c:\Users\laohu\anaconda3\envs\brainpy_env3\lib\site-packages\jax\_src\core.py:924, in check_eval_args(args)
    [922](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/jax/_src/core.py:922) for arg in args:
    [923](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/jax/_src/core.py:923)   if isinstance(arg, Tracer):
--> [924](file:///C:/Users/laohu/anaconda3/envs/brainpy_env3/lib/site-packages/jax/_src/core.py:924)     raise escaped_tracer_error(arg)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[128,100] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

I used python=3.11, brianpy=2.60, gpu version (I also tried the cpu version, the same problem),just installed following the latest tutorial.

I used my both windows and mac and found the same bug.

Will it be ok to downgrade to brainpy=2.4?

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions