Skip to content

The model is traced for each steps_per_execution step (Jax backend) #20411

@nicolaspi

Description

@nicolaspi

The model is traced for each steps_per_execution steps (using Jax backend), increasing the jit compilation time and memory usage proportionally.

This gist demonstrates the issue on MobileNetV2:

steps_per_execution = 1, model retracing count: 1, memory usage: 622MB, compilation overhead 4.11 seconds.
steps_per_execution = 5, model retracing count: 5, memory usage: 1610MB, compilation overhead 8.10 seconds.
steps_per_execution = 10, model retracing count: 10, memory usage: 3185MB, compilation overhead 14.54 seconds.
steps_per_execution = 20, model retracing count: 20, memory usage: 6337MB, compilation overhead 30.39 seconds.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions