- Rethink Why Flax NNX?
- Inspection: doesn't work for us.
- Running computation: worth rethinking whether we should support it or not.
- State handling: doesn't work for us. The state is explicitly considered. A dedicated new class is created before
02_mnist/v5.py
. But it seems a dict like object is already enough. - Model surgery: unclear about the real benefit here. The param & state sync might be an issue. But the operation should be easy since model/param/state are mirrored trees.
- Transforms: need revisit in the future.
- It seems Lux.jl is mainly inspired by the linen style in flax. While the NNX style is more close to pytorch. And existing implementation in this repo is more close to Lux and axlearn.
- DiLoCo: Distributed Low-Communication Training of Language Models
- Photon: Federated LLM Pre-Training
- Convergence of Distributed Adaptive Optimization with Local Updates
- DES-LOC: Desynced Low Communication Adaptive Optimizers for Training Foundation Models
- ArcticTraining
- Looks like our config system is similar to this.
- penzai v2 background
-
Parameters and state variables becoming mutable, shareable variable objects
- This seems to be aligned with current design. Currently a general
dict
is used. Maybe I should also introduce a dedicated class for params and states. -
all variable objects must have a unique label, which can either be specified manually or generated automatically.
- Hmm, I find it difficult to search for a specific Param/State. MAYBE a unique label will do some help here?
-
Eager parameter initialization
- In current design, params & states are separated from models. So more close to lazy initialization?
-
The built-in Transformer implementation also supports loading Llama, Mistral, and GPT-NeoX / Pythia models.
- TODO: this is a good feature to have.
-
- Penzai+ Treescope: A Toolkit for Interpreting, Visualizing, and Editing Models As Data
- ml_dtypes
- Controllable Video Generation: A Survey
- Understand Jax Array and shard_map
- HighPerfLLMs2024