diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..b346ef8
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,151 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# local
+jobs/
+local/
+.vscode/
+
+data/
+*.model
+*.npy
+*.jsonl
+*.pkl
+*.json
+__pycache__/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..cfac740
--- /dev/null
+++ b/README.md
@@ -0,0 +1,140 @@
+# Large World Model (LWM)
+
+[[Project]](https://largeworldmodel.github.io/)
+[[Paper]](https://arxiv.org/abs/2402.08268)
+[[Models]](https://huggingface.co/LargeWorldModel)
+
+**Large World Model (LWM)** is a general-purpose large-context multimodal autoregressive model. It is trained on a large dataset of diverse long videos and books using RingAttention, and can perform language, image, and video understanding and generation.
+
+
+## Approach
+
+
+

+
+
+Current language models fall short in understanding aspects of the world not easily described in words, and struggle with complex, long-form tasks. Video sequences offer valuable temporal information absent in language and static images, making them attractive for joint modeling with language. Such models could develop a understanding of both human textual knowledge and the physical world, enabling broader AI capabilities for assisting humans. However, learning from millions of tokens of video and language sequences poses challenges due to memory constraints, computational complexity, and limited datasets. We address these challenges with RingAttention, a technique for scaling context size arbitrarily without approximations or overheads, enabling scalably training on long sequences. We curate a large dataset of diverse videos and books, and gradually increase context size from 4K to 1M tokens during training to manage computational costs. This paper makes the following contributions: (a) Largest context size neural network: We train one of the largest context size transformers on long video and language sequences, setting new benchmarks in difficult retrieval tasks and long video understanding. (b) Solutions for overcoming vision-language training challenges, including using masked sequence packing for mixing different sequence lengths, loss weighting to balance language and vision, and model-generated QA dataset for long sequence chat. (c) A highly-optimized implementation with RingAttention, masked sequence packing, and other key features for training on millions-length multimodal sequences. (d) Fully open-sourced 7B parameter models capable of processing over 1M vision and language tokens.
+This work paves the way for training on massive datasets of long video and language to develop understanding of both human knowledge and the multimodal world, and broader capabilities.
+
+## LWM Capabilities
+
+
+

+
+ LWM can retrieval facts across 1M context with high accuracy.
+
+
+
+
+
+
+

+
+ LWM can answer questions over 1 hour YouTube video.
+
+
+
+
+
+
+

+
+ LWM can chat with images.
+
+
+
+
+
+
+

+
+ LWM can generate videos and images from text.
+
+
+
+
+## Setup
+Install the requirements with:
+```
+pip install -r requirements.txt
+```
+or set up TPU VM with:
+```
+sh tpu_requirements.sh
+```
+
+
+## Available models
+
+There are language-only and video-language versions, offering context sizes from 32K, to 128K, 256K and 1M tokens. The vision-langauge models are available only in Jax, and the language-only models are available in both PyTorch and Jax. Below are the names of the available models and their corresponding context sizes and capabilities:
+
+| Model Name | Context Size | Language or Vision-Language | Chat or Base | URL |
+|--------------------|--------------|-----------------------------|--------------|----------------------------------------------------------------------------------------------------------------------------------------------|
+| LWM-Text-Chat-128K | 128K | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-128K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-128K-Jax)] |
+| LWM-Text-Chat-256K | 256K | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-256K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-256K-Jax)] |
+| LWM-Text-Chat-512K | 512K | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-512K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-512K-Jax)] |
+| LWM-Text-Chat-1M | 1M | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M-Jax)] |
+| LWM-Text-128K | 128K | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-128K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-128K-Jax)] |
+| LWM-Text-256K | 256K | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-256K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-256K-Jax)] |
+| LWM-Text-512K | 512K | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-512K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-512K-Jax)] |
+| LWM-Text-1M | 1M | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-1M)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-1M-Jax)] |
+| LWM-Chat-32K | 32K | Vision-Language | Chat | [[Jax](https://huggingface.co/LargeWorldModel/LWM-32K-Jax)] |
+| LWM-Chat-128K | 128K | Vision-Language | Chat | [[Jax](https://huggingface.co/LargeWorldModel/LWM-128K-Jax)] |
+| LWM-Chat-1M | 1M | Vision-Language | Chat | [[Jax](https://huggingface.co/LargeWorldModel/LWM-1M-Jax)] |
+
+
+## Code structure
+Use `scan_query_chunk_size` and `scan_key_chunk_size` to control the block size in blockwise compute of the self-attention. Use `scan_mlp_chunk_size` to control the block size in blockwise compute of the feedforward network. Use `scan_attention=True` and `scan_mlp=True` to enable/disable blockwise compute in the self-attention and feed-forward network. Use `remat_attention` and `remat_mlp` to control the rematerialization policy with `nothing_saveable` recommended.
+
+You can use `mesh_dim=dp, fsdp, tp, sp` to control the degree of parallelism and RingAttention. It is a string of 4 integers separated by commas, representing the number of data parallelism, fully sharded data parallelism, tensor parallelism, and sequence parallelism.
+For example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention.
+
+
+## Command-line usage
+In this section, we provide instructions on how to run each of the provided scripts. For each script, you may need to fill in your own paths and values in the variables described in the beginning of each script.
+
+To run each of the following scripts, use `bash .sh`:
+- Language model training: `bash scripts/run_train_text.sh`
+- Vision-Language model training: `bash scripts/run_train_vision_text.sh`
+- Single Needle Evals (Language Model): `bash scripts/run_eval_needle.sh`
+- Multi Needle Evals (Language Model): `bash scripts/run_eval_needle_multi.sh`
+- Sampling images (Vision-Language Model): `bash scripts/run_sample_image.sh`
+- Sampling videos (Vision-LanguageModel): `bash scripts/run_sample_video.sh`
+- Image / Video understanding (Vision-Language Model): `bash scripts/run_vision_chat.sh`
+
+
+## If you have issues
+
+This is based on the [codebase](https://github.com/lhao499/ring-attention) of BPT and RingAttention, with the necessary features for vision-language training. The training and inference have been tested on both TPUv3 and TPUv4.
+
+If you encounter bugs, please open a GitHub issue!
+
+
+## Citation
+
+If you use this codebase, or otherwise found our work valuable, please cite:
+
+```
+@article{liu2023world,
+ title={World Model on Million-Length Video and Language with RingAttention},
+ author={Liu, Hao and Yan, Wilson and Zaharia, Matei and Abbeel, Pieter},
+ journal={arXiv preprint},
+ year={2024},
+}
+@article{liu2023ring,
+ title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
+ author={Liu, Hao and Zaharia, Matei and Abbeel, Pieter},
+ journal={International Conference on Learning Representations},
+ year={2024}
+}
+@article{liu2023blockwise,
+ title={Blockwise Parallel Transformer for Large Context Models},
+ author={Liu, Hao and Abbeel, Pieter},
+ journal={Advances in neural information processing systems},
+ year={2023}
+}
+```
+
+## License
+
+LWM's code and model weights are released under the Apache 2.0 License. See [LICENSE](https://github.com/LargeWorldModel/lwm/blob/main/LICENSE) for further details.
diff --git a/imgs/data.png b/imgs/data.png
new file mode 100644
index 0000000..4a8a608
Binary files /dev/null and b/imgs/data.png differ
diff --git a/imgs/image_chat.png b/imgs/image_chat.png
new file mode 100644
index 0000000..a699e5d
Binary files /dev/null and b/imgs/image_chat.png differ
diff --git a/imgs/image_video_gen.png b/imgs/image_video_gen.png
new file mode 100644
index 0000000..5c7cc9d
Binary files /dev/null and b/imgs/image_video_gen.png differ
diff --git a/imgs/long_video_chat_main.png b/imgs/long_video_chat_main.png
new file mode 100644
index 0000000..f68b8be
Binary files /dev/null and b/imgs/long_video_chat_main.png differ
diff --git a/imgs/single_needle_1M.png b/imgs/single_needle_1M.png
new file mode 100644
index 0000000..4647946
Binary files /dev/null and b/imgs/single_needle_1M.png differ
diff --git a/lwm/__init__.py b/lwm/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/lwm/data.py b/lwm/data.py
new file mode 100644
index 0000000..a6cb177
--- /dev/null
+++ b/lwm/data.py
@@ -0,0 +1,842 @@
+import time
+import random
+from functools import partial
+import json
+from multiprocessing import Pool
+
+from tux import open_file
+from ml_collections import ConfigDict
+import numpy as np
+import jax
+from jax.experimental.multihost_utils import host_local_array_to_global_array
+from jax.sharding import PartitionSpec as PS
+from datasets import load_dataset
+
+
+class DatasetFactory(object):
+ """ Datset builder class. """
+
+ @staticmethod
+ def get_default_config(updates=None):
+ config = ConfigDict()
+ config.type = 'huggingface'
+ config.text_processor = TextProcessor.get_default_config()
+ config.huggingface_dataset = HuggingfaceDataset.get_default_config()
+ config.json_dataset = JsonDataset.get_default_config()
+
+ config.vision_text_processor = VisionTextProcessor.get_default_config()
+ config.json_vision_dataset = JsonVisionDataset.get_default_config()
+
+ if updates is not None:
+ config.update(ConfigDict(updates).copy_and_resolve_references())
+ return config
+
+ @classmethod
+ def load_dataset(cls, config, tokenizer, **kwargs):
+ config = cls.get_default_config(config)
+ if config.type == 'huggingface':
+ text_processor = TextProcessor(config.text_processor, tokenizer)
+ return HuggingfaceDataset(
+ config.huggingface_dataset, tokenizer, text_processor, **kwargs
+ )
+ elif config.type == 'json':
+ text_processor = TextProcessor(config.text_processor, tokenizer)
+ return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
+ elif config.type == 'json_vision':
+ vision_text_processor = VisionTextProcessor(config.vision_text_processor, tokenizer)
+ return JsonVisionDataset(config.json_vision_dataset, tokenizer, vision_text_processor, **kwargs)
+ else:
+ raise ValueError(f'Unknown dataset type: {config.type}')
+
+ def __init__(self):
+ raise ValueError('DatasetFactory is a static class and should not be instantiated.')
+
+
+class TextProcessor(object):
+ """ Example processor that converts a dictionary of texts into tokens. """
+ @staticmethod
+ def get_default_config(updates=None):
+ config = ConfigDict()
+ config.fields_from_example = ''
+ config.fields = ''
+ config.subfield_separator = ' '
+ config.add_bos_token = True
+ config.add_eos_token = True
+ config.prepend_text = ''
+ if updates is not None:
+ config.update(ConfigDict(updates).copy_and_resolve_references())
+ return config
+
+ def __init__(self, config, tokenizer):
+ self.config = self.get_default_config(config)
+ assert self.config.fields != '' or self.config.fields_from_example != '', (
+ 'Either fields or fields_from_example must be specified.'
+ )
+ self.tokenizer = tokenizer
+
+ def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True):
+ if has_aux:
+ example, *aux = example
+ else:
+ aux = tuple()
+ token_buffer = []
+ loss_mask_buffer = []
+
+ if add_bos_token and self.config.add_bos_token:
+ token_buffer.append(self.tokenizer.bos_token_id)
+ loss_mask_buffer.append(0.0)
+
+ if self.config.fields_from_example != '':
+ fields = example[self.config.fields_from_example].split(',')
+ else:
+ fields = self.config.fields.split(',')
+
+ for i, field in enumerate(fields):
+ if field.startswith('[') and field.endswith(']'):
+ # No loss for this field.
+ field = field[1:-1]
+ mask = 0.0
+ else:
+ mask = 1.0
+
+ if field == '<|bos|>':
+ token_buffer.append(self.tokenizer.bos_token_id)
+ loss_mask_buffer.append(mask)
+ elif field == '<|eos|>':
+ token_buffer.append(self.tokenizer.eos_token_id)
+ loss_mask_buffer.append(mask)
+ else:
+ subfields = field.split('+')
+ text = self.config.subfield_separator.join(
+ [example[subfield] for subfield in subfields]
+ )
+ if i == 0:
+ text = self.config.prepend_text + text
+ tokens = self.tokenizer.encode(text)
+ token_buffer.extend(tokens)
+ loss_mask_buffer.extend([mask for _ in range(len(tokens))])
+
+ if add_eos_token and self.config.add_eos_token:
+ token_buffer.append(self.tokenizer.eos_token_id)
+ loss_mask_buffer.append(1.0)
+
+ return token_buffer, loss_mask_buffer, *aux
+
+
+class VisionTextProcessor(object):
+ @staticmethod
+ def get_default_config(updates=None):
+ config = ConfigDict()
+ config.fields_from_example = ''
+ config.subfield_separator = ' '
+ config.add_bos_token = True
+ config.add_eos_token = True
+ config.prepend_text = ''
+ config.fields_index = -1
+ config.eof_token = 8192 # denotes end of each frame for video generation
+ config.eov_token = 8193 # denotes end of vision generation
+ config.n_tokens_per_frame = 256 # 16 x 16 VQ codes
+ config.max_n_frames = -1
+ if updates is not None:
+ config.update(ConfigDict(updates).copy_and_resolve_references())
+ return config
+
+ def __init__(self, config, tokenizer):
+ self.config = self.get_default_config(config)
+ assert self.config.fields_from_example != '', (
+ 'fields_from_example must be specified.'
+ )
+ self.tokenizer = tokenizer
+ self.vision_start = tokenizer.encode('')
+ self.vision_end = tokenizer.encode('')
+
+ def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True):
+ if has_aux:
+ example, *aux = example
+ else:
+ aux = tuple()
+ rand_state = random.Random(aux[-1]) # makes augmentations deterministic by line number
+ token_buffer = []
+ loss_mask_buffer = []
+ vision_mask = []
+
+ fields = example[self.config.fields_from_example]
+ if isinstance(fields, (tuple, list)):
+ if self.config.fields_index >= 0:
+ fields = fields[self.config.fields_index]
+ else:
+ # seed based on line number
+ fields = rand_state.choice(fields)
+ fields = fields.split(',')
+
+ if add_bos_token and self.config.add_bos_token:
+ token_buffer.append(self.tokenizer.bos_token_id)
+ loss_mask_buffer.append(0.0)
+ vision_mask.append(False)
+
+ for i, field in enumerate(fields):
+ if field.startswith('[') and field.endswith(']'):
+ # No loss for this field.
+ field = field[1:-1]
+ mask = 0.0
+ else:
+ mask = 1.0
+
+ if field == '<|bos|>':
+ token_buffer.append(self.tokenizer.bos_token_id)
+ loss_mask_buffer.append(mask)
+ vision_mask.append(False)
+ elif field == '<|eos|>':
+ token_buffer.append(self.tokenizer.eos_token_id)
+ loss_mask_buffer.append(mask)
+ vision_mask.append(False)
+ elif 'vision' in field:
+ vision_tokens = example[field]
+ n_frames = int(len(vision_tokens) / self.config.n_tokens_per_frame)
+ if self.config.max_n_frames > 0 and n_frames > self.config.max_n_frames: # uniformly select
+ idxs = np.linspace(0, n_frames - 1, self.config.max_n_frames).astype(int)
+ new_vision_tokens = []
+ for idx in idxs:
+ new_vision_tokens.extend(vision_tokens[idx * self.config.n_tokens_per_frame:(idx + 1) * self.config.n_tokens_per_frame])
+ vision_tokens = new_vision_tokens
+ n_frames = self.config.max_n_frames
+ assert int(len(vision_tokens) / self.config.n_tokens_per_frame) == n_frames, (int(len(vision_tokens) / self.config.n_tokens_per_frame), n_frames)
+
+ assert n_frames > 0, len(vision_tokens)
+ tokens = list(self.vision_start)
+ for j in range(n_frames):
+ tokens.extend(vision_tokens[j*self.config.n_tokens_per_frame:(j+1)*self.config.n_tokens_per_frame])
+ if j == n_frames - 1: # last frame
+ tokens.append(self.config.eov_token)
+ else:
+ tokens.append(self.config.eof_token)
+ tokens.extend(self.vision_end)
+
+ token_buffer.extend(tokens)
+ loss_mask_buffer.extend([mask for _ in range(len(tokens))])
+ vision_mask.extend([False] * len(self.vision_start))
+ vision_mask.extend([True] * (self.config.n_tokens_per_frame * n_frames + n_frames)) # include extra eof/eov token at the end of each frame
+ vision_mask.extend([False] * len(self.vision_end))
+ else:
+ subfields = field.split('+')
+ text = self.config.subfield_separator.join(
+ [example[subfield] for subfield in subfields]
+ )
+ if i == 0:
+ text = self.config.prepend_text + text
+ tokens = self.tokenizer.encode(text)
+ token_buffer.extend(tokens)
+ loss_mask_buffer.extend([mask for _ in range(len(tokens))])
+ vision_mask.extend([False] * len(tokens))
+
+ if add_eos_token and self.config.add_eos_token:
+ token_buffer.append(self.tokenizer.eos_token_id)
+ loss_mask_buffer.append(1.0)
+ vision_mask.append(False)
+
+ assert len(token_buffer) == len(loss_mask_buffer) == len(vision_mask), (len(token_buffer), len(loss_mask_buffer), len(vision_mask))
+ keep = True
+ return token_buffer, loss_mask_buffer, vision_mask, keep, *aux
+
+
+class HuggingfaceDataset(object):
+ """ Huggingface dataset, where the dataset is loaded using the huggingface
+ datasets.load_dataset() function.
+ """
+
+ @staticmethod
+ def get_default_config(updates=None):
+ config = ConfigDict()
+ config.path = 'c4'
+ config.name = 'en'
+ config.split = 'train'
+ config.streaming = False
+ config.seq_length = 1024
+ config.batch_size = 8
+ config.always_start_with_bos = False
+
+ if updates is not None:
+ config.update(ConfigDict(updates).copy_and_resolve_references())
+ return config
+
+ def __init__(self, config, tokenizer, text_processor):
+ self.config = self.get_default_config(config)
+ name = self.config.name if self.config.name != '' else None
+ split = self.config.split if self.config.split != '' else None
+ self._tokenizer = tokenizer
+ self._text_processor = text_processor
+ self._dataset = load_dataset(
+ self.config.path, name, split=split, streaming=self.config.streaming
+ )
+
+ def __iter__(self):
+ chunk_size = self.config.batch_size * self.config.seq_length
+ total_tokens = 0
+ while True:
+ token_buffer = []
+ loss_mask_buffer = []
+ for index, example in enumerate(self._dataset):
+ tokens, loss_masks = self.text_processor(example)
+ token_buffer.extend(tokens)
+ loss_mask_buffer.extend(loss_masks)
+ while len(token_buffer) > chunk_size + 1:
+ total_tokens += chunk_size
+ metrics = {
+ 'dataset_example_index': index,
+ 'dataset_total_tokens': total_tokens,
+ }
+ batch = {
+ 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
+ self.config.batch_size, -1
+ ),
+ 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
+ self.config.batch_size, -1
+ ),
+ 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
+ self.config.batch_size, -1
+ ),
+ }
+ if self.config.always_start_with_bos:
+ batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
+ yield batch, metrics
+ token_buffer = token_buffer[chunk_size:]
+ loss_mask_buffer = loss_mask_buffer[chunk_size:]
+
+ def get_state_dict(self):
+ return dict(config=self.config)
+
+ def load_state_dict(self, state_dict):
+ if 'config' in state_dict:
+ self.config.update(ConfigDict(state_dict['config']))
+
+ @property
+ def seq_length(self):
+ return self.config.seq_length
+
+ @property
+ def tokenizer(self):
+ return self._tokenizer
+
+ @property
+ def text_processor(self):
+ return self._text_processor
+
+ @property
+ def dataset(self):
+ return self._dataset
+
+ @property
+ def vocab_size(self):
+ return len(self._tokenizer)
+
+
+class JsonDataset(object):
+ """ JSON dataset, where each line of the data file contains a JSON
+ dictionary with text fields.
+ """
+
+ @staticmethod
+ def get_default_config(updates=None):
+ config = ConfigDict()
+ config.path = ''
+ config.seq_length = 1024
+ config.batch_size = 8
+ config.always_start_with_bos = False
+ config.start_seek_loc = 0
+ config.example_index_at_start = 0
+ config.tokens_count_at_start = 0
+ config.tokenizer_processes = 1
+ config.tokenizer_parallel_chunk_size = 32
+ config.tokenizer_parallel_batch_size = 1024
+ config.throughput_average_window_size = 200
+ config.pad = False
+ config.use_data_sharded_loader = True
+ config.return_local_batch = False
+
+ if updates is not None:
+ config.update(ConfigDict(updates).copy_and_resolve_references())
+ return config
+
+ def __init__(self, config, tokenizer, text_processor, node_info):
+ self.config = self.get_default_config(config)
+ assert self.config.path != ''
+ self._tokenizer = tokenizer
+ self._text_processor = text_processor
+ self._node_info = node_info
+ self._index = self.config.example_index_at_start
+ self._file_loc = self.config.start_seek_loc
+ self._total_tokens = self.config.tokens_count_at_start
+
+ def parse_json(self, line):
+ if not line or line == '\n':
+ return None
+ try:
+ data = json.loads(line)
+ except json.decoder.JSONDecodeError:
+ print(f'Error parsing json line:\n{line}')
+ return None
+ return data
+
+ def json_iterator(self):
+ index, file_loc = self._index, self._file_loc
+ with open_file(self.config.path, 'r') as fin:
+ fin.seek(file_loc)
+ while True:
+ line = fin.readline()
+ file_loc = fin.tell()
+ if not line: # Reached EOF
+ index = 0
+ fin.seek(0)
+ continue
+
+ data = self.parse_json(line)
+ if data is not None and (not self.config.use_data_sharded_loader or index % self._node_info['dp_node_size'] == self._node_info['dp_node_rank']):
+ # JSON parsing succeeded
+ yield data, file_loc, index
+ index += 1
+
+ def batched(self, iterator, batch_size):
+ batch = []
+ for example in iterator:
+ batch.append(example)
+ if len(batch) == batch_size:
+ yield batch
+ batch = []
+ if len(batch) > 0:
+ yield batch
+
+ def parallel_example_iterator(self):
+ if self.config.tokenizer_processes == 1:
+ for example, loc, index in self.json_iterator():
+ self._file_loc = loc
+ self._index = index
+ yield self.text_processor((example, loc, index), has_aux=True)
+ else:
+ process_pool = Pool(self.config.tokenizer_processes)
+ batched_iterator = self.batched(
+ self.json_iterator(), self.config.tokenizer_parallel_batch_size
+ )
+ with process_pool as pool:
+ map_fn = partial(self.text_processor, has_aux=True)
+ next_batch = pool.map_async(
+ map_fn, next(batched_iterator),
+ chunksize=self.config.tokenizer_parallel_chunk_size
+ )
+ while True:
+ current_batch = next_batch
+ next_batch = pool.map_async(
+ map_fn, next(batched_iterator),
+ chunksize=self.config.tokenizer_parallel_chunk_size
+ )
+ for example in current_batch.get():
+ yield example
+
+ def __iter__(self):
+ global_chunk_size = self.config.batch_size * self.config.seq_length
+ if self.config.use_data_sharded_loader:
+ local_batch_size = self.config.batch_size // self._node_info['dp_node_size']
+ else:
+ local_batch_size = self.config.batch_size
+ chunk_size = local_batch_size * self.config.seq_length
+
+ token_buffer = []
+ loss_mask_buffer = []
+
+ last_time = 0.0
+ step_times = []
+ start_time = time.time()
+ start_tokens = self._total_tokens
+ for tokens, loss_masks, loc, index in self.parallel_example_iterator():
+ self._file_loc = loc
+ self._index = index
+ if self.config.pad:
+ tokens = tokens[:self.config.seq_length + 1]
+ tokens.extend([self._tokenizer.bos_token_id] * (self.config.seq_length + 1 - len(tokens)))
+ loss_masks = loss_masks[:self.config.seq_length + 1]
+ loss_masks.extend([0.0] * (self.config.seq_length + 1 - len(loss_masks)))
+ token_buffer.extend(tokens)
+ loss_mask_buffer.extend(loss_masks)
+ while len(token_buffer) > chunk_size + 1:
+ self._total_tokens += global_chunk_size
+ step_times.append(time.time() - last_time)
+ last_time = time.time()
+ if len(step_times) > self.config.throughput_average_window_size:
+ step_times = step_times[-self.config.throughput_average_window_size:]
+ average_throughput = global_chunk_size / np.mean(step_times)
+ accumulated_throughput = (
+ (self._total_tokens - start_tokens) / (time.time() - start_time)
+ )
+ metrics = {
+ 'dataset_file_loc': loc,
+ 'dataset_example_index': index,
+ 'dataset_total_tokens': self._total_tokens,
+ 'dataset_accumulated_tps': accumulated_throughput,
+ 'dataset_average_tps': average_throughput,
+ }
+ batch = {
+ 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
+ local_batch_size, -1
+ ),
+ 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
+ local_batch_size, -1
+ ),
+ 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
+ local_batch_size, -1
+ ),
+ }
+ batch.update({
+ 'input_vision_masks': np.zeros(batch['input_tokens'].shape, dtype=bool),
+ 'target_vision_masks': np.zeros(batch['input_tokens'].shape, dtype=bool),
+ })
+ if self.config.always_start_with_bos:
+ batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
+
+ if self.config.use_data_sharded_loader and not self.config.return_local_batch:
+ mesh = self._node_info['mesh']
+ sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())
+ sp_nodes_rank = jax.process_index() % sp_nodes_size
+ assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)
+ seq_chunk_size = self.config.seq_length // sp_nodes_size
+ batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}
+ batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))
+
+ yield batch, metrics
+ if self.config.pad:
+ token_buffer, loss_mask_buffer = [], []
+ else:
+ token_buffer = token_buffer[chunk_size:]
+ loss_mask_buffer = loss_mask_buffer[chunk_size:]
+
+ def _make_callback(self, v):
+ return lambda index: v[index]
+
+ def get_state_dict(self):
+ return dict(
+ config=self.config,
+ index=self._index,
+ file_loc=self._file_loc,
+ total_tokens=self._total_tokens,
+ )
+
+ def load_state_dict(self, state_dict):
+ if 'config' in state_dict:
+ self.config.update(ConfigDict(state_dict['config']))
+ self._index = state_dict.get('index', self.config.example_index_at_start)
+ self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
+ self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
+
+ @property
+ def seq_length(self):
+ return self.config.seq_length
+
+ @property
+ def tokenizer(self):
+ return self._tokenizer
+
+ @property
+ def text_processor(self):
+ return self._text_processor
+
+ @property
+ def vocab_size(self):
+ return len(self.tokenizer)
+
+
+class JsonVisionDataset(object):
+ @staticmethod
+ def get_default_config(updates=None):
+ config = ConfigDict()
+ config.path = ''
+ config.seq_length = 384
+ config.batch_size = 4
+ config.always_start_with_bos = False
+ config.start_seek_loc = 0
+ config.example_index_at_start = 0
+ config.tokens_count_at_start = 0
+ config.tokenizer_processes = 1
+ config.tokenizer_parallel_chunk_size = 32
+ config.tokenizer_parallel_batch_size = 1024
+ config.throughput_average_window_size = 200
+ config.use_data_sharded_loader = True
+ config.return_local_batch = False
+ config.mode = 'pad'
+
+ if updates is not None:
+ config.update(ConfigDict(updates).copy_and_resolve_references())
+ return config
+
+ def __init__(self, config, tokenizer, text_processor, node_info):
+ self.config = self.get_default_config(config)
+ assert self.config.path != ''
+ self._node_info = node_info
+ self._tokenizer = tokenizer
+ self._text_processor = text_processor
+ self._index = self.config.example_index_at_start
+ self._file_loc = self.config.start_seek_loc
+ self._total_tokens = 0
+
+ def parse_json(self, line):
+ if not line or line == '\n':
+ return None
+ try:
+ data = json.loads(line)
+ except json.decoder.JSONDecodeError:
+ print(f'Error parsing json line:\n{line}')
+ return None
+ return data
+
+ def json_iterator(self):
+ index, file_loc = self._index, self._file_loc
+ with open_file(self.config.path, 'r', block_size=50 * 2 ** 20) as fin:
+ fin.seek(file_loc)
+ while True:
+ line = fin.readline()
+ file_loc = fin.tell()
+ if not line: # Reached EOF
+ index = 0
+ fin.seek(0)
+ continue
+ if not self.config.use_data_sharded_loader or index % self._node_info['dp_node_size'] == self._node_info['dp_node_rank']:
+ data = self.parse_json(line)
+ if data is not None:
+ # JSON parsing succeeded
+ yield data, file_loc, index
+ index += 1
+
+ def batched(self, iterator, batch_size):
+ batch = []
+ for example in iterator:
+ batch.append(example)
+ if len(batch) == batch_size:
+ yield batch
+ batch = []
+ if len(batch) > 0:
+ yield batch
+
+ def parallel_example_iterator(self):
+ if self.config.tokenizer_processes == 1:
+ for example, loc, index in self.json_iterator():
+ self._file_loc = loc
+ self._index = index
+ yield self.text_processor((example, loc, index), has_aux=True)
+ else:
+ process_pool = Pool(self.config.tokenizer_processes)
+ batched_iterator = self.batched(
+ self.json_iterator(), self.config.tokenizer_parallel_batch_size
+ )
+ with process_pool as pool:
+ map_fn = partial(self.text_processor, has_aux=True)
+ next_batch = pool.map_async(
+ map_fn, next(batched_iterator),
+ chunksize=self.config.tokenizer_parallel_chunk_size
+ )
+ while True:
+ current_batch = next_batch
+ next_batch = pool.map_async(
+ map_fn, next(batched_iterator),
+ chunksize=self.config.tokenizer_parallel_chunk_size
+ )
+ for example in current_batch.get():
+ yield example
+
+ def __iter__(self):
+ if self.config.mode == 'pad':
+ fn = self._iter_pad
+ elif self.config.mode == 'no_pad':
+ fn = self._iter_no_pad
+ else:
+ raise ValueError(f'Unknown mode: {self.config.mode}')
+ return fn()
+
+ def _iter_pad(self):
+ chunk_size = self.config.batch_size * self.config.seq_length
+ if self.config.use_data_sharded_loader:
+ local_batch_size = self.config.batch_size // self._node_info['dp_node_size']
+ else:
+ local_batch_size = self.config.batch_size
+ last_time = 0.0
+ buffer = []
+ step_times = []
+ start_time = time.time()
+ start_tokens = self._total_tokens
+ for tokens, loss_masks, vision_masks, keep, loc, index in self.parallel_example_iterator():
+ if not keep:
+ continue
+ self._file_loc = loc
+ self._index = index
+ buffer.append((tokens, loss_masks, vision_masks))
+ while len(buffer) >= local_batch_size:
+ self._total_tokens += chunk_size
+ step_times.append(time.time() - last_time)
+ last_time = time.time()
+ if len(step_times) > self.config.throughput_average_window_size:
+ step_times = step_times[-self.config.throughput_average_window_size:]
+ average_throughput = chunk_size / np.mean(step_times)
+ accumulated_throughput = (
+ (self._total_tokens - start_tokens) / (time.time() - start_time)
+ )
+ metrics = {
+ 'dataset_file_loc': loc,
+ 'dataset_example_index': index,
+ 'dataset_total_tokens': self._total_tokens,
+ 'dataset_accumulated_tps': accumulated_throughput,
+ 'dataset_average_tps': average_throughput,
+ }
+
+ batch = {
+ 'input_tokens': np.full(
+ (local_batch_size, self.config.seq_length),
+ self._tokenizer.bos_token_id,
+ dtype=np.int32
+ ),
+ 'target_tokens': np.full(
+ (local_batch_size, self.config.seq_length),
+ self._tokenizer.bos_token_id,
+ dtype=np.int32
+ ),
+ 'loss_masks': np.zeros(
+ (local_batch_size, self.config.seq_length),
+ dtype=np.float32
+ ),
+ 'input_vision_masks': np.zeros(
+ (local_batch_size, self.config.seq_length),
+ dtype=bool
+ ),
+ 'target_vision_masks': np.zeros(
+ (local_batch_size, self.config.seq_length),
+ dtype=bool
+ )
+ }
+ for i in range(local_batch_size):
+ tokens, loss_masks, vision_masks = buffer[i]
+ if len(tokens) > self.config.seq_length:
+ tokens = tokens[:self.config.seq_length + 1]
+ loss_masks = loss_masks[1:self.config.seq_length + 1]
+ vision_masks = vision_masks[:self.config.seq_length + 1]
+ input_tokens, target_tokens = tokens[:-1], tokens[1:]
+ input_vision_masks, target_vision_masks = vision_masks[:-1], vision_masks[1:]
+ loss_masks = loss_masks[1:]
+ batch['input_tokens'][i, :len(input_tokens)] = input_tokens
+ batch['target_tokens'][i, :len(target_tokens)] = target_tokens
+ batch['input_vision_masks'][i, :len(input_vision_masks)] = input_vision_masks
+ batch['target_vision_masks'][i, :len(target_vision_masks)] = target_vision_masks
+ batch['loss_masks'][i, :len(loss_masks)] = loss_masks
+
+ if self.config.use_data_sharded_loader and not self.config.return_local_batch:
+ mesh = self._node_info['mesh']
+ sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())
+ sp_nodes_rank = jax.process_index() % sp_nodes_size
+ assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)
+ seq_chunk_size = self.config.seq_length // sp_nodes_size
+ batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}
+ batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))
+ yield batch, metrics
+ buffer = buffer[local_batch_size:]
+
+ def _iter_no_pad(self):
+ global_chunk_size = self.config.batch_size * self.config.seq_length
+ if self.config.use_data_sharded_loader:
+ local_batch_size = self.config.batch_size // self._node_info['dp_node_size']
+ else:
+ local_batch_size = self.config.batch_size
+ chunk_size = local_batch_size * self.config.seq_length
+
+ token_buffer = []
+ loss_mask_buffer = []
+ vision_mask_buffer = []
+
+ last_time = 0.0
+ step_times = []
+ start_time = time.time()
+ start_tokens = self._total_tokens
+ for tokens, loss_masks, vision_masks, keep, loc, index in self.parallel_example_iterator():
+ if not keep:
+ continue
+ self._file_loc = loc
+ self._index = index
+ token_buffer.extend(tokens)
+ loss_mask_buffer.extend(loss_masks)
+ vision_mask_buffer.extend(vision_masks)
+ while len(token_buffer) > chunk_size + 1:
+ self._total_tokens += global_chunk_size
+ step_times.append(time.time() - last_time)
+ last_time = time.time()
+ if len(step_times) > self.config.throughput_average_window_size:
+ step_times = step_times[-self.config.throughput_average_window_size:]
+ average_throughput = global_chunk_size / np.mean(step_times)
+ accumulated_throughput = (
+ (self._total_tokens - start_tokens) / (time.time() - start_time)
+ )
+ metrics = {
+ 'dataset_file_loc': loc,
+ 'dataset_example_index': index,
+ 'dataset_total_tokens': self._total_tokens,
+ 'dataset_accumulated_tps': accumulated_throughput,
+ 'dataset_average_tps': average_throughput,
+ }
+ batch = {
+ 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
+ local_batch_size, -1
+ ),
+ 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
+ local_batch_size, -1
+ ),
+ 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
+ local_batch_size, -1
+ ),
+ 'input_vision_masks': np.array(vision_mask_buffer[:chunk_size], dtype=bool).reshape(
+ local_batch_size, -1
+ ),
+ 'target_vision_masks': np.array(vision_mask_buffer[1:chunk_size + 1], dtype=bool).reshape(
+ local_batch_size, -1
+ ),
+ }
+
+ if self.config.use_data_sharded_loader and not self.config.return_local_batch:
+ mesh = self._node_info['mesh']
+ sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())
+ sp_nodes_rank = jax.process_index() % sp_nodes_size
+ assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)
+ seq_chunk_size = self.config.seq_length // sp_nodes_size
+ batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}
+ batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))
+
+ yield batch, metrics
+ token_buffer = token_buffer[chunk_size:]
+ loss_mask_buffer = loss_mask_buffer[chunk_size:]
+ vision_mask_buffer = vision_mask_buffer[chunk_size:]
+
+
+ def _make_callback(self, v):
+ return lambda index: v[index]
+
+ def get_state_dict(self):
+ return dict(
+ config=self.config,
+ index=self._index,
+ file_loc=self._file_loc,
+ total_tokens=self._total_tokens,
+ )
+
+ def load_state_dict(self, state_dict):
+ if 'config' in state_dict:
+ self.config.update(ConfigDict(state_dict['config']))
+ self._index = state_dict.get('index', self.config.example_index_at_start)
+ self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
+ self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
+
+ @property
+ def seq_length(self):
+ return self.config.seq_length
+
+ @property
+ def tokenizer(self):
+ return self._tokenizer
+
+ @property
+ def text_processor(self):
+ return self._text_processor
+
+ @property
+ def vocab_size(self):
+ return len(self._tokenizer)
diff --git a/lwm/llama.py b/lwm/llama.py
new file mode 100644
index 0000000..d6cbc64
--- /dev/null
+++ b/lwm/llama.py
@@ -0,0 +1,1470 @@
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+import json
+import tempfile
+from functools import partial
+
+import numpy as np
+import jax
+from jax.lib import xla_bridge
+import jax.numpy as jnp
+from jax import lax
+from jax.sharding import PartitionSpec as PS
+from jax.experimental.shard_map import shard_map
+import flax.linen as nn
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.traverse_util import flatten_dict, unflatten_dict
+from flax.linen import partitioning as nn_partitioning
+
+import sentencepiece as spm
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
+from transformers.modeling_flax_utils import FlaxPreTrainedModel
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
+
+from ml_collections import ConfigDict
+from tux import function_args_to_config, load_pickle, open_file, with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
+from lwm.ring_attention import blockwise_ffn, ring_flash_attention_tpu, \
+ ring_attention_standard, ring_attention
+
+
+LLAMA_STANDARD_CONFIGS = {
+ '200m': {
+ 'vocab_size': 32000,
+ 'hidden_size': 1024,
+ 'intermediate_size': 2048,
+ 'num_hidden_layers': 14,
+ 'num_attention_heads': 8,
+ 'max_sequence_length': 2048,
+ 'initializer_range': 0.02,
+ 'rms_norm_eps': 1e-6,
+ 'use_cache': True,
+ 'tie_word_embeddings': False,
+ },
+ '1b': {
+ 'vocab_size': 32000,
+ 'hidden_size': 2048,
+ 'intermediate_size': 5504,
+ 'num_hidden_layers': 22,
+ 'num_attention_heads': 16,
+ 'max_sequence_length': 2048,
+ 'initializer_range': 0.02,
+ 'rms_norm_eps': 1e-6,
+ 'use_cache': True,
+ 'tie_word_embeddings': False,
+ },
+ '3b': {
+ 'vocab_size': 32000,
+ 'hidden_size': 3200,
+ 'intermediate_size': 8640,
+ 'num_hidden_layers': 26,
+ 'num_attention_heads': 32,
+ 'max_sequence_length': 2048,
+ 'initializer_range': 0.02,
+ 'rms_norm_eps': 1e-6,
+ 'use_cache': True,
+ 'tie_word_embeddings': False,
+ },
+ '7b': {
+ 'vocab_size': 32000,
+ 'hidden_size': 4096,
+ 'intermediate_size': 11008,
+ 'num_hidden_layers': 32,
+ 'num_attention_heads': 32,
+ 'max_sequence_length': 4096,
+ 'initializer_range': 0.02,
+ 'rms_norm_eps': 1e-6,
+ 'use_cache': True,
+ 'tie_word_embeddings': False,
+ },
+ '13b': {
+ 'vocab_size': 32000,
+ 'hidden_size': 5120,
+ 'intermediate_size': 13824,
+ 'num_hidden_layers': 40,
+ 'num_attention_heads': 40,
+ 'max_sequence_length': 2048,
+ 'initializer_range': 0.02,
+ 'rms_norm_eps': 1e-6,
+ 'use_cache': True,
+ 'tie_word_embeddings': False,
+ },
+ '30b': {
+ 'vocab_size': 32000,
+ 'hidden_size': 6656,
+ 'intermediate_size': 17920,
+ 'num_hidden_layers': 60,
+ 'num_attention_heads': 52,
+ 'max_sequence_length': 2048,
+ 'initializer_range': 0.02,
+ 'rms_norm_eps': 1e-6,
+ 'use_cache': True,
+ 'tie_word_embeddings': False,
+ },
+ '65b': {
+ 'vocab_size': 32000,
+ 'hidden_size': 8192,
+ 'intermediate_size': 22016,
+ 'num_hidden_layers': 80,
+ 'num_attention_heads': 64,
+ 'max_sequence_length': 2048,
+ 'initializer_range': 0.02,
+ 'rms_norm_eps': 1e-5,
+ 'use_cache': True,
+ 'tie_word_embeddings': False,
+ },
+ 'debug': { # A small model for debugging
+ 'vocab_size': 32000,
+ 'hidden_size': 256,
+ 'intermediate_size': 256,
+ 'num_hidden_layers': 2,
+ 'num_attention_heads': 2,
+ 'max_sequence_length': 2048,
+ 'initializer_range': 0.02,
+ 'rms_norm_eps': 1e-6,
+ 'use_cache': True,
+ 'tie_word_embeddings': False,
+ },
+}
+
+
+class LLaMAConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`~LLaMAModel`]. It is used to instantiate an LLaMA
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the LLaMA-7B.
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`~LLaMAModel`] or [`~TFLLaMAModel`].
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_sequence_length (`int`, *optional*, defaults to 2048):
+ Max sequence length for model (for RoPE computation)
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ Example:
+ ```python
+ >>> from transformers import LLaMAModel, LLaMAConfig
+ >>> # Initializing a LLaMA llama-7b style configuration
+ >>> configuration = LLaMAConfig()
+ >>> # Initializing a model from the llama-7b style configuration
+ >>> model = LLaMAModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "llama"
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ max_sequence_length=4096,
+ orig_sequence_length=4096,
+ rms_norm_eps=1e-6,
+ initializer_range=0.02,
+ use_cache=True,
+ bos_token_id=0,
+ eos_token_id=1,
+ resid_pdrop=0.0,
+ embd_pdrop=0.0,
+ attn_pdrop=0.0,
+ tie_word_embeddings=False,
+ remat_block='',
+ remat_attention='',
+ remat_mlp='',
+ scan_attention=False,
+ scan_mlp=False,
+ scan_query_chunk_size=1024,
+ scan_key_chunk_size=1024,
+ scan_mlp_chunk_size=1024,
+ fcm_min_ratio=0.0,
+ fcm_max_ratio=0.0,
+ scan_layers=True,
+ param_scan_axis=0,
+ mesh_dim=None,
+ use_flash_attention=True,
+ theta=10000,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.initializer_range = initializer_range
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.max_sequence_length = max_sequence_length
+ self.orig_sequence_length = orig_sequence_length
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.remat_block = remat_block
+ self.remat_attention = remat_attention
+ self.remat_mlp = remat_mlp
+ self.scan_attention = scan_attention
+ self.scan_mlp = scan_mlp
+ self.scan_query_chunk_size = scan_query_chunk_size
+ self.scan_key_chunk_size = scan_key_chunk_size
+ self.scan_mlp_chunk_size = scan_mlp_chunk_size
+ self.fcm_min_ratio = fcm_min_ratio
+ self.fcm_max_ratio = fcm_max_ratio
+ self.scan_layers = scan_layers
+ self.param_scan_axis = param_scan_axis
+ self.mesh_dim = mesh_dim
+ self.use_flash_attention = use_flash_attention
+ self.theta = theta
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ @classmethod
+ def get_default_config(cls, updates=None):
+ config = function_args_to_config(cls.__init__)
+
+ if updates is not None:
+ config.update(ConfigDict(updates).copy_and_resolve_references())
+
+ return config
+
+ @staticmethod
+ def get_jax_mesh(axis_dims):
+ return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'tp', 'sp'))
+
+ @staticmethod
+ def get_ranks_and_size(mesh):
+ out = dict(mesh=mesh)
+ mp_size = mesh.shape['tp'] * mesh.shape['sp']
+ mp_node_size = max(1, mp_size // jax.local_device_count())
+ dp_node_size = jax.process_count() // mp_node_size
+ out.update(mp_node_size=mp_node_size,
+ dp_node_size=dp_node_size)
+
+ dp_node_rank = jax.process_index() // mp_node_size
+ mp_node_rank = jax.process_index() % mp_node_size
+ out.update(dp_node_rank=dp_node_rank,
+ mp_node_rank=mp_node_rank)
+ return out
+
+
+ @staticmethod
+ def get_partition_rules(scan_layers=False, scan_axis=0):
+ """ Parition rules for GPTJ. Note that these rules are orderd, so that
+ the beginning rules match first. It is important to use
+ PartitionSpec() instead of None here because JAX does not treat
+ None as a pytree leaf.
+ """
+ if scan_layers:
+ if scan_axis == 0:
+ return (
+ # embeddings
+ ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
+ # atention
+ ("attention/(wq|wk|wv)/kernel", PS(None, ("fsdp", "sp"), "tp")),
+ ("attention/wo/kernel", PS(None, "tp", ("fsdp", "sp"))),
+ # mlp
+ ("feed_forward/w1/kernel", PS(None, ("fsdp", "sp"), "tp")),
+ ("feed_forward/w2/kernel", PS(None, "tp", ("fsdp", "sp"))),
+ ("feed_forward/w3/kernel", PS(None, ("fsdp", "sp"), "tp")),
+ # layer norms
+ ("attention_norm/kernel", PS(None, None)),
+ ("ffn_norm/kernel", PS(None, None)),
+ # output head
+ ("transformer/ln_f/kernel", PS(None)),
+ ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
+ ('.*', PS(None)),
+ )
+ elif scan_axis == 1:
+ return (
+ # embeddings
+ ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
+ # atention
+ ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), None, "tp")),
+ ("attention/wo/kernel", PS("tp", None, ("fsdp", "sp"))),
+ # mlp
+ ("feed_forward/w1/kernel", PS(("fsdp", "sp"), None, "tp")),
+ ("feed_forward/w2/kernel", PS("tp", None, ("fsdp", "sp"))),
+ ("feed_forward/w3/kernel", PS(("fsdp", "sp"), None, "tp")),
+ # layer norms
+ ("attention_norm/kernel", PS(None, None)),
+ ("ffn_norm/kernel", PS(None, None)),
+ # output head
+ ("transformer/ln_f/kernel", PS(None)),
+ ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
+ ('.*', PS(None)),
+ )
+ else:
+ raise ValueError(f"Invalid scan_axis {scan_axis}")
+ else:
+ return (
+ # embeddings
+ ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
+ # atention
+ ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), "tp")),
+ ("attention/wo/kernel", PS("tp", ("fsdp", "sp"))),
+ # mlp
+ ("feed_forward/w1/kernel", PS(("fsdp", "sp"), "tp")),
+ ("feed_forward/w2/kernel", PS("tp", ("fsdp", "sp"))),
+ ("feed_forward/w3/kernel", PS(("fsdp", "sp"), "tp")),
+ # layer norms
+ ("attention_norm/kernel", PS(None)),
+ ("ffn_norm/kernel", PS(None)),
+ # output head
+ ("transformer/ln_f/kernel", PS(None)),
+ ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
+ ('.*', PS(None)),
+ )
+
+ @staticmethod
+ def get_weight_decay_exclusions():
+ return tuple()
+
+ @staticmethod
+ def get_frozen_param_exclusions(freeze_base):
+ if freeze_base:
+ return ("vte", "vision_head")
+ else:
+ return tuple()
+
+ @staticmethod
+ def rng_keys():
+ return ('params', 'dropout', 'fcm')
+
+ @staticmethod
+ def get_tokenizer_config(updates=None):
+ config = ConfigDict()
+ config.vocab_file = ''
+ config.add_bos_token = False
+ config.add_eos_token = False
+
+ if updates is not None:
+ config.update(ConfigDict(updates).copy_and_resolve_references())
+ return config
+
+ @classmethod
+ def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
+ config = cls.get_tokenizer_config(config)
+ assert config.vocab_file != '', 'vocab_file must be specified'
+ tokenizer = LLaMATokenizer(
+ vocab_file=config.vocab_file,
+ add_bos_token=config.add_bos_token,
+ add_eos_token=config.add_eos_token,
+ padding_side=padding_side,
+ truncation_side=truncation_side,
+ )
+ return tokenizer
+
+ @classmethod
+ def load_config(cls, path):
+ if path in LLAMA_STANDARD_CONFIGS:
+ return cls.from_dict(LLAMA_STANDARD_CONFIGS[path])
+ load_type, load_path = path.split('::', 1)
+ if load_type == 'pickle':
+ return cls.from_dict(load_pickle(load_path)['llama_config'])
+ elif load_type == 'json':
+ with open_file(load_path, 'r') as fin:
+ raw_config = fin.read()
+ return cls.from_dict(json.loads(raw_config))
+ else:
+ raise ValueError(f'Unsupported load config type: {load_type}')
+
+
+remat = nn_partitioning.remat
+
+logger = logging.get_logger(__name__)
+
+
+class RMSNorm(nn.Module):
+ dim: int
+ eps: float=1e-6
+ dtype: jnp.dtype=jnp.float32
+ param_dtype: jnp.dtype=jnp.float32
+
+ def setup(self) -> None:
+ self.weight = self.param(
+ 'kernel',
+ nn.initializers.ones,
+ (self.dim,),
+ self.param_dtype,
+ )
+
+ def _norm(self, x: jnp.ndarray) -> jnp.ndarray:
+ return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
+
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
+ x = x.astype(jnp.promote_types(self.dtype, jnp.float32))
+ output = self._norm(x).astype(self.dtype)
+ weight = jnp.asarray(self.weight, self.dtype)
+ return output * weight
+
+
+def precompute_freqs_cis(dim: int, max_position_embedding: int, theta: float=10000.0, dtype: jnp.dtype=jnp.float32) -> jnp.ndarray:
+ freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
+ t = np.arange(max_position_embedding) # type: ignore
+ freqs = np.outer(t, freqs).astype(dtype) # type: ignore
+ sin, cos = np.sin(freqs), np.cos(freqs)
+ freqs_cis = np.complex64(cos + 1j * sin)
+ return jnp.asarray(freqs_cis)
+
+
+def apply_rotary_emb(
+ xq: jnp.ndarray,
+ xk: jnp.ndarray,
+ freqs_cis: jnp.ndarray,
+ dtype: jnp.dtype=jnp.float32,
+) -> Tuple[jnp.ndarray, jnp.ndarray]:
+
+ reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
+ reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
+
+ xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
+ xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
+
+ # add head dim
+ freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
+
+ xq_out = xq_ * freqs_cis
+ xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
+
+ xk_out = xk_ * freqs_cis
+ xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
+
+ return xq_out.astype(dtype), xk_out.astype(dtype)
+
+
+class FlaxLLaMAAttention(nn.Module):
+ config: LLaMAConfig
+ dtype: jnp.dtype=jnp.float32
+ param_dtype: jnp.dtype=jnp.float32
+ precision: Optional[Union[jax.lax.Precision, str]]=None
+
+ def setup(self):
+ config = self.config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+
+ self.wq = nn.Dense(
+ config.num_attention_heads*self.head_dim,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ precision=self.precision,
+ )
+ self.wk = nn.Dense(
+ config.num_attention_heads*self.head_dim,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ precision=self.precision,
+ )
+ self.wv = nn.Dense(
+ config.num_attention_heads*self.head_dim,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ precision=self.precision,
+ )
+ self.wo = nn.Dense(
+ config.hidden_size,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ precision=self.precision,
+ )
+
+ self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
+
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool")
+
+ self.freqs_cis = precompute_freqs_cis(
+ self.head_dim,
+ config.max_sequence_length,
+ theta=config.theta,
+ dtype=self.dtype,
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slighly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ if query.shape[1] == 1:
+ mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim)
+ def fn(cached_key, cached_value, key, value, cur_index):
+ assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape)
+ sp_size = max_length // mesh.shape['sp']
+ axis_index = jax.lax.axis_index('sp')
+ cur_index = cur_index - axis_index * sp_size
+ key, value = jax.lax.cond(
+ jnp.logical_and(cur_index >= 0, cur_index < sp_size),
+ lambda: (
+ cached_key.at[:, cur_index].set(key[:, -1]),
+ cached_value.at[:, cur_index].set(value[:, -1]),
+ ),
+ lambda: (cached_key, cached_value),
+ )
+ return key, value
+ fn = shard_map(
+ fn, mesh=mesh,
+ in_specs=(
+ PS(('dp', 'fsdp'), 'sp', 'tp', None),
+ PS(('dp', 'fsdp'), 'sp', 'tp', None),
+ PS(('dp', 'fsdp'), None, 'tp', None),
+ PS(('dp', 'fsdp'), None, 'tp', None),
+ PS()
+ ),
+ out_specs=(
+ PS(('dp', 'fsdp'), 'sp', 'tp', None),
+ PS(('dp', 'fsdp'), 'sp', 'tp', None)
+ ),
+ check_rep=False
+ )
+ key, value = fn(cached_key.value, cached_value.value, key, value, cur_index)
+ else:
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ segment_ids,
+ position_ids,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ fcm_mask=None,
+ ):
+ xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
+
+ if xq.shape[1] == 1:
+ xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "tp"))
+ else:
+ xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), "sp", "tp"))
+ xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), "sp", "tp"))
+ xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), "sp", "tp"))
+
+ xq = self._split_heads(xq)
+ xk = self._split_heads(xk)
+ xv = self._split_heads(xv)
+
+ freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0)
+
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
+
+ dropout_rng = None
+ if not deterministic and self.config.attn_pdrop > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ if self.config.scan_attention and xq.shape[1] > max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size):
+ # attention mask without nxn materlization, blockwise_attn will handle the rest
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ if self.has_variable("cache", "cached_key") or init_cache:
+ xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
+
+ # transform boolean mask into float mask
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
+ )
+ attn_weights = None
+
+ platform = xla_bridge.get_backend().platform
+ if self.config.use_flash_attention and platform == "tpu":
+ ring_attention_fn = ring_flash_attention_tpu
+ else:
+ ring_attention_fn = ring_attention # uses BPT attention
+ ring_attention_sharded = shard_map(
+ partial(
+ ring_attention_fn,
+ axis_name="sp",
+ float32_logits=True,
+ blockwise_kwargs=dict(
+ deterministic=deterministic,
+ dropout_rng=dropout_rng,
+ attn_pdrop=self.config.attn_pdrop,
+ causal=True,
+ query_chunk_size=self.config.scan_query_chunk_size,
+ key_chunk_size=self.config.scan_key_chunk_size,
+ dtype=self.dtype,
+ policy=get_gradient_checkpoint_policy('nothing_saveable'),
+ precision=self.precision,
+ prevent_cse=not self.config.scan_layers,
+ )
+ ),
+ mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
+ in_specs=(
+ PS(("dp", "fsdp"), "sp", "tp", None),
+ PS(("dp", "fsdp"), "sp", "tp", None),
+ PS(("dp", "fsdp"), "sp", "tp", None),
+ PS(("dp", "fsdp"), None, None, None),
+ PS(("dp", "fsdp"), None),
+ ),
+ out_specs=PS(("dp", "fsdp"), "sp", "tp", None),
+ check_rep=False
+ )
+ attn_output = ring_attention_sharded(xq, xk, xv, attention_bias, segment_ids)
+ attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), "sp", "tp", None))
+ else:
+ query_length, key_length = xq.shape[1], xk.shape[1]
+
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = jnp.arange(max_decoder_length)[None] <= (jnp.arange(query_length) + mask_shift)[:, None]
+ causal_mask = causal_mask[None, None]
+ segment_mask = None
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ segment_mask = segment_ids[:, :, None] == segment_ids[:, None, :]
+ segment_mask = segment_mask[:, None]
+
+ batch_size = hidden_states.shape[0]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask, segment_mask)
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.has_variable("cache", "cached_key") or init_cache:
+ xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
+
+ q_sp_dim = None if xq.shape[1] == 1 else 'sp'
+ attn_weights = None
+ ring_attention_sharded = shard_map(
+ partial(ring_attention_standard, axis_name="sp"), mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
+ in_specs=(
+ PS(("dp", "fsdp"), q_sp_dim, "tp", None),
+ PS(("dp", "fsdp"), "sp", "tp", None),
+ PS(("dp", "fsdp"), "sp", "tp", None),
+ PS(("dp", "fsdp"), None, q_sp_dim, None)
+ ),
+ out_specs=PS(("dp", "fsdp"), q_sp_dim, "tp", None),
+ check_rep=False
+ )
+ attn_output = ring_attention_sharded(
+ xq, xk, xv, attention_mask
+ )
+
+ attn_output = self._merge_heads(attn_output)
+ attn_output = self.wo(attn_output)
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
+ return outputs
+
+
+class FlaxLLaMAMLP(nn.Module):
+ config: LLaMAConfig
+ dtype: jnp.dtype=jnp.float32
+ param_dtype: jnp.dtype=jnp.float32
+ precision: Optional[Union[jax.lax.Precision, str]]=None
+
+ def setup(self) -> None:
+ config = self.config
+
+ self.w1 = nn.Dense(
+ config.intermediate_size,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ precision=self.precision,
+ )
+ self.w2 = nn.Dense(
+ config.hidden_size,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ precision=self.precision,
+ )
+ self.w3 = nn.Dense(
+ config.intermediate_size,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ precision=self.precision,
+ )
+ self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
+
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
+ x = self.w2(nn.silu(self.w1(x)) * self.w3(x))
+ x = self.dropout(x, deterministic=deterministic)
+ return x
+
+
+class FlaxLLaMABlock(nn.Module):
+ config: LLaMAConfig
+ dtype: jnp.dtype=jnp.float32
+ param_dtype: jnp.dtype=jnp.float32
+ precision: Optional[Union[jax.lax.Precision, str]]=None
+
+ def setup(self) -> None:
+ attention_module = FlaxLLaMAAttention
+ mlp_module = FlaxLLaMAMLP
+ if self.config.remat_attention != '':
+ attention_module = remat(
+ FlaxLLaMAAttention, static_argnums=(4, 5, 6),
+ policy=get_gradient_checkpoint_policy(self.config.remat_attention),
+ prevent_cse=not self.config.scan_layers,
+ )
+ if self.config.remat_mlp != '':
+ mlp_module = remat(
+ FlaxLLaMAMLP, static_argnums=(1,),
+ policy=get_gradient_checkpoint_policy(self.config.remat_mlp),
+ prevent_cse=not self.config.scan_layers,
+ )
+
+ self.attention = attention_module(
+ self.config,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ precision=self.precision,
+ )
+ self.feed_forward = mlp_module(
+ self.config,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ precision=self.precision,
+ )
+ self.attention_norm = RMSNorm(
+ self.config.hidden_size,
+ eps=self.config.rms_norm_eps,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ )
+ self.ffn_norm = RMSNorm(
+ self.config.hidden_size,
+ eps=self.config.rms_norm_eps,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ )
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ segment_ids=None,
+ position_ids=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ fcm_mask: Optional[jnp.ndarray] = None,
+ ):
+ attn_outputs = self.attention(
+ self.attention_norm(hidden_states),
+ attention_mask,
+ segment_ids,
+ position_ids,
+ deterministic,
+ init_cache,
+ output_attentions,
+ fcm_mask,
+ )
+ attn_output = attn_outputs[0]
+ hidden_states = hidden_states + attn_output
+
+ feed_forward_input = self.ffn_norm(hidden_states)
+
+ if self.config.scan_mlp and hidden_states.shape[1] >= self.config.scan_mlp_chunk_size:
+ feed_forward_hidden_states = blockwise_ffn(
+ self.feed_forward,
+ feed_forward_input,
+ self.config.scan_mlp_chunk_size,
+ deterministic,
+ )
+ else:
+ feed_forward_hidden_states = self.feed_forward(
+ feed_forward_input,
+ deterministic,
+ )
+ feed_forward_hidden_states = with_sharding_constraint(feed_forward_hidden_states, PS(("dp", "fsdp"), None, "tp"))
+
+ hidden_states = hidden_states + feed_forward_hidden_states
+
+ # return (hidden_states,) + attn_outputs[1:]
+ outputs = hidden_states
+ if self.config.scan_layers:
+ outputs = (outputs, None)
+ return outputs
+
+
+class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LLaMAConfig
+ base_model_prefix = "transformer"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: LLaMAConfig,
+ input_shape: Tuple = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+ segment_ids = jnp.zeros_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ if self.config.add_cross_attention:
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
+ encoder_attention_mask = attention_mask
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ segment_ids,
+ position_ids,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ return_dict=False,
+ )
+ else:
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
+
+ random_params = module_init_outputs["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length))
+ attention_mask = jnp.ones_like(input_ids)
+ segment_ids = jnp.zeros_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True
+ )
+ return init_variables["cache"].unfreeze()
+
+ @add_start_docstrings_to_model_forward("")
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ segment_ids=None,
+ position_ids=None,
+ params: dict = None,
+ past_key_values: dict = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ batch_size, sequence_length = input_ids.shape
+
+ if position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
+
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ if attention_mask is None:
+ attention_mask = jnp.ones((batch_size, sequence_length))
+ if segment_ids is None:
+ segment_ids = jnp.zeros((batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ jnp.array(segment_ids, dtype="i4"),
+ jnp.array(position_ids, dtype="i4"),
+ not train,
+ False,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+class FlaxLLaMABlockCollection(nn.Module):
+ config: LLaMAConfig
+ dtype: jnp.dtype = jnp.float32
+ param_dtype: jnp.dtype=jnp.float32
+ precision: Optional[Union[jax.lax.Precision, str]]=None
+
+ @nn.compact
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ segment_ids=None,
+ position_ids=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ if not deterministic and self.config.fcm_max_ratio > 0:
+ # Apply forgetful causal mask
+ batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
+ fcm_ratio = jax.random.uniform(
+ self.make_rng('fcm'), shape=(batch_size, 1, 1, 1),
+ minval=self.config.fcm_min_ratio,
+ maxval=self.config.fcm_max_ratio
+ )
+ fcm_mask = jax.random.uniform(
+ self.make_rng('fcm'),
+ shape=(batch_size, 1, seq_length, seq_length)
+ ) > fcm_ratio
+ fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
+ fcm_mask = fcm_mask.astype('bool')
+ else:
+ fcm_mask = None
+
+ block = FlaxLLaMABlock
+ if self.config.remat_block != '':
+ block = remat(
+ FlaxLLaMABlock, static_argnums=(4, 5, 6),
+ prevent_cse=not self.config.scan_layers,
+ policy=get_gradient_checkpoint_policy(self.config.remat_block)
+ )
+ if self.config.scan_layers:
+ initializing = self.is_mutable_collection('params')
+ params_spec = (
+ self.config.param_scan_axis if initializing else
+ nn_partitioning.ScanIn(self.config.param_scan_axis))
+ cache_spec = 0
+ hidden_states, _ = nn.scan(
+ block,
+ variable_axes={
+ 'params': params_spec,
+ 'cache': cache_spec,
+ 'intermediates': 0
+ },
+ split_rngs={
+ 'params': True,
+ 'dropout': True
+ },
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
+ length=self.config.num_hidden_layers,
+ metadata_params={nn.PARTITION_NAME: 'scan_decoder_layer'},
+ )(self.config, name='scan_decoder', dtype=self.dtype, param_dtype=self.param_dtype,)(
+ hidden_states,
+ attention_mask,
+ segment_ids,
+ position_ids,
+ deterministic,
+ init_cache,
+ output_attentions,
+ fcm_mask,
+ )
+ else:
+ blocks = [
+ block(
+ self.config,
+ name=str(i),
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ ) for i in range(self.config.num_hidden_layers)
+ ]
+ for block in blocks:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = block(
+ hidden_states,
+ attention_mask,
+ segment_ids,
+ position_ids,
+ deterministic,
+ init_cache,
+ output_attentions,
+ fcm_mask,
+ )
+ hidden_states = layer_outputs
+
+ if output_attentions:
+ all_attentions += (layer_outputs[1],)
+
+ # this contains possible `None` values - `FlaxGPTJModule` will filter them out
+ outputs = (hidden_states, all_hidden_states, all_attentions)
+
+ return outputs
+
+
+class FlaxLLaMAModule(nn.Module):
+ config: LLaMAConfig
+ dtype: jnp.dtype = jnp.float32
+ param_dtype: jnp.dtype=jnp.float32
+ precision: Optional[Union[jax.lax.Precision, str]]=None
+
+ def setup(self):
+ self.embed_dim = self.config.hidden_size
+
+ self.wte = nn.Embed(
+ self.config.vocab_size,
+ self.config.hidden_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ )
+ self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
+ self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
+ self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ segment_ids,
+ position_ids,
+ deterministic=True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ input_embeds = self.wte(input_ids.astype("i4"))
+
+ hidden_states = self.dropout(input_embeds, deterministic=deterministic)
+
+ outputs = self.h(
+ hidden_states,
+ attention_mask,
+ segment_ids=segment_ids,
+ position_ids=position_ids,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = outputs[1] + (hidden_states,)
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=outputs[1],
+ attentions=outputs[-1],
+ )
+
+@add_start_docstrings("", "")
+class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel):
+ module_class = FlaxLLaMAModule
+
+class FlaxLLaMAForCausalLMModule(nn.Module):
+ config: LLaMAConfig
+ dtype: jnp.dtype = jnp.float32
+ param_dtype: jnp.dtype=jnp.float32
+ precision: Optional[Union[jax.lax.Precision, str]]=None
+
+ def setup(self):
+ self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype)
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ precision=self.precision,
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ segment_ids=None,
+ position_ids=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ batch_size, seq_length = input_ids.shape
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+ if segment_ids is None:
+ segment_ids = jnp.zeros_like(input_ids)
+ if position_ids is None:
+ position_ids = jnp.broadcast_to(
+ jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
+ (batch_size, seq_length)
+ )
+ outputs = self.transformer(
+ input_ids,
+ attention_mask,
+ segment_ids,
+ position_ids,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ if not return_dict:
+ return (lm_logits,) + outputs[1:]
+
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
+
+
+@add_start_docstrings("", "")
+class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
+ module_class = FlaxLLaMAForCausalLMModule
+
+ def prepare_inputs_for_generation(
+ self, input_ids, max_length,
+ attention_mask: Optional[jax.Array] = None,
+ ):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since GPTJ uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
+
+PRETRAINED_VOCAB_FILES_MAP = {}
+
+
+class LLaMATokenizer(PreTrainedTokenizer):
+ """
+ Construct a LLaMA tokenizer. Based on byte-level Byte-Pair-Encoding.
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ add_bos_token=False,
+ add_eos_token=False,
+ **kwargs,
+ ):
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
+ self.vocab_file = vocab_file
+ self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+
+ with tempfile.NamedTemporaryFile() as tfile:
+ with open_file(self.vocab_file, 'rb') as fin:
+ tfile.write(fin.read())
+ tfile.flush()
+ tfile.seek(0)
+ self.sp_model.Load(tfile.name)
+ """ Initialisation"""
+ self.add_special_tokens(dict(
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ ))
+ self.pad_token_id = self.unk_token_id
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.sp_model.get_piece_size()
+
+ @property
+ def bos_token_id(self) -> Optional[int]:
+ return self.sp_model.bos_id()
+
+ @property
+ def eos_token_id(self) -> Optional[int]:
+ return self.sp_model.eos_id()
+
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text):
+ """Returns a tokenized string."""
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_model.piece_to_id(token)
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = self.sp_model.IdToPiece(index)
+ return token
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ current_sub_tokens = []
+ out_string = ""
+ prev_is_special = False
+ for token in tokens:
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ if not prev_is_special:
+ out_string += " "
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ prev_is_special = True
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ prev_is_special = False
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string.strip()
+
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ """
+ Save the vocabulary and special tokens file to a directory.
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ if self.add_bos_token:
+ bos_token_ids = [self.bos_token_id]
+ else:
+ bos_token_ids = []
+
+ output = bos_token_ids + token_ids_0
+
+ if token_ids_1 is not None:
+ output = output + token_ids_1
+
+ if self.add_eos_token:
+ output = output + [self.eos_token_id]
+
+ return output
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
+ use of token type ids, therefore a list of zeros is returned.
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ eos = [self.eos_token_id]
+
+ if token_ids_1 is None:
+ return len(token_ids_0 + eos) * [0]
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
diff --git a/lwm/ring_attention.py b/lwm/ring_attention.py
new file mode 100644
index 0000000..a5d38a1
--- /dev/null
+++ b/lwm/ring_attention.py
@@ -0,0 +1,1989 @@
+"""This module contains ring attention forward and backward pass, supporting both blockwise computation and TPU-compatible fused attention.
+It features blockwise computation for feedforward networks to reduce memory cost.
+For more details, refer to 'RingAttention' at https://arxiv.org/abs/2305.19370 and 'Blockwise Parallel Transformers' at https://arxiv.org/abs/2310.01889.
+"""
+
+import numpy as np
+import flax.linen as nn
+import jax
+import jax.lax as lax
+import jax.numpy as jnp
+from jax.experimental import pallas as pl
+from jax.experimental.pallas import tpu as pltpu
+from einops import rearrange
+from functools import partial
+import dataclasses
+import functools
+from typing import Any, NamedTuple
+
+
+def _ring_attention_fwd(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwise_kwargs):
+ if float32_logits:
+ q, k = q.astype(jnp.float32), k.astype(jnp.float32)
+ batch, q_len, num_heads, dim_per_head = q.shape
+ batch, kv_len, num_heads, dim_per_head = k.shape
+ numerator = jnp.zeros((batch, q_len, num_heads, dim_per_head)).astype(q.dtype)
+ denominator = jnp.zeros((batch, num_heads, q_len)).astype(q.dtype)
+ axis_size = lax.psum(1, axis_name)
+ q_block_size, kv_block_size = q_len, kv_len # assumes this function is pre-sharded inside shard_map
+ query_chunk_size = blockwise_kwargs["query_chunk_size"]
+ key_chunk_size = blockwise_kwargs["key_chunk_size"]
+
+ def scan_kv_block(carry, idx):
+ prev_max_score, numerator, denominator, k, v = carry
+ q_block_idx = lax.axis_index(axis_name)
+ k_block_idx = (lax.axis_index(axis_name) - idx) % axis_size
+ q_chunk_idx_start = q_block_idx * (q_block_size // query_chunk_size)
+ k_chunk_idx_start = k_block_idx * (kv_block_size // key_chunk_size)
+ numerator, denominator, max_score = _blockwise_attention_fwd(q, k, v, (numerator, denominator, prev_max_score), q_chunk_idx_start, k_chunk_idx_start, bias=attn_bias, segment_ids=segment_ids, **blockwise_kwargs)
+ k, v = map(lambda x: lax.ppermute(x, axis_name, perm=[(i, (i + 1) % axis_size) for i in range(axis_size)]), (k, v))
+ return (max_score, numerator, denominator, k, v), None
+ prev_max_score = jnp.full((batch, num_heads, q_len), -jnp.inf).astype(q.dtype)
+ (max_score, numerator, denominator, _, _), _ = lax.scan(scan_kv_block,
+ init=(prev_max_score, numerator, denominator, k, v), xs=jnp.arange(0, axis_size))
+ output = numerator / rearrange(denominator, 'b h q -> b q h')[..., None]
+ return output.astype(v.dtype), (output, q, k, v, attn_bias, segment_ids, denominator, max_score)
+
+def _ring_attention_bwd(axis_name, float32_logits, blockwise_kwargs, res, g):
+ del float32_logits
+ output, q, k, v, attn_bias, segment_ids, denominator, max_score = res
+ batch, q_len, num_heads, dim_per_head = q.shape
+ batch, kv_len, num_heads, dim_per_head = k.shape
+ axis_size = lax.psum(1, axis_name)
+ dq = jnp.zeros_like(q, dtype=q.dtype)
+ dk = jnp.zeros_like(k, dtype=k.dtype)
+ dv = jnp.zeros_like(v, dtype=k.dtype)
+ query_chunk_size = blockwise_kwargs["query_chunk_size"]
+ key_chunk_size = blockwise_kwargs["key_chunk_size"]
+ q_block_size, kv_block_size = q_len, kv_len # assumes this function is pre-sharded inside shard_map
+ def scan_kv_block(carry, idx):
+ dq, dk, dv, k, v = carry
+ q_block_idx = lax.axis_index(axis_name)
+ k_block_idx = (lax.axis_index(axis_name) - idx) % axis_size
+ q_chunk_idx_start = q_block_idx * (q_block_size // query_chunk_size)
+ k_chunk_idx_start = k_block_idx * (kv_block_size // key_chunk_size)
+ dq, dk, dv = _blockwise_attention_bwd(q, k, v, g, (dq, dk, dv, output, denominator, max_score), q_chunk_idx_start, k_chunk_idx_start, bias=attn_bias, segment_ids=segment_ids, **blockwise_kwargs)
+ k, v, dk, dv = map(lambda x: lax.ppermute(x, axis_name, perm=[(i,
+ (i + 1) % axis_size) for i in range(axis_size)]), (k, v, dk, dv))
+ return (dq, dk, dv, k, v), None
+ (dq, dk, dv, k, v), _ = lax.scan(scan_kv_block, init=(dq, dk, dv, k, v), xs=jnp.arange(0, axis_size))
+ dq, dk, dv = dq.astype(q.dtype), dk.astype(k.dtype), dv.astype(k.dtype)
+ return dq, dk, dv, None, None
+
+@partial(jax.custom_vjp, nondiff_argnums=[5, 6, 7])
+def ring_attention(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwise_kwargs):
+ y, _ = _ring_attention_fwd(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwise_kwargs)
+ return y
+
+ring_attention.defvjp(_ring_attention_fwd, _ring_attention_bwd)
+
+
+def _ring_attention_standard_fwd(q, k, v, attn_mask, axis_name, float32_logits):
+ if float32_logits:
+ q, k = q.astype(jnp.float32), k.astype(jnp.float32)
+ batch, q_len, num_heads, _ = q.shape
+ batch, kv_len, num_heads, dim_per_head = k.shape
+ numerator = jnp.zeros((batch, q_len, num_heads, dim_per_head)).astype(q.dtype)
+ denominator = jnp.zeros((batch, num_heads, q_len)).astype(q.dtype)
+ axis_size = lax.psum(1, axis_name)
+ scale = jnp.sqrt(q.shape[-1])
+ def scan_kv_block(carry, idx):
+ prev_max_score, numerator, denominator, k, v = carry
+ mask = lax.dynamic_slice_in_dim(attn_mask,
+ (lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1)
+ attn_weights = jnp.einsum("bqhd,bkhd->bhqk", q, k) / scale
+ attn_weights = jnp.where(mask, attn_weights, jnp.finfo(attn_weights.dtype).min)
+ max_score = jnp.maximum(prev_max_score, jnp.max(attn_weights, axis=-1))
+ exp_weights = jnp.exp(attn_weights - max_score[..., None])
+ correction = rearrange(jnp.exp(prev_max_score - max_score), 'b h q -> b q h')[..., None]
+ numerator = numerator * correction + jnp.einsum("bhqk,bkhd->bqhd", exp_weights, v)
+ denominator = denominator * jnp.exp(prev_max_score - max_score) + jnp.sum(exp_weights, axis=-1)
+ k, v = map(lambda x: lax.ppermute(x, axis_name, perm=[(i,
+ (i + 1) % axis_size) for i in range(axis_size)]), (k, v))
+ return (max_score, numerator, denominator, k, v), None
+ prev_max_score = jnp.full((batch, num_heads, q_len), -jnp.inf).astype(q.dtype)
+ (max_score, numerator, denominator, _, _), _ = lax.scan(scan_kv_block,
+ init=(prev_max_score, numerator, denominator, k, v), xs=jnp.arange(0, axis_size))
+ output = numerator / rearrange(denominator, 'b h q -> b q h')[..., None]
+ return output.astype(v.dtype), (output, q, k, v, attn_mask, numerator, denominator, max_score)
+
+def _ring_attention_standard_bwd(axis_name, float32_logits, res, g):
+ del float32_logits
+ axis_size = lax.psum(1, axis_name)
+ output, q, k, v, attn_mask, numerator, denominator, max_score = res
+ dq = jnp.zeros_like(q, dtype=jnp.float32)
+ dk = jnp.zeros_like(k, dtype=jnp.float32)
+ dv = jnp.zeros_like(v, dtype=jnp.float32)
+ batch, kv_len, num_heads, dim_per_head = k.shape
+ scale = jnp.sqrt(q.shape[-1])
+ def scan_kv_block(carry, idx):
+ dq, dk, dv, k, v = carry
+ mask = lax.dynamic_slice_in_dim(attn_mask,
+ (lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1)
+ attn_weights = jnp.einsum("bqhd,bkhd->bhqk", q, k) / scale
+ attn_weights = jnp.where(mask, attn_weights, jnp.finfo(attn_weights.dtype).min)
+ exp_weights = jnp.exp(attn_weights - max_score[..., None]) / denominator[..., None]
+ ds = jnp.einsum("bqhd,bkhd->bhqk", g, v)
+ dl = (ds - jnp.einsum("bqhd,bqhd->bhq", g, output)[..., None]) * exp_weights
+ dq = dq + jnp.einsum("bhqk,bkhd->bqhd", dl, k) / scale
+ dk = dk + jnp.einsum("bqhd,bhqk->bkhd", q, dl) / scale
+ dv = dv + jnp.einsum("bhqk,bqhd->bkhd", exp_weights, g)
+ k, v, dk, dv = map(lambda x: lax.ppermute(x, axis_name, perm=[(i,
+ (i + 1) % axis_size) for i in range(axis_size)]), (k, v, dk, dv))
+ return (dq, dk, dv, k, v), None
+ (dq, dk, dv, k, v), _ = lax.scan(scan_kv_block, init=(dq, dk, dv, k, v), xs=jnp.arange(0, axis_size))
+ dq, dk, dv = dq.astype(q.dtype), dk.astype(k.dtype), dv.astype(v.dtype)
+ return dq, dk, dv, None
+
+@partial(jax.custom_vjp, nondiff_argnums=[4, 5])
+def ring_attention_standard(q, k, v, attn_mask, axis_name, float32_logits=True):
+ y, _ = _ring_attention_standard_fwd(q, k, v, attn_mask, axis_name, float32_logits)
+ return y
+
+ring_attention_standard.defvjp(_ring_attention_standard_fwd, _ring_attention_standard_bwd)
+
+
+def _blockwise_attention_fwd(q, k, v, carry, q_chunk_idx_start, k_chunk_idx_start, bias, segment_ids, causal, query_chunk_size,
+ key_chunk_size, deterministic, dropout_rng, attn_pdrop, dtype, policy, precision, prevent_cse):
+ batch, q_len, num_heads, dim_per_head = q.shape
+ batch, kv_len, num_heads, dim_per_head = k.shape
+ batch, kv_len, num_heads, dim_per_head = v.shape
+ num_q = q_len // query_chunk_size
+ num_kv = kv_len // key_chunk_size
+ q = q.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
+ k = k.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
+ v = v.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
+ q, k, v = map(lambda x: jnp.moveaxis(x, 1, 0), (q, k, v))
+
+ numerator, denominator, max_score = carry
+ numerator = numerator.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
+ numerator = jnp.moveaxis(numerator, 1, 0)
+ denominator = denominator.reshape((batch, num_heads, num_q, query_chunk_size))
+ max_score = max_score.reshape((batch, num_heads, num_q, query_chunk_size))
+ denominator, max_score = map(lambda x: rearrange(x, 'b h n c -> n b h c'), (denominator, max_score))
+
+ scale = jnp.sqrt(q.shape[-1])
+ if not deterministic and attn_pdrop > 0.0:
+ attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
+ attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
+ else:
+ attn_dropout = None
+ _chunk_bias_fn = partial(
+ _chunk_attention_bias,
+ query_chunk_size, key_chunk_size, bias, segment_ids, deterministic,
+ attn_dropout, attn_pdrop, causal, dtype)
+ def scan_attention(_, scan):
+ q_chunk, numerator_chunk, denominator_chunk, max_score_chunk, q_chunk_idx = scan
+ @partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
+ def scan_kv_block(carry, scan):
+ k_chunk, value_chunk, k_chunk_idx = scan
+ numerator_chunk, denominator_chunk, prev_max_score_chunk = carry
+ attn_weights = jnp.einsum('bqhd,bkhd->bhqk', q_chunk, k_chunk, precision=precision) / scale
+ bias_chunk = _chunk_bias_fn(q_chunk_idx_start + q_chunk_idx, k_chunk_idx_start + k_chunk_idx)
+ attn_weights = attn_weights + bias_chunk
+
+ max_score_chunk = jnp.maximum(prev_max_score_chunk, jnp.max(attn_weights, axis=-1))
+ max_score_chunk = lax.stop_gradient(max_score_chunk)
+ exp_weights = jnp.exp(attn_weights - max_score_chunk[..., None])
+ exp_values = jnp.einsum('bhqk,bkhd->bqhd', exp_weights, value_chunk, precision=precision)
+ correction = rearrange(jnp.exp(prev_max_score_chunk - max_score_chunk), 'b h q -> b q h')[..., None]
+ numerator_chunk = numerator_chunk * correction + exp_values
+ denominator_chunk = denominator_chunk * jnp.exp(prev_max_score_chunk - max_score_chunk) + exp_weights.sum(axis=-1)
+ return (numerator_chunk, denominator_chunk, max_score_chunk), None
+
+ def skip_upper_half(carry, args):
+ key_chunk, value_chunk, k_chunk_idx = args
+ skip_block = jnp.array(False)
+ if causal:
+ skip_block = q_chunk_idx_start + q_chunk_idx < k_chunk_idx_start + k_chunk_idx
+ return jax.lax.cond(
+ skip_block,
+ lambda carry, args: (carry, None),
+ scan_kv_block,
+ carry,
+ args
+ )
+
+ (numerator_chunk, denominator_chunk, max_score_chunk), _ = lax.scan(
+ skip_upper_half, init=(numerator_chunk, denominator_chunk, max_score_chunk), xs=(k, v, jnp.arange(0, num_kv))
+ )
+ output_chunk = numerator_chunk / rearrange(denominator_chunk, 'b h q -> b q h')[..., None].astype(dtype)
+ return (), (output_chunk, numerator_chunk, denominator_chunk, max_score_chunk)
+ _, (_, numerator, denominator, max_score) = lax.scan(scan_attention, init=(), xs=(q, numerator, denominator, max_score, jnp.arange(0, num_q)))
+
+ numerator = jnp.moveaxis(numerator, 1, 0)
+ numerator = numerator.reshape((batch, q_len, num_heads, dim_per_head))
+ denominator, max_score = map(lambda x: rearrange(x, 'n b h c -> b h n c'), (denominator, max_score))
+ denominator = denominator.reshape((batch, num_heads, q_len))
+ max_score = max_score.reshape((batch, num_heads, q_len))
+
+ return numerator, denominator, max_score
+
+def _blockwise_attention_bwd(q, k, v, g, carry, q_chunk_idx_start, k_chunk_idx_start, bias, segment_ids, causal, query_chunk_size, key_chunk_size, deterministic, dropout_rng, attn_pdrop, dtype, policy, precision, prevent_cse):
+ batch, q_len, num_heads, dim_per_head = q.shape
+ batch, kv_len, num_heads, dim_per_head = k.shape
+ batch, kv_len, num_heads, dim_per_head = v.shape
+ num_q = q_len // query_chunk_size
+ num_kv = kv_len // key_chunk_size
+ dq, dk, dv, output, denominator, max_score = carry
+
+ g = g.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
+ dq = dq.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
+ dk = dk.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
+ dv = dv.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
+ output = output.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
+ g, dq, dk, dv, output = map(lambda x: jnp.moveaxis(x, 1, 0), (g, dq, dk, dv, output))
+
+ denominator = denominator.reshape((batch, num_heads, num_q, query_chunk_size))
+ max_score = max_score.reshape((batch, num_heads, num_q, query_chunk_size))
+ denominator, max_score = map(lambda x: rearrange(x, 'b h n c -> n b h c'), (denominator, max_score))
+
+ q = q.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
+ k = k.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
+ v = v.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
+ q, k, v = map(lambda x: jnp.moveaxis(x, 1, 0), (q, k, v))
+
+ scale = jnp.sqrt(q.shape[-1])
+ if not deterministic and attn_pdrop > 0.0:
+ attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
+ attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
+ else:
+ attn_dropout = None
+ _chunk_bias_fn = partial(
+ _chunk_attention_bias,
+ query_chunk_size, key_chunk_size, bias, segment_ids, deterministic,
+ attn_dropout, attn_pdrop, causal, dtype)
+ def scan_attention(carry, scan):
+ dk, dv = carry
+ q_chunk, dq_chunk, g_chunk, output_chunk, denominator_chunk, max_score_chunk, q_chunk_idx = scan
+ dl_part = jnp.einsum("bqhd,bqhd->bhq", g_chunk, output_chunk)[..., None]
+ @partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
+ def scan_kv_block(carry, scan):
+ k_chunk, value_chunk, k_chunk_idx = scan
+ dq_chunk = carry
+ attn_weights = jnp.einsum('bqhd,bkhd->bhqk', q_chunk, k_chunk, precision=precision) / scale
+ bias_chunk = _chunk_bias_fn(q_chunk_idx_start + q_chunk_idx, k_chunk_idx_start + k_chunk_idx)
+ attn_weights = attn_weights + bias_chunk
+ exp_weights = jnp.exp(attn_weights - max_score_chunk[..., None]) / denominator_chunk[..., None]
+
+ ds = jnp.einsum("bqhd,bkhd->bhqk", g_chunk, value_chunk)
+ dl = (ds - dl_part) * exp_weights
+ dq_chunk = dq_chunk + jnp.einsum("bhqk,bkhd->bqhd", dl, k_chunk) / scale
+ dk_chunk = jnp.einsum("bqhd,bhqk->bkhd", q_chunk, dl) / scale
+ dv_chunk = jnp.einsum("bhqk,bqhd->bkhd", exp_weights, g_chunk)
+ return dq_chunk, (dk_chunk, dv_chunk)
+
+ def skip_upper_half(carry, args):
+ key_chunk, value_chunk, k_chunk_idx = args
+ skip_block = jnp.array(False)
+ if causal:
+ skip_block = q_chunk_idx_start + q_chunk_idx < k_chunk_idx_start + k_chunk_idx
+ return lax.cond(
+ skip_block,
+ lambda carry, args: (
+ carry, (
+ jnp.zeros((batch, key_chunk_size, num_heads, dim_per_head), dtype=dk.dtype),
+ jnp.zeros((batch, key_chunk_size, num_heads, dim_per_head), dtype=dk.dtype),
+ )
+ ),
+ scan_kv_block,
+ carry,
+ args
+ )
+
+ dq_chunk, (dk_part, dv_part) = lax.scan(
+ skip_upper_half, init=dq_chunk, xs=(k, v, jnp.arange(0, num_kv))
+ )
+ return (dk + dk_part, dv + dv_part), dq_chunk
+ (dk, dv), dq = lax.scan(scan_attention, init=(dk, dv), xs=(q, dq, g, output, denominator, max_score, jnp.arange(0, num_q)))
+
+ dq, dk, dv = map(lambda x: jnp.moveaxis(x, 1, 0), (dq, dk, dv))
+ dq = dq.reshape((batch, q_len, num_heads, dim_per_head))
+ dk = dk.reshape((batch, kv_len, num_heads, dim_per_head))
+ dv = dv.reshape((batch, kv_len, num_heads, dim_per_head))
+
+ return dq, dk, dv
+
+
+# Blockwise feedforward network for memory-efficient training
+def blockwise_ffn(remat_ffn, inputs, chunk_size, deterministic):
+ inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
+ def scan_ffn(remat_ffn, carry, hidden_states):
+ outputs = remat_ffn(hidden_states, deterministic=deterministic)
+ return carry, outputs
+ scan_axis = inputs.ndim - 2
+ _, output = nn.scan(
+ scan_ffn,
+ variable_broadcast="params",
+ split_rngs={"params": False, "dropout": True},
+ in_axes=scan_axis,
+ out_axes=scan_axis,
+ )(remat_ffn, None, inputs)
+ output = rearrange(output, 'b c n d -> b (c n) d')
+ return output
+
+
+def _chunk_attention_bias(query_chunk_size, key_chunk_size,
+ bias, segment_ids, deterministic, attn_dropout, attn_pdrop, causal,
+ dtype, query_chunk_idx, key_chunk_idx):
+ query_offset = query_chunk_idx * query_chunk_size
+ key_offset = key_chunk_idx * key_chunk_size
+ chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
+ if bias is not None:
+ chunk_bias = lax.dynamic_slice(
+ bias,
+ start_indices=(0, 0, 0, key_offset),
+ slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
+ )
+
+ if segment_ids is not None:
+ q_segment_ids = lax.dynamic_slice(
+ segment_ids,
+ start_indices=(0, query_offset),
+ slice_sizes=(segment_ids.shape[0], query_chunk_size)
+ )
+ k_segment_ids = lax.dynamic_slice(
+ segment_ids,
+ start_indices=(0, key_offset),
+ slice_sizes=(segment_ids.shape[0], key_chunk_size)
+ )
+ segment_ids_mask = q_segment_ids[:, :, None] != k_segment_ids[:, None, :]
+ segment_ids_mask = segment_ids_mask[:, None] # B1QK
+ segment_ids_bias = segment_ids_mask * jnp.finfo(dtype).min
+ chunk_bias += segment_ids_bias
+
+ if causal:
+ query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
+ key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
+ offset = query_offset - key_offset
+ query_idx += offset
+ causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
+ chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
+
+ if not deterministic and attn_pdrop > 0.0:
+ attn_dropout_slice = lax.dynamic_slice(
+ attn_dropout,
+ start_indices=(0, 0, query_offset, key_offset),
+ slice_sizes=(
+ *attn_dropout.shape[:2],
+ min(attn_dropout.shape[-2], query_chunk_size),
+ min(attn_dropout.shape[-1], key_chunk_size),
+ ),
+ )
+ chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
+ return chunk_bias.astype(dtype)
+
+
+def _ring_flash_attention_fwd_tpu(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwise_kwargs):
+ if float32_logits:
+ q, k = q.astype(jnp.float32), k.astype(jnp.float32)
+ q, k, v = map(lambda x: rearrange(x, 'b q h d -> b h q d'), [q, k, v])
+ batch, num_heads, q_len, dim_per_head = q.shape
+ batch, num_heads, kv_len, dim_per_head = k.shape
+ attn_bias = attn_bias[:, 0, 0] # (batch, k_len)
+
+ o = jnp.zeros((batch, num_heads, q_len, dim_per_head)).astype(q.dtype)
+ l = jnp.zeros((batch, num_heads, q_len)).astype(q.dtype)
+ m = jnp.full((batch, num_heads, q_len), -jnp.inf).astype(q.dtype)
+
+ axis_size = lax.psum(1, axis_name)
+ q_block_size, kv_block_size = q_len, kv_len # assumes this function is pre-sharded inside shard_map
+ query_chunk_size = blockwise_kwargs["query_chunk_size"]
+ key_chunk_size = blockwise_kwargs["key_chunk_size"]
+ if segment_ids is not None:
+ q_segment_ids = lax.dynamic_slice_in_dim(
+ segment_ids,
+ lax.axis_index(axis_name) * q_len, q_len, axis=-1
+ )
+
+ block_sizes = BlockSizes(
+ block_q=query_chunk_size,
+ block_k_major=key_chunk_size,
+ block_k=key_chunk_size,
+ block_b=1,
+ block_q_major_dkv=query_chunk_size,
+ block_k_major_dkv=key_chunk_size,
+ block_k_dkv=key_chunk_size,
+ block_q_dkv=query_chunk_size,
+ block_k_major_dq=key_chunk_size,
+ block_k_dq=key_chunk_size,
+ block_q_dq=query_chunk_size,
+ )
+
+ scale = q.shape[-1] ** -0.5
+ def scan_kv_block(carry, idx):
+ o, l, m, k, v = carry
+ attn_bias_slice = lax.dynamic_slice_in_dim(attn_bias,
+ (lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1
+ )
+ attn_bias_slice = None # TODO
+ if segment_ids is not None:
+ kv_segment_ids = lax.dynamic_slice_in_dim(
+ segment_ids,
+ (lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1
+ )
+ segment_ids_slice = SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
+ else:
+ segment_ids_slice = None
+ q_block_idx = lax.axis_index(axis_name)
+ k_block_idx = (lax.axis_index(axis_name) - idx) % axis_size
+ q_chunk_idx_start = q_block_idx * (q_block_size // query_chunk_size)
+ k_chunk_idx_start = k_block_idx * (kv_block_size // key_chunk_size)
+ o, l, m = _flash_attention_fwd(
+ q, k, v,
+ carry=(o, l, m),
+ q_chunk_idx_start=q_chunk_idx_start,
+ k_chunk_idx_start=k_chunk_idx_start,
+ ab=attn_bias_slice,
+ segment_ids=segment_ids_slice,
+ save_residuals=False,
+ causal=blockwise_kwargs["causal"],
+ sm_scale=scale,
+ block_sizes=block_sizes,
+ debug=False
+ )
+ k, v = map(lambda x: lax.ppermute(x, axis_name, perm=[(i, (i + 1) % axis_size) for i in range(axis_size)]), (k, v))
+ return (o, l, m, k, v), None
+ (o, l, m, _, _), _ = lax.scan(scan_kv_block,
+ init=(o, l, m, k, v), xs=jnp.arange(0, axis_size))
+ output = rearrange(o.astype(v.dtype), 'b h q d -> b q h d')
+ return output, (o, q, k, v, attn_bias, segment_ids, l, m)
+
+def _ring_flash_attention_bwd_tpu(axis_name, float32_logits, blockwise_kwargs, res, g):
+ del float32_logits
+ o, q, k, v, attn_bias, segment_ids, l, m = res
+ batch, num_heads, kv_len, dim_per_head = k.shape
+ axis_size = lax.psum(1, axis_name)
+ dq = jnp.zeros_like(q, dtype=jnp.float32)
+ dk = jnp.zeros_like(k, dtype=jnp.float32)
+ dv = jnp.zeros_like(v, dtype=jnp.float32)
+ query_chunk_size = blockwise_kwargs["query_chunk_size"]
+ key_chunk_size = blockwise_kwargs["key_chunk_size"]
+ q_block_size, kv_block_size = q.shape[2], k.shape[2] # assumes this function is pre-sharded inside shard_map
+ scale = q.shape[-1] ** -0.5
+
+ if segment_ids is not None:
+ q_segment_ids = lax.dynamic_slice_in_dim(
+ segment_ids,
+ lax.axis_index(axis_name) * q_block_size, q_block_size, axis=-1
+ )
+
+ g = rearrange(g, 'b q h d -> b h q d')
+
+ block_sizes = BlockSizes(
+ block_q=query_chunk_size,
+ block_k_major=key_chunk_size,
+ block_k=key_chunk_size,
+ block_b=1,
+ block_q_major_dkv=query_chunk_size,
+ block_k_major_dkv=key_chunk_size,
+ block_k_dkv=key_chunk_size,
+ block_q_dkv=query_chunk_size,
+ block_k_major_dq=key_chunk_size,
+ block_k_dq=key_chunk_size,
+ block_q_dq=query_chunk_size,
+ )
+
+ def scan_kv_block(carry, idx):
+ dq, dk, dv, k, v = carry
+ attn_bias_slice = lax.dynamic_slice_in_dim(attn_bias,
+ (lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1
+ )
+ attn_bias_slice = None # TODO
+ if segment_ids is not None:
+ kv_segment_ids = lax.dynamic_slice_in_dim(
+ segment_ids,
+ (lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1
+ )
+ segment_ids_slice = SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
+ else:
+ segment_ids_slice = None
+ q_block_idx = lax.axis_index(axis_name)
+ k_block_idx = (lax.axis_index(axis_name) - idx) % axis_size
+ q_chunk_idx_start = q_block_idx * (q_block_size // query_chunk_size)
+ k_chunk_idx_start = k_block_idx * (kv_block_size // key_chunk_size)
+ dq_i, dk_i, dv_i, = _flash_attention_bwd(
+ save_residuals=False,
+ causal=blockwise_kwargs["causal"],
+ sm_scale=scale,
+ block_sizes=block_sizes,
+ debug=False,
+ q_chunk_idx_start=q_chunk_idx_start,
+ k_chunk_idx_start=k_chunk_idx_start,
+ residuals=(q, k, v, attn_bias_slice, segment_ids_slice, o, l, m),
+ do=g
+ )
+ dq += dq_i
+ dk += dk_i
+ dv += dv_i
+ k, v, dk, dv = map(lambda x: lax.ppermute(x, axis_name, perm=[(i,
+ (i + 1) % axis_size) for i in range(axis_size)]), (k, v, dk, dv))
+ return (dq, dk, dv, k, v), None
+ (dq, dk, dv, k, v), _ = lax.scan(scan_kv_block, init=(dq, dk, dv, k, v), xs=jnp.arange(0, axis_size))
+ dq, dk, dv = dq.astype(q.dtype), dk.astype(k.dtype), dv.astype(v.dtype)
+ dq, dk, dv = map(lambda x: rearrange(x, 'b h q d -> b q h d'), (dq, dk, dv))
+ return dq, dk, dv, None, None
+
+@partial(jax.custom_vjp, nondiff_argnums=[5, 6, 7])
+def ring_flash_attention_tpu(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwise_kwargs):
+ y, _ = _ring_flash_attention_fwd_tpu(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwise_kwargs)
+ return y
+
+ring_flash_attention_tpu.defvjp(_ring_flash_attention_fwd_tpu, _ring_flash_attention_bwd_tpu)
+
+# TPU-compatible fused attention functions for RingAttention
+DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
+NUM_LANES = 128
+NUM_SUBLANES = 8
+
+class SegmentIds(NamedTuple):
+ """SegmentIds for Q and KV sequences.
+
+ SegmentIds are used to generate segment mask, which prevents attention between
+ different segments in the input sequence. Each array is a list of ids
+ (integers).
+ Only the token with the same id can attend to each other.
+
+ Attributes:
+ q: segment ids along the Q sequence.
+ kv: segment ids along the KV sequence.
+ """
+
+ q: jax.Array # [q_seq_len]
+ kv: jax.Array # [kv_seq_len]
+
+
+@dataclasses.dataclass(frozen=True)
+class BlockSizes:
+ block_q: int
+ block_k_major: int
+ block_k: int
+ block_b: int
+
+ block_q_major_dkv: int | None = None
+ block_k_major_dkv: int | None = None
+ block_k_dkv: int | None = None
+ block_q_dkv: int | None = None
+
+ block_k_major_dq: int | None = None
+ block_k_dq: int | None = None
+ block_q_dq: int | None = None
+
+ def __post_init__(self):
+ def verify_major_minor(prefix, suffix, major, minor):
+ if minor > major:
+ raise ValueError(
+ f"{prefix}{suffix}={minor} should be smaller than"
+ f" {prefix}_major{suffix}={major}"
+ )
+ if major % minor != 0:
+ raise ValueError(
+ f"{prefix}{suffix}={minor} should divide"
+ f" {prefix}_major{suffix}={major}"
+ )
+
+ verify_major_minor("block_k", "", self.block_k_major, self.block_k)
+ if self.block_q_major_dkv is not None and self.block_q_dkv is not None:
+ verify_major_minor(
+ "block_q", "_dkv", self.block_q_major_dkv, self.block_q_dkv
+ )
+ if self.block_k_major_dkv is not None and self.block_k_dkv is not None:
+ verify_major_minor(
+ "block_k", "_dkv", self.block_k_major_dkv, self.block_k_dkv
+ )
+ if self.block_k_major_dq is not None and self.block_k_dq is not None:
+ verify_major_minor("block_k", "_dq", self.block_k_major_dq, self.block_k_dq)
+
+ @property
+ def has_backward_blocks(self) -> bool:
+ backward_blocks = (
+ self.block_q_major_dkv,
+ self.block_k_major_dkv,
+ self.block_q_dkv,
+ self.block_k_dkv,
+ self.block_k_major_dq,
+ self.block_k_dq,
+ self.block_q_dq,
+ )
+ return all(b is not None for b in backward_blocks)
+
+ @classmethod
+ def get_default(cls, batch_size, num_heads, q_seq_len, kv_len, d_model):
+ del batch_size, num_heads, q_seq_len, kv_len, d_model # Unused.
+ return BlockSizes(
+ block_q=128,
+ block_k_major=128,
+ block_k=128,
+ block_b=1,
+ block_q_major_dkv=128,
+ block_k_major_dkv=128,
+ block_k_dkv=128,
+ block_q_dkv=128,
+ block_k_major_dq=128,
+ block_k_dq=128,
+ block_q_dq=128,
+ )
+
+
+def _flash_attention(
+ q,
+ k,
+ v,
+ carry,
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ ab,
+ segment_ids,
+ save_residuals,
+ causal,
+ sm_scale,
+ block_sizes,
+ debug,
+):
+ return _flash_attention_impl(
+ q,
+ k,
+ v,
+ carry,
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ ab,
+ segment_ids,
+ save_residuals,
+ causal,
+ sm_scale,
+ block_sizes.block_b,
+ block_sizes.block_q,
+ block_sizes.block_k_major,
+ block_sizes.block_k,
+ debug,
+ )
+
+
+def _flash_attention_fwd(
+ q,
+ k,
+ v,
+ carry,
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ ab,
+ segment_ids,
+ save_residuals,
+ causal,
+ sm_scale,
+ block_sizes,
+ debug,
+):
+ if save_residuals:
+ raise NotImplementedError("Higher-order AD not supported")
+ o, l, m = _flash_attention(
+ q,
+ k,
+ v,
+ carry,
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ ab,
+ segment_ids,
+ True,
+ causal,
+ sm_scale,
+ block_sizes,
+ debug,
+ )
+ return o, l, m
+
+
+def _flash_attention_bwd(
+ save_residuals: bool,
+ causal: bool,
+ sm_scale: float,
+ block_sizes: BlockSizes,
+ debug: bool,
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ residuals,
+ do,
+):
+ """VJP rule for FlashAttention."""
+ if save_residuals:
+ raise NotImplementedError("Higher-order AD not supported")
+ (q, k, v, ab, segment_ids, o, l, m) = residuals
+ if not block_sizes.has_backward_blocks:
+ raise ValueError(
+ "Program is being differentiated, but not all backward blocks are"
+ " specified"
+ )
+
+ di = jnp.sum(
+ o.astype(jnp.float32) * do.astype(jnp.float32), axis=-1
+ ) # [batch_size, num_heads, q_seq_len]
+
+ dk, dv = _flash_attention_bwd_dkv(
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ q,
+ k,
+ v,
+ ab,
+ segment_ids,
+ l,
+ m,
+ do,
+ di,
+ block_q_major=block_sizes.block_q_major_dkv,
+ block_k_major=block_sizes.block_k_major_dkv,
+ block_k=block_sizes.block_k_dkv,
+ block_q=block_sizes.block_q_dkv,
+ sm_scale=sm_scale,
+ causal=causal,
+ mask_value=DEFAULT_MASK_VALUE,
+ debug=debug,
+ )
+
+ dq, ds = _flash_attention_bwd_dq(
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ q,
+ k,
+ v,
+ ab,
+ segment_ids,
+ l,
+ m,
+ do,
+ di,
+ block_q_major=block_sizes.block_q_dq,
+ block_k_major=block_sizes.block_k_major_dq,
+ block_k=block_sizes.block_k_dq,
+ sm_scale=sm_scale,
+ causal=causal,
+ mask_value=DEFAULT_MASK_VALUE,
+ debug=debug,
+ )
+ return dq, dk, dv
+
+
+MIN_BLOCK_SIZE = 128
+TRANS_B_DIM_NUMBERS = (((1,), (1,)), ((), ()))
+
+
+def below_or_on_diag(r, r_blk_size, c, c_blk_size):
+ # A block is considered below or on diagonal as long as the bottom left
+ # corner of the block is below or on diagonal.
+ return ((r + 1) * r_blk_size - 1) > (c * c_blk_size)
+
+
+def _flash_attention_kernel(
+ q_idx_chunk_start, k_idx_chunk_start, q_tile_ref, *args, **kwargs
+):
+ block_b = q_tile_ref.shape[0]
+ # If we're not going to tile the softmax, then we can avoid a bunch of VPU ops.
+ if kwargs["block_k"] == kwargs["kv_seq_len"]:
+ assert False
+ kernel = _flash_attention_kernel_single_batch_single_step
+ else:
+ kernel = _flash_attention_kernel_single_batch
+ for batch_idx in range(block_b):
+ kernel(
+ (batch_idx, 0),
+ q_idx_chunk_start,
+ k_idx_chunk_start,
+ q_tile_ref,
+ *args,
+ **kwargs,
+ )
+
+
+def _flash_attention_kernel_single_batch(
+ batch_idx: tuple[int, ...],
+ q_chunk_idx_start_ref,
+ k_chunk_idx_start_ref,
+ q_tile_ref,
+ k_tile_ref,
+ v_tile_ref,
+ acc_tile_ref,
+ l_tile_ref,
+ m_tile_ref,
+ ab_tile_ref,
+ q_segment_ids_tile_ref,
+ kv_segment_ids_tile_ref, # Input arrays
+ o_tile_ref, # Output arrays
+ m_scratch_ref,
+ l_scratch_ref,
+ acc_scratch_ref,
+ l_ref: Any | None = None,
+ m_ref: Any | None = None,
+ *,
+ causal,
+ sm_scale,
+ block_k,
+ kv_seq_len,
+ mask_value,
+ block_q,
+):
+ block_k_major = k_tile_ref.shape[2]
+ block_q = q_tile_ref.shape[2]
+ head_dim = q_tile_ref.shape[-1]
+
+ kv_seq_idx = pl.program_id(3)
+
+ @pl.when(kv_seq_idx == 0)
+ def start_new_sequence():
+ m_scratch_ref[batch_idx] = m_tile_ref[batch_idx]
+ l_scratch_ref[batch_idx] = l_tile_ref[batch_idx]
+ acc_scratch_ref[batch_idx] = acc_tile_ref[batch_idx]
+
+ q_chunk_idx_start = q_chunk_idx_start_ref[0]
+ k_chunk_idx_start = k_chunk_idx_start_ref[0]
+
+ q_seq_idx = pl.program_id(2)
+ if causal:
+ should_run = below_or_on_diag(
+ q_seq_idx + q_chunk_idx_start,
+ block_q,
+ kv_seq_idx + k_chunk_idx_start,
+ block_k_major,
+ )
+ else:
+ should_run = True
+
+ @pl.when(should_run)
+ def run():
+ @functools.partial(
+ lax.fori_loop, 0, block_k_major // block_k, init_val=None, unroll=True
+ )
+ def body(i, _):
+ m_prev = m_scratch_ref[batch_idx]
+ l_prev = l_scratch_ref[batch_idx]
+ q = q_tile_ref[batch_idx] # [block_q, head_dim]
+ start_k = i * block_k
+ k = pl.load(
+ k_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None))
+ ) # [block_k, head_dim]
+
+ s = jax.lax.dot_general(
+ q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32
+ ) # [block_q, block_k]
+
+ # Add attention bias if needed.
+ if ab_tile_ref is not None:
+ ab = pl.load(
+ ab_tile_ref,
+ (batch_idx[0], pl.dslice(0, block_q), pl.dslice(start_k, block_k)),
+ ).astype(jnp.float32)
+ s += ab
+
+ if sm_scale != 1.0:
+ s *= sm_scale
+
+ mask = None
+ if q_segment_ids_tile_ref is not None:
+ repeats, rem = divmod(block_k, NUM_LANES)
+ if rem:
+ raise NotImplementedError(
+ f"kv block size must be a multiple of {NUM_LANES}"
+ )
+ q_segment_ids = pltpu.repeat(
+ q_segment_ids_tile_ref[batch_idx[0]], repeats, axis=1
+ ) # [block_q, block_k].
+ kv_segment_ids = pl.load(
+ kv_segment_ids_tile_ref,
+ (batch_idx[0], pl.dslice(1), pl.dslice(start_k, block_k)),
+ ) # [1, block_k].
+ mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
+
+ if causal:
+ mask_shape = (block_q, block_k)
+ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
+ row_ids += (q_seq_idx + q_chunk_idx_start) * block_q
+ col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
+ col_ids += (kv_seq_idx + k_chunk_idx_start) * block_k_major + start_k
+ causal_mask = col_ids <= row_ids
+ mask = (
+ causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
+ )
+
+ s = s if mask is None else s + jnp.where(mask, 0.0, mask_value)
+
+ m_curr = jnp.max(s, axis=1)[:, None] # Row max, shape [block_q, 1].
+ m_next = jnp.maximum(m_prev, m_curr) # Shape [block_q, 128].
+
+ block_k_repeats, rem = divmod(block_k, MIN_BLOCK_SIZE)
+ if rem:
+ raise NotImplementedError(
+ f"{block_k=} should be a multiple of {MIN_BLOCK_SIZE}"
+ )
+ p = jnp.exp(s - pltpu.repeat(m_next, block_k_repeats, 1))
+
+ alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128].
+
+ l_corr = alpha * l_prev
+
+ l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128]
+
+ head_dim_repeats, rem = divmod(head_dim, MIN_BLOCK_SIZE)
+ l_broadcast = lambda l: pltpu.repeat(l, head_dim_repeats, 1)
+ if rem:
+ if head_dim_repeats == 0:
+ l_broadcast = lambda l: l[:, :head_dim]
+ else:
+ raise NotImplementedError(
+ f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE} if larger"
+ )
+ l_scratch_ref[batch_idx] = l_next
+ m_scratch_ref[batch_idx] = m_next
+
+ l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next)
+ acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe)
+ v = pl.load(
+ v_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None))
+ )
+ o_curr = jax.lax.dot(
+ p.astype(v.dtype), v, preferred_element_type=jnp.float32
+ )
+ acc_scratch_ref[batch_idx] += o_curr * l_broadcast(l_next_inv_safe)
+
+ @pl.when(kv_seq_idx == (kv_seq_len // block_k_major) - 1)
+ def store_output():
+ o_tile_ref[batch_idx] = acc_scratch_ref[batch_idx].astype(o_tile_ref.dtype)
+ if l_ref is not None:
+ l_ref[batch_idx] = l_scratch_ref[batch_idx].astype(l_ref.dtype)
+ if m_ref is not None:
+ m_ref[batch_idx] = m_scratch_ref[batch_idx].astype(m_ref.dtype)
+
+
+def _flash_attention_impl(
+ q,
+ k,
+ v,
+ carry,
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ ab,
+ segment_ids,
+ save_residuals,
+ causal,
+ sm_scale,
+ block_b,
+ block_q,
+ block_k_major,
+ block_k,
+ debug,
+):
+ assert block_k_major == block_k, (block_k_major, block_k)
+ batch_size, num_heads, q_seq_len, head_dim = q.shape
+ _, _, kv_seq_len, _ = k.shape
+ acc, l_prev, m_prev = carry
+ l_prev, m_prev = map(
+ lambda x: jnp.broadcast_to(x[..., None], (*x.shape, MIN_BLOCK_SIZE)),
+ (l_prev, m_prev),
+ )
+ q_chunk_idx_start, k_chunk_idx_start = (
+ q_chunk_idx_start[None],
+ k_chunk_idx_start[None],
+ )
+ _verify_block("block_q", "q_seq_len", block_q, q_seq_len, should_divide=False)
+ _verify_block("block_k_major", "kv_seq_len", block_k_major, kv_seq_len)
+ _verify_block("block_k", "kv_seq_len", block_k, kv_seq_len)
+ _verify_block("block_b", "batch", block_b, batch_size, should_divide=False)
+
+ grid = (
+ pl.cdiv(batch_size, block_b),
+ num_heads,
+ pl.cdiv(q_seq_len, block_q),
+ kv_seq_len // block_k_major,
+ )
+
+ def q_index_map(batch_index, head_index, q_seq_index, _, q_idx_ref, k_idx_ref):
+ return (batch_index, head_index, q_seq_index, 0)
+
+ def kv_index_map(
+ batch_index, head_index, q_seq_index, kv_seq_index, q_idx_ref, k_idx_ref
+ ):
+ if causal:
+ # If the kv block is skipped, prefetch the next valid kv block, i.e. the
+ # 0th one to be used for the next block_q rows.
+ next_kv_index = lax.select(
+ below_or_on_diag(
+ q_seq_index + q_idx_ref[0],
+ block_q,
+ kv_seq_index + k_idx_ref[0],
+ block_k_major,
+ ),
+ kv_seq_index,
+ 0,
+ )
+ else:
+ next_kv_index = kv_seq_index
+ return (batch_index, head_index, next_kv_index, 0)
+
+ def ab_index_map(
+ batch_index, head_index, q_seq_index, kv_seq_index, q_idx_ref, k_idx_ref
+ ):
+ if causal:
+ should_run = below_or_on_diag(
+ q_seq_index + q_idx_ref[0],
+ block_q,
+ kv_seq_index + k_idx_ref[0],
+ block_k_major,
+ )
+ next_kv_index = lax.select(should_run, kv_seq_index, 0)
+ else:
+ next_kv_index = kv_seq_index
+
+ return (batch_index, 0, next_kv_index)
+
+ def o_index_map(batch_index, head_index, q_seq_index, _, q_idx_ref, k_idx_ref):
+ return (batch_index, head_index, q_seq_index, 0)
+
+ def lm_index_map(batch_index, head_index, q_seq_index, _, q_idx_ref, k_idx_ref):
+ return (batch_index, head_index, q_seq_index, 0)
+
+ kernel = functools.partial(
+ _flash_attention_kernel,
+ causal=causal,
+ mask_value=DEFAULT_MASK_VALUE,
+ sm_scale=sm_scale,
+ block_k=block_k,
+ kv_seq_len=kv_seq_len,
+ block_q=block_q,
+ )
+ out_shape = [jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)]
+ out_specs = [pl.BlockSpec(o_index_map, (block_b, 1, block_q, head_dim))]
+
+ if block_k != kv_seq_len:
+ scratch_shape = functools.partial(jax.ShapeDtypeStruct, dtype=jnp.float32)
+ m_scratch = scratch_shape((block_b, 1, block_q, MIN_BLOCK_SIZE))
+ l_scratch = scratch_shape((block_b, 1, block_q, MIN_BLOCK_SIZE))
+ acc_scratch = scratch_shape((block_b, 1, block_q, head_dim))
+ out_shape += [m_scratch, l_scratch, acc_scratch]
+ out_specs += [
+ pl.BlockSpec(lambda *_: (0, 0, 0, 0), m_scratch.shape),
+ pl.BlockSpec(lambda *_: (0, 0, 0, 0), l_scratch.shape),
+ pl.BlockSpec(lambda *_: (0, 0, 0, 0), acc_scratch.shape),
+ ]
+ else:
+ assert False
+ out_shape += [None, None, None]
+ out_specs += [None, None, None]
+
+ if save_residuals:
+ out_specs = [
+ *out_specs,
+ pl.BlockSpec(lm_index_map, (block_b, 1, block_q, MIN_BLOCK_SIZE)),
+ pl.BlockSpec(lm_index_map, (block_b, 1, block_q, MIN_BLOCK_SIZE)),
+ ]
+ l = jax.ShapeDtypeStruct(
+ (batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE), dtype=jnp.float32
+ )
+ m = jax.ShapeDtypeStruct(
+ (batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE), dtype=jnp.float32
+ )
+ out_shape = (*out_shape, l, m)
+
+ ab_block_spec = (
+ pl.BlockSpec(ab_index_map, (block_b, block_q, block_k_major))
+ if ab is not None
+ else None
+ )
+
+ if ab is not None:
+ ab = ab[:, None].repeat(block_q, axis=1)
+
+ q_segment_ids_spec = kv_segment_ids_spec = None
+ q_segment_ids = kv_segment_ids = None
+ if segment_ids is not None:
+
+ def q_segment_ids_index_map(
+ batch_index, head_index, q_seq_index, _, q_idx_ref, k_idx_ref
+ ):
+ del head_index
+ return (batch_index, q_seq_index, 0)
+
+ def kv_segment_ids_index_map(
+ batch_index, head_index, q_seq_index, kv_seq_index, q_idx_ref, k_idx_ref
+ ):
+ del head_index
+ if causal:
+ next_kv_index = lax.select(
+ below_or_on_diag(
+ q_seq_index + q_idx_ref[0],
+ block_q,
+ kv_seq_index + k_idx_ref[0],
+ block_k_major,
+ ),
+ kv_seq_index,
+ 0,
+ )
+ else:
+ next_kv_index = kv_seq_index
+ return (batch_index, 0, next_kv_index)
+
+ q_segment_ids_spec = pl.BlockSpec(
+ q_segment_ids_index_map, (block_b, block_q, NUM_LANES)
+ )
+ kv_segment_ids_spec = pl.BlockSpec(
+ kv_segment_ids_index_map, (block_b, NUM_SUBLANES, block_k_major)
+ )
+
+ q_segment_ids = jax.lax.broadcast_in_dim(
+ segment_ids.q,
+ (batch_size, q_seq_len, NUM_LANES),
+ (
+ 0,
+ 1,
+ ),
+ )
+ kv_segment_ids = jax.lax.broadcast_in_dim(
+ segment_ids.kv,
+ (batch_size, NUM_SUBLANES, kv_seq_len),
+ (
+ 0,
+ 2,
+ ),
+ )
+
+ in_specs = [
+ pl.BlockSpec(q_index_map, (block_b, 1, block_q, head_dim)),
+ pl.BlockSpec(kv_index_map, (block_b, 1, block_k_major, head_dim)),
+ pl.BlockSpec(kv_index_map, (block_b, 1, block_k_major, head_dim)),
+ pl.BlockSpec(q_index_map, (block_b, 1, block_q, head_dim)),
+ pl.BlockSpec(lm_index_map, (block_b, 1, block_q, MIN_BLOCK_SIZE)),
+ pl.BlockSpec(lm_index_map, (block_b, 1, block_q, MIN_BLOCK_SIZE)),
+ ab_block_spec,
+ q_segment_ids_spec,
+ kv_segment_ids_spec,
+ ]
+
+ o, *aux = pl.pallas_call(
+ kernel,
+ out_shape=out_shape,
+ grid_spec=pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=2, in_specs=in_specs, out_specs=out_specs, grid=grid
+ ),
+ debug=debug,
+ mosaic_params=dict(
+ dimension_semantics=("parallel", "parallel", "parallel", "arbitrary")
+ ),
+ )(
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ q,
+ k,
+ v,
+ acc,
+ l_prev,
+ m_prev,
+ ab,
+ q_segment_ids,
+ kv_segment_ids,
+ )
+ if save_residuals:
+ l, m = (v[..., 0] for v in aux[-2:])
+ return (o, l, m)
+ else:
+ return o
+
+
+def _flash_attention_dkv_kernel(
+ q_chunk_idx_start_ref,
+ k_chunk_idx_start_ref,
+ q_tile_ref,
+ k_tile_ref,
+ v_tile_ref,
+ ab_tile_ref,
+ q_segment_ids_tile_ref,
+ kv_segment_ids_tile_ref,
+ l_tile_ref,
+ m_tile_ref,
+ do_tile_ref,
+ di_tile_ref,
+ dk_tile_ref,
+ dv_tile_ref,
+ dk_scratch_ref,
+ dv_scratch_ref,
+ *,
+ sm_scale: float,
+ causal: bool,
+ mask_value: float,
+ q_seq_len: int,
+ block_q: int,
+ block_k: int,
+):
+ _, _, block_q_major, _ = q_tile_ref.shape
+ _, _, block_k_major, _ = k_tile_ref.shape
+
+ q_seq_index = pl.program_id(axis=3)
+ kv_seq_index = pl.program_id(axis=2)
+
+ q_chunk_idx_start = q_chunk_idx_start_ref[0]
+ k_chunk_idx_start = k_chunk_idx_start_ref[0]
+
+ @pl.when(q_seq_index == 0)
+ def start_new_sequence():
+ dk_scratch_ref[:, :] = jnp.zeros(dk_scratch_ref.shape, dk_scratch_ref.dtype)
+ dv_scratch_ref[:, :] = jnp.zeros(dv_scratch_ref.shape, dv_scratch_ref.dtype)
+
+ def q_body(j, _):
+ start_q = j * block_q
+
+ def k_body(i, _):
+ start_k = i * block_k
+ k = pl.load(k_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None)))
+ v = pl.load(v_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None)))
+ q = pl.load(
+ q_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None))
+ ) # [block_q, head_dim]
+ l = pl.load(
+ l_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None))
+ ) # [block_q, 128]
+ m = pl.load(
+ m_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None))
+ ) # [block_q, 128]
+ do = pl.load(
+ do_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None))
+ ) # [block_q, 128]
+ di = pl.load(
+ di_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None))
+ ).astype(
+ jnp.float32
+ ) # [block_q, 128]
+
+ capped_logits = lax.dot_general(
+ q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32
+ ) # [block_q_major, block_k]
+
+ if ab_tile_ref is not None:
+ ab = pl.load(
+ ab_tile_ref,
+ (
+ 0,
+ pl.dslice(0, block_q),
+ pl.dslice(i * block_k, block_k),
+ ),
+ ).astype(jnp.float32)
+ capped_logits += ab
+
+ if sm_scale != 1.0:
+ capped_logits *= sm_scale
+
+ mask = None
+ if q_segment_ids_tile_ref is not None:
+ repeats, rem = divmod(block_k, NUM_LANES)
+ if rem:
+ raise NotImplementedError()
+ q_segment_ids = pl.load(
+ q_segment_ids_tile_ref, (0, pl.ds(start_q, block_q), slice(None))
+ ) # [block_q, NUM_LANES].
+ q_segment_ids = pltpu.repeat(
+ q_segment_ids, repeats, axis=1
+ ) # [block_q, block_k].
+ kv_segment_ids = pl.load(
+ kv_segment_ids_tile_ref, (slice(None), 0, pl.ds(start_k, block_k))
+ ) # [1, block_k].
+ mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
+
+ if causal:
+ mask_shape = (block_q, block_k)
+ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
+ row_ids += (q_seq_index + q_chunk_idx_start) * block_q_major + start_q
+ col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
+ col_ids += (kv_seq_index + k_chunk_idx_start) * block_k_major + start_k
+ causal_mask = col_ids <= row_ids
+ mask = (
+ causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
+ )
+
+ capped_logits = (
+ capped_logits
+ if mask is None
+ else capped_logits + jnp.where(mask, 0.0, mask_value)
+ )
+
+ p = jnp.exp(
+ capped_logits - pltpu.repeat(m, block_k // MIN_BLOCK_SIZE, axis=1)
+ )
+ p = p * pltpu.repeat(
+ 1 / l, block_k // MIN_BLOCK_SIZE, axis=1
+ ) # [block_q_major, block_k_major]
+ dv = lax.dot(p.T.astype(do.dtype), do, preferred_element_type=jnp.float32)
+ pl.store(
+ dv_scratch_ref,
+ (pl.ds(start_k, block_k), slice(None)),
+ pl.load(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None)))
+ + dv.astype(dv_scratch_ref.dtype),
+ )
+
+ # di: [block_q, 128]
+ # do: [block_q, head_dim]
+ # v: [block_k_major, head_dim]
+ dp = lax.dot_general(
+ do, v, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32
+ )
+ ds = (dp - pltpu.repeat(di, block_k // MIN_BLOCK_SIZE, axis=1)) * p
+
+ if sm_scale != 1.0:
+ ds = ds * sm_scale
+
+ # ds: [block_q_major, block_k_major]
+ # q: [block_q_major, head_dim]
+ dk = lax.dot(ds.T.astype(do.dtype), q, preferred_element_type=jnp.float32)
+ pl.store(
+ dk_scratch_ref,
+ (pl.ds(start_k, block_k), slice(None)),
+ pl.load(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)))
+ + dk.astype(dk_scratch_ref.dtype),
+ )
+
+ lax.fori_loop(0, block_k_major // block_k, k_body, None, unroll=True)
+
+ if causal:
+ should_run = below_or_on_diag(
+ q_seq_index + q_chunk_idx_start,
+ block_q_major,
+ kv_seq_index + k_chunk_idx_start,
+ block_k_major,
+ )
+ else:
+ should_run = True
+
+ @pl.when(should_run)
+ def run():
+ lax.fori_loop(0, block_q_major // block_q, q_body, None, unroll=True)
+
+ @pl.when(q_seq_index == q_seq_len // block_q_major - 1)
+ def end_of_q_sequence():
+ dv_tile_ref[0, 0, :, :] = dv_scratch_ref[...].astype(dv_tile_ref)
+ dk_tile_ref[0, 0, :, :] = dk_scratch_ref[...].astype(dk_tile_ref)
+
+
+def _flash_attention_bwd_dkv(
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ q,
+ k,
+ v,
+ ab,
+ segment_ids,
+ l,
+ m,
+ do,
+ di,
+ *,
+ block_q_major: int | None,
+ block_q: int | None,
+ block_k_major: int | None,
+ block_k: int | None,
+ sm_scale: float,
+ causal: bool = False,
+ mask_value: float = DEFAULT_MASK_VALUE,
+ debug: bool = False,
+):
+ batch_size, num_heads, q_seq_len, head_dim = q.shape
+ _, _, kv_seq_len, _ = k.shape
+ q_chunk_idx_start, k_chunk_idx_start = (
+ q_chunk_idx_start[None],
+ k_chunk_idx_start[None],
+ )
+ _verify_block("block_q_major_dkv", "q_seq_len", block_q_major, q_seq_len)
+ _verify_block("block_q_dkv", "q_seq_len", block_q, q_seq_len)
+ _verify_block("block_k_major_dkv", "kv_seq_len", block_k_major, kv_seq_len)
+ _verify_block("block_k_dkv", "kv_seq_len", block_k, kv_seq_len)
+
+ # Broadcast out scalar values
+ m = jnp.broadcast_to(m[..., None], (*m.shape, MIN_BLOCK_SIZE))
+ l = jnp.broadcast_to(l[..., None], (*l.shape, MIN_BLOCK_SIZE))
+ # Preprocess contraction for bwd pass
+ di = jnp.broadcast_to(di[..., None], (*di.shape, MIN_BLOCK_SIZE))
+
+ # kv index needs to be before q index since q index is the contractng
+ # dimension.
+ grid = (
+ batch_size,
+ num_heads,
+ kv_seq_len // block_k_major,
+ q_seq_len // block_q_major,
+ )
+
+ def qo_index_map(
+ batch_index, head_index, kv_seq_index, q_seq_index, q_idx_ref, k_idx_ref
+ ):
+ if causal:
+ # If the q block is skipped, stay at the 0th q block.
+ next_q_index = lax.select(
+ below_or_on_diag(
+ q_seq_index + q_idx_ref[0],
+ block_q_major,
+ kv_seq_index + k_idx_ref[0],
+ block_k_major,
+ ),
+ q_seq_index,
+ 0,
+ )
+ else:
+ next_q_index = q_seq_index
+
+ return (batch_index, head_index, next_q_index, 0)
+
+ qo_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
+ assert qo_spec.block_shape is not None
+ assert q.ndim == len(qo_spec.block_shape)
+ do_spec = qo_spec
+ assert do.ndim == len(qo_spec.block_shape)
+
+ def kv_index_map(batch_index, head_index, kv_seq_index, _, q_idx_ref, k_idx_ref):
+ return (batch_index, head_index, kv_seq_index, 0)
+
+ kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim))
+ assert kv_spec.block_shape is not None
+ assert k.ndim == len(kv_spec.block_shape)
+ assert v.ndim == len(kv_spec.block_shape)
+
+ def lm_index_map(batch_index, head_index, _, q_seq_index, q_idx_ref, k_idx_ref):
+ return (batch_index, head_index, q_seq_index, 0)
+
+ lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
+ assert lm_spec.block_shape is not None
+ assert l.ndim == len(lm_spec.block_shape)
+ assert m.ndim == len(lm_spec.block_shape)
+
+ di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
+ assert di_spec.block_shape is not None
+ assert di.ndim == len(di_spec.block_shape)
+
+ def ab_index_map(
+ batch_index, head_index, kv_seq_index, q_seq_index, q_idx_ref, k_idx_ref
+ ):
+ return (batch_index, 0, kv_seq_index)
+
+ if ab is not None:
+ ab = ab[:, None].repeat(block_q_major, axis=1)
+
+ dab_spec = (
+ pl.BlockSpec(ab_index_map, (1, block_q_major, block_k_major))
+ if ab is not None
+ else None
+ )
+
+ q_segment_ids_spec = kv_segment_ids_spec = None
+ q_segment_ids = kv_segment_ids = None
+ if segment_ids is not None:
+
+ def q_segment_ids_index_map(
+ batch_index, head_index, kv_seq_index, q_seq_index, q_idx_ref, k_idx_ref
+ ):
+ del head_index
+ if causal:
+ next_q_index = lax.select(
+ below_or_on_diag(
+ q_seq_index + q_idx_ref[0],
+ block_q_major,
+ kv_seq_index + k_idx_ref[0],
+ block_k_major,
+ ),
+ q_seq_index,
+ 0,
+ )
+ else:
+ next_q_index = q_seq_index
+ return (batch_index, next_q_index, 0)
+
+ def kv_segment_ids_index_map(
+ batch_index, head_index, kv_seq_index, _, q_idx_ref, k_idx_ref
+ ):
+ del head_index
+ return (batch_index, 0, kv_seq_index)
+
+ q_segment_ids_spec = pl.BlockSpec(
+ q_segment_ids_index_map, (1, block_q_major, NUM_LANES)
+ )
+ kv_segment_ids_spec = pl.BlockSpec(
+ kv_segment_ids_index_map, (1, NUM_SUBLANES, block_k_major)
+ )
+
+ q_segment_ids = jax.lax.broadcast_in_dim(
+ segment_ids.q,
+ (batch_size, q_seq_len, NUM_LANES),
+ (
+ 0,
+ 1,
+ ),
+ )
+ kv_segment_ids = jax.lax.broadcast_in_dim(
+ segment_ids.kv,
+ (batch_size, NUM_SUBLANES, kv_seq_len),
+ (
+ 0,
+ 2,
+ ),
+ )
+
+ in_specs = [
+ qo_spec,
+ kv_spec,
+ kv_spec,
+ dab_spec,
+ q_segment_ids_spec,
+ kv_segment_ids_spec,
+ lm_spec,
+ lm_spec,
+ do_spec,
+ di_spec,
+ ]
+
+ out_shapes = [
+ jax.ShapeDtypeStruct((batch_size, num_heads, kv_seq_len, head_dim), k.dtype),
+ jax.ShapeDtypeStruct((batch_size, num_heads, kv_seq_len, head_dim), v.dtype),
+ jax.ShapeDtypeStruct((block_k_major, head_dim), jnp.float32),
+ jax.ShapeDtypeStruct((block_k_major, head_dim), jnp.float32),
+ ]
+
+ def dkv_index_map(batch_index, head_index, kv_seq_index, _, q_idx_ref, k_idx_ref):
+ return (batch_index, head_index, kv_seq_index, 0)
+
+ dkv_spec = pl.BlockSpec(dkv_index_map, (1, 1, block_k_major, head_dim))
+ out_specs = [
+ dkv_spec,
+ dkv_spec,
+ pl.BlockSpec(lambda *_: (0, 0), (block_k_major, head_dim)),
+ pl.BlockSpec(lambda *_: (0, 0), (block_k_major, head_dim)),
+ ]
+
+ kernel = functools.partial(
+ _flash_attention_dkv_kernel,
+ block_q=block_q,
+ block_k=block_k,
+ sm_scale=sm_scale,
+ causal=causal,
+ mask_value=mask_value,
+ q_seq_len=q_seq_len,
+ )
+ name_scope = (
+ f"flash_mha_bwd_dkv_{block_q_major=}_{block_q=}_{block_k_major=}_{block_k=}"
+ )
+ with jax.named_scope(name_scope):
+ dk, dv, _, _ = pl.pallas_call(
+ kernel,
+ out_shape=out_shapes,
+ grid_spec=pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=2, in_specs=in_specs, out_specs=out_specs, grid=grid
+ ),
+ debug=debug,
+ mosaic_params=dict(
+ dimension_semantics=(
+ "parallel",
+ "parallel",
+ "parallel",
+ "arbitrary",
+ )
+ ),
+ )(
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ q,
+ k,
+ v,
+ ab,
+ q_segment_ids,
+ kv_segment_ids,
+ l,
+ m,
+ do,
+ di,
+ )
+ assert dk.shape == k.shape
+ assert dv.shape == v.shape
+ return dk, dv
+
+
+def _flash_attention_dq_kernel(
+ q_chunk_idx_start_ref,
+ k_chunk_idx_start_ref,
+ q_tile_ref,
+ k_tile_ref,
+ v_tile_ref,
+ ab_tile_ref,
+ q_segment_ids_tile_ref,
+ kv_segment_ids_tile_ref,
+ l_tile_ref,
+ m_tile_ref,
+ do_tile_ref,
+ di_tile_ref,
+ dq_tile_ref,
+ dq_scratch_ref,
+ ds_tile_ref,
+ *,
+ sm_scale: float,
+ causal: bool,
+ mask_value: float,
+ kv_seq_len: int,
+ block_k: int,
+):
+ _, _, block_k_major, _ = k_tile_ref.shape
+ _, _, block_q_major, _ = q_tile_ref.shape
+
+ kv_seq_index = pl.program_id(axis=3)
+ q_seq_index = pl.program_id(axis=2)
+
+ q_chunk_idx_start = q_chunk_idx_start_ref[0]
+ k_chunk_idx_start = k_chunk_idx_start_ref[0]
+
+ @pl.when(kv_seq_index == 0)
+ def start_new_sequence():
+ dq_scratch_ref[:, :] = jnp.zeros(dq_scratch_ref.shape, dq_scratch_ref.dtype)
+
+ def body(i, _):
+ k_slice = pl.ds(i * block_k, block_k)
+ q = q_tile_ref[0, 0, :, :]
+ k = pl.load(
+ k_tile_ref,
+ (0, 0, k_slice, slice(None)),
+ ) # [block_k, head_dim]
+ v = pl.load(
+ v_tile_ref,
+ (0, 0, k_slice, slice(None)),
+ ) # [block_k, head_dim]
+ l = l_tile_ref[0, 0, :, :] # [block_q_major, 128]
+ m = m_tile_ref[0, 0, :, :] # [block_q_major, 128]
+ do = do_tile_ref[0, 0, :, :] # [block_q_major, head_dim]
+ di = di_tile_ref[0, 0, :].astype(jnp.float32) # [block_q_major, 128]
+
+ capped_logits = jax.lax.dot_general(
+ q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32
+ )
+
+ if ab_tile_ref is not None:
+ ab = pl.load(
+ ab_tile_ref,
+ (0, pl.dslice(0, block_q_major), pl.dslice(i * block_k, block_k)),
+ ).astype(jnp.float32)
+ capped_logits += ab
+
+ if sm_scale != 1.0:
+ capped_logits *= sm_scale
+
+ mask = None
+ if q_segment_ids_tile_ref is not None:
+ repeats, rem = divmod(block_k, NUM_LANES)
+ if rem:
+ raise NotImplementedError(
+ f"kv block size must be a multiple of {NUM_LANES}"
+ )
+ q_segment_ids = pltpu.repeat(
+ q_segment_ids_tile_ref[0], repeats, axis=1
+ ) # [block_q, block_k].
+ kv_segment_ids = pl.load(
+ kv_segment_ids_tile_ref, (slice(None), 0, k_slice)
+ ) # [1, block_k].
+ mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
+
+ if causal:
+ mask_shape = (block_q_major, block_k)
+ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
+ row_ids += (q_seq_index + q_chunk_idx_start) * block_q_major
+ col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
+ col_ids += (kv_seq_index + k_chunk_idx_start) * block_k_major + i * block_k
+ causal_mask = col_ids <= row_ids
+ mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
+ capped_logits = (
+ capped_logits
+ if mask is None
+ else capped_logits + jnp.where(mask, 0.0, mask_value)
+ )
+
+ p = jnp.exp(capped_logits - pltpu.repeat(m, block_k // MIN_BLOCK_SIZE, axis=1))
+ p = p * pltpu.repeat(
+ 1 / l, block_k // MIN_BLOCK_SIZE, axis=1
+ ) # [block_q_major, block_k]
+
+ # di: [block_q_major, 128]
+ # do: [block_q_major, head_dim]
+ # v: [block_k_major, head_dim]
+ dp = jax.lax.dot_general(
+ do,
+ v,
+ TRANS_B_DIM_NUMBERS,
+ preferred_element_type=jnp.float32,
+ )
+ ds = (dp - pltpu.repeat(di, block_k // MIN_BLOCK_SIZE, axis=1)) * p
+
+ if sm_scale != 1.0:
+ ds = ds * sm_scale
+
+ if ds_tile_ref is not None:
+ pl.store(
+ ds_tile_ref,
+ (0, pl.dslice(None), pl.dslice(i * block_k, block_k)),
+ ds.astype(ds_tile_ref.dtype),
+ )
+
+ # dp: [block_q_major, block_k]
+ # k: [block_k, head_dim]
+ dq_scratch_ref[:, :] += lax.dot(
+ ds.astype(k.dtype),
+ k,
+ preferred_element_type=jnp.float32,
+ ).astype(dq_scratch_ref.dtype)
+
+ if causal:
+ should_run = below_or_on_diag(
+ q_seq_index + q_chunk_idx_start,
+ block_q_major,
+ kv_seq_index + k_chunk_idx_start,
+ block_k_major,
+ )
+ should_not_run = lax.select(should_run, False, True)
+ else:
+ should_run = True
+ should_not_run = False # type: ignore
+
+ @pl.when(should_run)
+ def run():
+ lax.fori_loop(0, block_k_major // block_k, body, None, unroll=True)
+
+ @pl.when(should_not_run)
+ def zero_out_ds():
+ if ds_tile_ref is not None:
+ ds_tile_ref[...] = jnp.zeros_like(ds_tile_ref)
+
+ @pl.when(kv_seq_index == kv_seq_len // block_k_major - 1)
+ def end_of_kv_sequence():
+ dq_tile_ref[0, 0, :, :] = dq_scratch_ref[...].astype(dq_tile_ref)
+ dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref)
+
+
+def _flash_attention_bwd_dq(
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ q,
+ k,
+ v,
+ ab,
+ segment_ids,
+ l,
+ m,
+ do,
+ di,
+ *,
+ block_q_major: int | None,
+ block_k_major: int | None,
+ block_k: int | None,
+ sm_scale: float,
+ causal: bool,
+ mask_value: float,
+ debug: bool,
+):
+ batch_size, num_heads, q_seq_len, head_dim = q.shape
+ _, _, kv_seq_len, _ = k.shape
+ q_chunk_idx_start, k_chunk_idx_start = (
+ q_chunk_idx_start[None],
+ k_chunk_idx_start[None],
+ )
+ _verify_block("block_q_dq", "q_seq_len", block_q_major, q_seq_len)
+ _verify_block("block_k_major_dq", "kv_seq_len", block_k_major, kv_seq_len)
+ _verify_block("block_k_dq", "block_k", block_k, kv_seq_len)
+
+ # Broadcast out scalar values
+ m = jnp.broadcast_to(m[..., None], (*m.shape, MIN_BLOCK_SIZE))
+ l = jnp.broadcast_to(l[..., None], (*l.shape, MIN_BLOCK_SIZE))
+ # Preprocess contraction for bwd pass
+ di = jnp.broadcast_to(di[..., None], (*di.shape, block_k_major))
+
+ grid = (
+ batch_size,
+ num_heads,
+ q_seq_len // block_q_major,
+ kv_seq_len // block_k_major,
+ )
+
+ def qo_index_map(batch_index, head_index, q_seq_index, _, q_idx_ref, k_idx_ref):
+ return (batch_index, head_index, q_seq_index, 0)
+
+ qo_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
+ do_spec = qo_spec
+
+ def kv_index_map(
+ batch_index, head_index, q_seq_index, kv_seq_index, q_idx_ref, k_idx_ref
+ ):
+ if causal:
+ # If the kv block is skipped, prefetch the next valid kv block, i.e. the
+ # 0th one to be used for the next block_q rows.
+ next_kv_index = lax.select(
+ below_or_on_diag(
+ q_seq_index + q_idx_ref[0],
+ block_q_major,
+ kv_seq_index + k_idx_ref[0],
+ block_k_major,
+ ),
+ kv_seq_index,
+ 0,
+ )
+ else:
+ next_kv_index = kv_seq_index
+ return (batch_index, head_index, next_kv_index, 0)
+
+ kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim))
+ assert kv_spec.block_shape is not None
+ assert k.ndim == len(kv_spec.block_shape)
+ assert v.ndim == len(kv_spec.block_shape)
+
+ def lm_index_map(batch_index, head_index, q_seq_index, _, q_idx_ref, k_idx_ref):
+ return (batch_index, head_index, q_seq_index, 0)
+
+ lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
+ assert lm_spec.block_shape is not None
+ assert l.ndim == len(lm_spec.block_shape)
+ assert m.ndim == len(lm_spec.block_shape)
+
+ di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
+ assert di_spec.block_shape is not None
+ assert di.ndim == len(di_spec.block_shape)
+
+ def ab_index_map(
+ batch_index, head_index, q_seq_index, kv_seq_index, q_idx_ref, k_idx_ref
+ ):
+ return (batch_index, 0, kv_seq_index)
+
+ if ab is not None:
+ ab = ab[:, None].repeat(block_q_major, axis=1)
+
+ dab_spec = (
+ pl.BlockSpec(ab_index_map, (1, block_q_major, block_k_major))
+ if ab is not None
+ else None
+ )
+
+ q_segment_ids_spec = kv_segment_ids_spec = None
+ q_segment_ids = kv_segment_ids = None
+ if segment_ids is not None:
+
+ def q_segment_ids_index_map(
+ batch_index, head_index, q_seq_index, _, q_idx_ref, k_idx_ref
+ ):
+ del head_index
+ return (batch_index, q_seq_index, 0)
+
+ def kv_segment_ids_index_map(
+ batch_index, head_index, q_seq_index, kv_seq_index, q_idx_ref, k_idx_ref
+ ):
+ del head_index
+ if causal:
+ # If the kv block is skipped, prefetch the next valid kv block, i.e. the
+ # 0th one to be used for the next block_q rows.
+ next_kv_index = lax.select(
+ below_or_on_diag(
+ q_seq_index + q_idx_ref[0],
+ block_q_major,
+ kv_seq_index + k_idx_ref[0],
+ block_k_major,
+ ),
+ kv_seq_index,
+ 0,
+ )
+ else:
+ next_kv_index = kv_seq_index
+ return (batch_index, 0, next_kv_index)
+
+ q_segment_ids_spec = pl.BlockSpec(
+ q_segment_ids_index_map, (1, block_q_major, NUM_LANES)
+ )
+ kv_segment_ids_spec = pl.BlockSpec(
+ kv_segment_ids_index_map, (1, NUM_SUBLANES, block_k_major)
+ )
+
+ q_segment_ids = jax.lax.broadcast_in_dim(
+ segment_ids.q,
+ (batch_size, q_seq_len, NUM_LANES),
+ (
+ 0,
+ 1,
+ ),
+ )
+ kv_segment_ids = jax.lax.broadcast_in_dim(
+ segment_ids.kv,
+ (batch_size, NUM_SUBLANES, kv_seq_len),
+ (
+ 0,
+ 2,
+ ),
+ )
+
+ in_specs = [
+ qo_spec,
+ kv_spec,
+ kv_spec,
+ dab_spec,
+ q_segment_ids_spec,
+ kv_segment_ids_spec,
+ lm_spec,
+ lm_spec,
+ do_spec,
+ di_spec,
+ ]
+
+ out_shapes = [
+ jax.ShapeDtypeStruct(q.shape, q.dtype),
+ jax.ShapeDtypeStruct((block_q_major, head_dim), jnp.float32),
+ jax.ShapeDtypeStruct(ab.shape, ab.dtype) if ab is not None else None,
+ ]
+ dq_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
+ out_specs = [
+ dq_spec,
+ pl.BlockSpec(lambda *_: (0, 0), (block_q_major, head_dim)),
+ dab_spec,
+ ]
+
+ kernel = functools.partial(
+ _flash_attention_dq_kernel,
+ sm_scale=sm_scale,
+ causal=causal,
+ mask_value=mask_value,
+ block_k=block_k,
+ kv_seq_len=kv_seq_len,
+ )
+ name_scope = f"flash_mha_bwd_dq_{block_q_major=}_{block_k_major=}_{block_k=}"
+ with jax.named_scope(name_scope):
+ dq, _, ds = pl.pallas_call(
+ kernel,
+ out_shape=out_shapes,
+ grid_spec=pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=2, in_specs=in_specs, out_specs=out_specs, grid=grid
+ ),
+ debug=debug,
+ mosaic_params=dict(
+ dimension_semantics=(
+ "parallel",
+ "parallel",
+ "parallel",
+ "arbitrary",
+ )
+ ),
+ )(
+ q_chunk_idx_start,
+ k_chunk_idx_start,
+ q,
+ k,
+ v,
+ ab,
+ q_segment_ids,
+ kv_segment_ids,
+ l,
+ m,
+ do,
+ di,
+ )
+
+ return dq, ds
+
+
+def _verify_block(block_name, dim_name, block, dim, should_divide=True):
+ if block > dim:
+ raise ValueError(
+ f"{block_name}={block} should be smaller or equal to {dim_name}={dim}"
+ )
+ if should_divide and dim % block != 0:
+ raise ValueError(
+ f"{dim_name}={dim} should be divisible by {block_name}={block}"
+ )
diff --git a/lwm/train.py b/lwm/train.py
new file mode 100644
index 0000000..1e1fe28
--- /dev/null
+++ b/lwm/train.py
@@ -0,0 +1,396 @@
+import pprint
+import os
+from functools import partial
+
+from tqdm import tqdm, trange
+import numpy as np
+from absl.app import run
+import absl.logging as logging
+import tux
+
+import jax
+import jax.numpy as jnp
+from jax.experimental.pjit import pjit
+from jax.sharding import PartitionSpec as PS
+from flax.training.train_state import TrainState
+
+from lwm.data import DatasetFactory
+from tux import (
+ JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
+ cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
+ set_random_seed, average_metrics, get_mask,
+ make_shard_and_gather_fns, with_sharding_constraint, define_flags_with_default,
+ OptimizerFactory, StreamingCheckpointer
+)
+from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLMModule
+from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLMModule
+
+
+FLAGS, FLAGS_DEF = define_flags_with_default(
+ modality='text',
+ use_data_sharded_loader=True,
+ seed=42,
+ mesh_dim='1,-1,1,1',
+ dtype='fp32',
+ total_steps=10000,
+ load_llama_config='',
+ update_llama_config='',
+ load_checkpoint='',
+ load_dataset_state='',
+ log_freq=50,
+ save_model_freq=0,
+ save_milestone_freq=0,
+ eval_steps=0,
+ tokenizer=VideoLLaMAConfig.get_tokenizer_config(),
+ train_dataset=DatasetFactory.get_default_config(),
+ eval_dataset=DatasetFactory.get_default_config(),
+ optimizer=OptimizerFactory.get_default_config(),
+ checkpointer=StreamingCheckpointer.get_default_config(),
+ llama=VideoLLaMAConfig.get_default_config(),
+ logger=tux.WandBLogger.get_default_config(),
+ log_all_worker=False,
+ jax_distributed=JaxDistributedConfig.get_default_config(),
+ autoresume=False,
+)
+
+
+def main(argv):
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
+ variant = tux.get_user_flags(FLAGS, FLAGS_DEF)
+ flags_config_dict = tux.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
+
+ logger = tux.WandBLogger(
+ config=FLAGS.logger,
+ variant=variant,
+ enable=FLAGS.log_all_worker or (jax.process_index() == 0),
+ )
+ set_random_seed(FLAGS.seed)
+
+ if jax.process_index() == 0:
+ output_dir = logger.output_dir
+ else:
+ output_dir = os.path.join(logger.output_dir, logger.experiment_id)
+
+ if FLAGS.modality == 'text':
+ config_cls = LLaMAConfig
+ llama_cls = FlaxLLaMAForCausalLMModule
+ elif FLAGS.modality == 'vision,text':
+ config_cls = VideoLLaMAConfig
+ llama_cls = FlaxVideoLLaMAForCausalLMModule
+ else:
+ raise ValueError(f"Unsupported modality: {FLAGS.modality}")
+
+ mesh = config_cls.get_jax_mesh(FLAGS.mesh_dim)
+ node_info = config_cls.get_ranks_and_size(mesh)
+
+ tokenizer = config_cls.get_tokenizer(FLAGS.tokenizer)
+ dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer, node_info=node_info)
+ if FLAGS.autoresume and tux.check_exists(output_dir):
+ logging.info('Found existing output. Resuming dataset from latest checkpoint...')
+ resume_path = f"{output_dir}/dataset.pkl"
+ dataset.load_state_dict(tux.load_pickle(resume_path))
+ elif FLAGS.load_dataset_state != '':
+ dataset.load_state_dict(tux.load_pickle(FLAGS.load_dataset_state))
+
+ if FLAGS.eval_steps > 0:
+ eval_dataset = DatasetFactory.load_dataset(
+ FLAGS.eval_dataset, dataset.tokenizer
+ )
+ eval_iterator = iter(eval_dataset)
+
+ seq_length = dataset.seq_length
+
+ if FLAGS.load_llama_config != '':
+ llama_config = config_cls.load_config(FLAGS.load_llama_config)
+ updates = config_cls(**FLAGS.llama)
+ llama_config.update(dict(
+ remat_block=updates.remat_block,
+ remat_attention=updates.remat_attention,
+ remat_mlp=updates.remat_mlp,
+ scan_attention=updates.scan_attention,
+ scan_mlp=updates.scan_mlp,
+ scan_query_chunk_size=updates.scan_query_chunk_size,
+ scan_key_chunk_size=updates.scan_key_chunk_size,
+ scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
+ scan_layers=updates.scan_layers,
+ param_scan_axis=updates.param_scan_axis,
+ ))
+ else:
+ llama_config = config_cls(**FLAGS.llama)
+
+ if FLAGS.update_llama_config != '':
+ llama_config.update(dict(eval(FLAGS.update_llama_config)))
+
+ llama_config.update(dict(
+ bos_token_id=dataset.tokenizer.bos_token_id,
+ eos_token_id=dataset.tokenizer.eos_token_id,
+ ))
+ if llama_config.vocab_size < dataset.vocab_size:
+ llama_config.update(dict(vocab_size=dataset.vocab_size))
+ llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
+
+ model = llama_cls(
+ llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
+ )
+
+ optimizer, optimizer_info = OptimizerFactory.get_optimizer(
+ FLAGS.optimizer,
+ get_mask(config_cls.get_weight_decay_exclusions()),
+ None,
+ )
+
+ def create_trainstate_from_params(params):
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
+
+ def init_fn(rng):
+ rng_generator = JaxRNG(rng)
+ batch = 512
+ if FLAGS.modality == 'text':
+ params = model.init(
+ input_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
+ position_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
+ attention_mask=jnp.ones((batch, seq_length), dtype=jnp.int32),
+ rngs=rng_generator(llama_config.rng_keys()),
+ )
+ elif FLAGS.modality == 'vision,text':
+ params = model.init(
+ input_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
+ vision_masks=jnp.zeros((batch, seq_length), dtype=bool),
+ position_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
+ attention_mask=jnp.ones((batch, seq_length), dtype=jnp.int32),
+ rngs=rng_generator(llama_config.rng_keys()),
+ )
+ else:
+ raise ValueError(f"Unsupported modality: {FLAGS.modality}")
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
+
+ def train_step(train_state, rng, batch):
+ rng_generator = JaxRNG(rng)
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
+ def loss_and_accuracy(params):
+ if FLAGS.modality == 'text':
+ logits = model.apply(
+ params,
+ batch['input_tokens'],
+ deterministic=False,
+ rngs=rng_generator(llama_config.rng_keys()),
+ ).logits
+ loss, acc = cross_entropy_loss_and_accuracy(
+ logits,
+ batch['target_tokens'],
+ batch['loss_masks']
+ )
+ metrics = dict(acc=acc)
+ return loss, metrics
+ elif FLAGS.modality == 'vision,text':
+ vision_logits, text_logits = model.apply(
+ params,
+ batch['input_tokens'],
+ batch['input_vision_masks'],
+ deterministic=False,
+ rngs=rng_generator(llama_config.rng_keys()),
+ ).logits
+ vision_loss, vision_acc = cross_entropy_loss_and_accuracy(
+ vision_logits,
+ jnp.where(batch['target_vision_masks'], batch['target_tokens'], 0),
+ batch['loss_masks'] * batch['target_vision_masks']
+ )
+ text_loss, text_acc = cross_entropy_loss_and_accuracy(
+ text_logits,
+ jnp.where(batch['target_vision_masks'], 0, batch['target_tokens']),
+ batch['loss_masks'] * (1.0 - batch['target_vision_masks'])
+ )
+ loss = 0.5 * (vision_loss + text_loss)
+
+ metrics = dict(
+ vision_loss=vision_loss,
+ vision_acc=vision_acc,
+ text_loss=text_loss,
+ text_acc=text_acc,
+ )
+ else:
+ raise ValueError(f"Unsupported modality: {FLAGS.modality}")
+ return loss, metrics
+ grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
+ (loss, loss_metrics), grads = grad_fn(train_state.params)
+ train_state = train_state.apply_gradients(grads=grads)
+ metrics = dict(
+ loss=loss,
+ learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
+ param_norm=global_norm(train_state.params),
+ gradient_norm=global_norm(grads),
+ **loss_metrics
+ )
+ return train_state, rng_generator(), metrics
+
+ def eval_step(train_state, rng, batch):
+ rng_generator = JaxRNG(rng)
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
+ if FLAGS.modality == 'text':
+ logits = model.apply(
+ train_state.params,
+ batch['input_tokens'],
+ deterministic=True,
+ rngs=rng_generator(llama_config.rng_keys()),
+ ).logits
+ loss, acc = cross_entropy_loss_and_accuracy(
+ logits,
+ batch['target_tokens'],
+ batch['loss_masks']
+ )
+ metrics = dict(
+ eval_loss=loss,
+ eval_acc=acc,
+ )
+ elif FLAGS.modality == 'vision,text':
+ vision_logits, text_logits = model.apply(
+ train_state.params,
+ batch['input_tokens'],
+ batch['input_vision_masks'],
+ deterministic=True,
+ rngs=rng_generator(llama_config.rng_keys()),
+ ).logits
+ vision_loss, vision_acc = cross_entropy_loss_and_accuracy(
+ vision_logits,
+ jnp.where(batch['target_vision_masks'], batch['target_tokens'], 0),
+ batch['loss_masks'] * batch['target_vision_masks']
+ )
+ text_loss, text_acc = cross_entropy_loss_and_accuracy(
+ text_logits,
+ jnp.where(batch['target_vision_masks'], 0, batch['target_tokens']),
+ batch['loss_masks'] * (1.0 - batch['target_vision_masks'])
+ )
+ loss = 0.5 * (vision_loss + text_loss)
+ metrics = dict(
+ eval_loss=loss,
+ eval_vision_accuracy=vision_acc,
+ eval_vision_loss=vision_loss,
+ eval_text_accuracy=text_acc,
+ eval_text_loss=text_loss,
+ )
+ return rng_generator(), metrics
+
+ train_state_shapes = jax.eval_shape(init_fn, next_rng())
+ train_state_partition = match_partition_rules(
+ config_cls.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), train_state_shapes
+ )
+
+ shard_fns, gather_fns = make_shard_and_gather_fns(
+ train_state_partition, train_state_shapes
+ )
+ checkpointer = StreamingCheckpointer(
+ FLAGS.checkpointer, logger.output_dir,
+ enable=jax.process_index() == 0,
+ )
+
+ sharded_init_fn = pjit(
+ init_fn,
+ in_shardings=PS(),
+ out_shardings=train_state_partition
+ )
+
+ sharded_create_trainstate_from_params = pjit(
+ create_trainstate_from_params,
+ in_shardings=(train_state_partition.params, ),
+ out_shardings=train_state_partition,
+ donate_argnums=(0, ),
+ )
+
+ if FLAGS.use_data_sharded_loader:
+ batch_spec = PS(('dp', 'fsdp'), 'sp')
+ else:
+ batch_spec = PS()
+ sharded_train_step = pjit(
+ train_step,
+ in_shardings=(train_state_partition, PS(), batch_spec),
+ out_shardings=(train_state_partition, PS(), PS()),
+ donate_argnums=(0, 1),
+ )
+
+ sharded_eval_step = pjit(
+ eval_step,
+ in_shardings=(train_state_partition, PS(), PS()),
+ out_shardings=(PS(), PS()),
+ donate_argnums=(1,),
+ )
+
+ def save_checkpoint(train_state, milestone=False):
+ step = int(jax.device_get(train_state.step))
+ metadata = dict(
+ step=step,
+ variant=variant,
+ flags=flags_config_dict,
+ llama_config=llama_config.to_dict(),
+ )
+ checkpointer.save_all(
+ train_state=train_state,
+ gather_fns=gather_fns,
+ metadata=metadata,
+ dataset=dataset.get_state_dict(),
+ milestone=milestone,
+ )
+
+ with mesh:
+ train_state, restored_params = None, None
+
+ if FLAGS.autoresume and tux.check_exists(output_dir):
+ logging.info('Found existing output. Resuming model from latest checkpoint...')
+ resume_path = f"trainstate::{output_dir}/streaming_train_state"
+ train_state, restored_params = checkpointer.load_trainstate_checkpoint(
+ resume_path, train_state_shapes, shard_fns, max_buffer_size=32 * 2 ** 30
+ )
+ elif FLAGS.load_checkpoint != '':
+ train_state, restored_params = checkpointer.load_trainstate_checkpoint(
+ FLAGS.load_checkpoint, train_state_shapes, shard_fns, max_buffer_size=32 * 2 ** 30
+ )
+
+ if train_state is None and restored_params is None:
+ # Initialize from scratch
+ train_state = sharded_init_fn(next_rng())
+ elif train_state is None and restored_params is not None:
+ # Restore from params but initialize train_state
+ train_state = sharded_create_trainstate_from_params(restored_params)
+ del restored_params
+
+ start_step = int(jax.device_get(train_state.step))
+
+ if FLAGS.save_model_freq > 0:
+ save_checkpoint(train_state)
+
+ sharded_rng = next_rng()
+
+ step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
+ for step, (batch, dataset_metrics) in zip(step_counter, dataset):
+ train_state, sharded_rng, metrics = sharded_train_step(
+ train_state, sharded_rng, batch
+ )
+ if step % FLAGS.log_freq == 0:
+ if FLAGS.eval_steps > 0:
+ eval_metric_list = []
+ for _ in range(FLAGS.eval_steps):
+ eval_batch, _ = next(eval_iterator)
+ sharded_rng, eval_metrics = sharded_eval_step(
+ train_state, sharded_rng, eval_batch
+ )
+ eval_metrics = jax.device_get(eval_metrics)
+ eval_metric_list.append(eval_metrics)
+ metrics.update(average_metrics(eval_metric_list))
+
+ log_metrics = {"step": step}
+ log_metrics.update(metrics)
+ log_metrics.update(dataset_metrics)
+ log_metrics = jax.device_get(log_metrics)
+ logger.log(log_metrics)
+ tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
+
+ if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
+ save_checkpoint(train_state, milestone=True)
+ elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
+ save_checkpoint(train_state)
+
+ if FLAGS.save_model_freq > 0:
+ save_checkpoint(train_state)
+
+
+if __name__ == "__main__":
+ run(main)
diff --git a/lwm/vision_chat.py b/lwm/vision_chat.py
new file mode 100644
index 0000000..f398d11
--- /dev/null
+++ b/lwm/vision_chat.py
@@ -0,0 +1,254 @@
+from absl.app import run
+import math
+from tqdm import tqdm
+from PIL import Image
+import decord
+from functools import cached_property
+import numpy as np
+import jax
+from jax.experimental.pjit import pjit
+from jax.sharding import PartitionSpec as PS
+from transformers import GenerationConfig
+from tux import (
+ define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
+ set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
+ match_partition_rules, make_shard_and_gather_fns,
+ with_sharding_constraint, tree_apply, open_file
+)
+from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM
+from lwm.vqgan import VQGAN
+
+
+FLAGS, FLAGS_DEF = define_flags_with_default(
+ prompt="",
+ input_file="",
+ vqgan_checkpoint="",
+ temperature=0.2,
+ max_n_frames=8,
+ seed=1234,
+ mesh_dim='1,-1,1,1',
+ dtype='fp32',
+ load_llama_config='',
+ update_llama_config='',
+ load_checkpoint='',
+ tokenizer=VideoLLaMAConfig.get_tokenizer_config(),
+ llama=VideoLLaMAConfig.get_default_config(),
+ jax_distributed=JaxDistributedConfig.get_default_config(),
+)
+
+
+class Sampler:
+ def __init__(self):
+ self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
+ self.vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
+ self.prefix_tokenizer = VideoLLaMAConfig.get_tokenizer(
+ FLAGS.tokenizer, truncation_side='left', padding_side='left'
+ )
+ self.tokenizer = VideoLLaMAConfig.get_tokenizer(FLAGS.tokenizer)
+ self.n_tokens_per_frame = 257
+ self.min_buffer_size = 256
+ self.sharded_rng = next_rng()
+ self._load_model()
+
+ @property
+ def block_size(self):
+ return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']
+
+ @property
+ def data_dim(self):
+ return self.mesh.shape['dp'] * self.mesh.shape['fsdp']
+
+ def _process_frame(self, image, size):
+ width, height = image.size
+ if width < height:
+ new_width = size
+ new_height = int(size * height / width)
+ else:
+ new_height = size
+ new_width = int(size * width / height)
+ image = image.resize((new_width, new_height))
+
+ left = (new_width - size) / 2
+ top = (new_height - size) / 2
+ right = (new_width + size) / 2
+ bottom = (new_height + size) / 2
+ image = image.crop((left, top, right, bottom))
+ return np.array(image, dtype=np.float32) / 127.5 - 1
+
+ def _read_process_vision(self, path, max_n_frames):
+ f = open_file(path, 'rb')
+ if path.endswith('.png') or path.endswith('.jpg'):
+ image = Image.open(f)
+ vision = self._process_frame(image, 256)[None]
+ else:
+ vr = decord.VideoReader(f, ctx=decord.cpu(0))
+ duration = len(vr)
+ if duration <= max_n_frames:
+ frame_id_list = list(range(duration))
+ else:
+ frame_id_list = np.linspace(0, duration - 1, max_n_frames, dtype=int).tolist()
+ video = vr.get_batch(frame_id_list).asnumpy()
+ vision = np.stack([self._process_frame(Image.fromarray(frame), 256) for frame in video])
+
+ B = 1
+ encodings = []
+ for i in range(0, len(vision), 1):
+ v = vision[i:i + B]
+ if len(v) % B == 0:
+ n_pad = 0
+ else:
+ n_pad = B - len(v) % B
+ v = np.pad(v, ((n_pad, 0), (0, 0), (0, 0), (0, 0)))
+ enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)
+ enc = enc[n_pad:]
+ for t in range(len(enc)):
+ encodings.extend(enc[t].reshape(-1).tolist())
+ if t == len(enc) - 1:
+ encodings.append(8193)
+ else:
+ encodings.append(8192)
+ return encodings
+
+ def construct_input(self, prompts, max_n_frames):
+ max_input_length = max_n_frames * self.n_tokens_per_frame + self.min_buffer_size
+ max_input_length = int(math.ceil(max_input_length / self.block_size) * self.block_size)
+
+ vision_start = self.tokenizer.encode('')
+ vision_end = self.tokenizer.encode('')
+
+ input_ids = np.zeros((len(prompts), max_input_length), dtype=int)
+ vision_masks = np.zeros((len(prompts), max_input_length), dtype=bool)
+ attention_mask = np.zeros((len(prompts), max_input_length), dtype=int)
+ for i, prompt in enumerate(tqdm(prompts)):
+ vision = self._read_process_vision(prompt['input_path'], max_n_frames)
+ text_1 = self.tokenizer.encode(f"You are a helpful assistant. USER: {prompt['question']}\n")
+ tail = self.tokenizer.encode(" ASSISTANT:")
+
+ tokens, vm = [], []
+ tokens.extend(text_1)
+ vm.extend([False] * len(text_1))
+ tokens.extend(vision_start)
+ vm.extend([False] * len(vision_start))
+ tokens.extend(vision)
+ vm.extend([True] * len(vision))
+ tokens.extend(vision_end)
+ vm.extend([False] * len(vision_end))
+ tokens.extend(tail)
+ vm.extend([False] * len(tail))
+ assert len(tokens) < max_input_length, (len(tokens), max_input_length)
+ assert len(tokens) == len(vm)
+ input_ids[i, -len(tokens):] = tokens
+ vision_masks[i, -len(tokens):] = vm
+ attention_mask[i, -len(tokens):] = 1
+ return {
+ 'input_ids': input_ids,
+ 'vision_masks': vision_masks,
+ 'attention_mask': attention_mask
+ }
+
+
+ def _load_model(self):
+ if FLAGS.load_llama_config != '':
+ llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)
+ updates = VideoLLaMAConfig(**FLAGS.llama)
+ llama_config.update(dict(
+ remat_block=updates.remat_block,
+ remat_attention=updates.remat_attention,
+ remat_mlp=updates.remat_mlp,
+ scan_attention=updates.scan_attention,
+ scan_mlp=updates.scan_mlp,
+ scan_query_chunk_size=updates.scan_query_chunk_size,
+ scan_key_chunk_size=updates.scan_key_chunk_size,
+ scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
+ scan_layers=updates.scan_layers,
+ param_scan_axis=updates.param_scan_axis,
+ ))
+ else:
+ llama_config = VideoLLaMAConfig(**FLAGS.llama)
+
+ if FLAGS.update_llama_config != '':
+ llama_config.update(dict(eval(FLAGS.update_llama_config)))
+
+ llama_config.update(dict(
+ bos_token_id=self.tokenizer.bos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ ))
+ llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
+ self.config = llama_config
+
+ self.model = FlaxVideoLLaMAForCausalLM(
+ llama_config,
+ input_shape=(512, self.block_size),
+ seed=FLAGS.seed,
+ _do_init=False,
+ dtype=get_float_dtype_by_name(FLAGS.dtype),
+ )
+
+ with jax.default_device(jax.devices("cpu")[0]):
+ _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
+ FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
+ )
+ self.model_ps = match_partition_rules(
+ VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
+ )
+ shard_fns, _ = make_shard_and_gather_fns(
+ self.model_ps, get_float_dtype_by_name(FLAGS.dtype)
+ )
+
+ with self.mesh:
+ self.params = tree_apply(shard_fns, self.params)
+
+ @cached_property
+ def _forward_generate(self):
+ def fn(params, rng, batch):
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
+ rng_generator = JaxRNG(rng)
+ output = self.model.generate(
+ batch['input_ids'],
+ vision_masks=batch['vision_masks'],
+ attention_mask=batch['attention_mask'],
+ params=params['params'],
+ prng_key=rng_generator(),
+ generation_config=GenerationConfig(
+ max_new_tokens=self.block_size,
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ temperature=FLAGS.temperature,
+ do_sample=True,
+ )
+ ).sequences[:, batch['input_ids'].shape[1]:]
+ return output, rng_generator()
+ return pjit(
+ fn,
+ in_shardings=(self.model_ps, PS(), PS()),
+ out_shardings=(PS(), PS())
+ )
+
+ def __call__(self, prompts, max_n_frames):
+ batch = self.construct_input(prompts, max_n_frames)
+ with self.mesh:
+ output, self.sharded_rng = self._forward_generate(
+ self.params, self.sharded_rng, batch
+ )
+ output = jax.device_get(output)
+ output_text = []
+ for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)):
+ if self.tokenizer.eos_token in text:
+ text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]
+ output_text.append(text)
+ return output_text
+
+def main(argv):
+ assert FLAGS.prompt != ''
+ assert FLAGS.input_file != ''
+
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
+ set_random_seed(FLAGS.seed)
+
+ prompts = [{'input_path': FLAGS.input_file, 'question': FLAGS.prompt}]
+ sampler = Sampler()
+ output = sampler(prompts, FLAGS.max_n_frames)[0]
+ print(f"Question: {FLAGS.prompt}\nAnswer: {output}")
+
+if __name__ == "__main__":
+ run(main)
diff --git a/lwm/vision_generation.py b/lwm/vision_generation.py
new file mode 100644
index 0000000..c903402
--- /dev/null
+++ b/lwm/vision_generation.py
@@ -0,0 +1,258 @@
+from absl.app import run
+from tqdm import tqdm
+import imageio
+import numpy as np
+from PIL import Image
+from transformers import GenerationConfig
+import jax
+import jax.numpy as jnp
+from jax.experimental.pjit import pjit
+from jax.sharding import PartitionSpec as PS
+from tux import (
+ define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
+ set_random_seed, get_float_dtype_by_name, JaxRNG,
+ match_partition_rules, make_shard_and_gather_fns,
+ with_sharding_constraint, tree_apply, next_rng
+)
+from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM
+from lwm.vqgan import VQGAN
+
+
+FLAGS, FLAGS_DEF = define_flags_with_default(
+ prompt='Fireworks over the city',
+ output_file='',
+ temperature_image=1.0,
+ temperature_video=1.0,
+ top_k_image=8192,
+ top_k_video=100,
+ cfg_scale_image=1.0,
+ cfg_scale_video=1.0,
+ vqgan_checkpoint='',
+ n_frames=1,
+ seed=1234,
+ mesh_dim='1,-1,1,1',
+ dtype='fp32',
+ load_llama_config='',
+ update_llama_config='',
+ load_checkpoint='',
+ tokenizer=VideoLLaMAConfig.get_tokenizer_config(),
+ llama=VideoLLaMAConfig.get_default_config(),
+ jax_distributed=JaxDistributedConfig.get_default_config(),
+)
+
+
+def main(argv):
+ assert FLAGS.output_file != ''
+ if FLAGS.output_file.endswith('mp4'):
+ assert FLAGS.n_frames > 1
+ elif FLAGS.output_file.endswith('png') or FLAGS.output_file.endswith('jpg'):
+ assert FLAGS.n_frames == 1
+ else:
+ raise ValueError(f"Unsupported output file extension: {FLAGS.output_file}")
+
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
+ set_random_seed(FLAGS.seed)
+
+ tokens_per_frame = 257
+ vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
+ mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
+ tokenizer = VideoLLaMAConfig.get_tokenizer(FLAGS.tokenizer)
+ prefix_tokenizer = VideoLLaMAConfig.get_tokenizer(
+ FLAGS.tokenizer, truncation_side='left', padding_side='left'
+ )
+ if FLAGS.load_llama_config != '':
+ llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)
+ updates = VideoLLaMAConfig(**FLAGS.llama)
+ llama_config.update(dict(
+ remat_block=updates.remat_block,
+ remat_attention=updates.remat_attention,
+ remat_mlp=updates.remat_mlp,
+ scan_attention=updates.scan_attention,
+ scan_mlp=updates.scan_mlp,
+ scan_query_chunk_size=updates.scan_query_chunk_size,
+ scan_key_chunk_size=updates.scan_key_chunk_size,
+ scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
+ scan_layers=updates.scan_layers,
+ param_scan_axis=updates.param_scan_axis,
+ ))
+ else:
+ llama_config = VideoLLaMAConfig(**FLAGS.llama)
+
+ if FLAGS.update_llama_config != '':
+ llama_config.update(dict(eval(FLAGS.update_llama_config)))
+
+ llama_config.update(dict(
+ bos_token_id=tokenizer.bos_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ ))
+ llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
+
+ with jax.default_device(jax.devices("cpu")[0]):
+ _, params = StreamingCheckpointer.load_trainstate_checkpoint(
+ FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
+ )
+ model = FlaxVideoLLaMAForCausalLM(
+ llama_config,
+ input_shape=(512, 8192),
+ seed=FLAGS.seed,
+ _do_init=False,
+ dtype=get_float_dtype_by_name(FLAGS.dtype),
+ )
+ model_ps = match_partition_rules(
+ VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), params
+ )
+ shard_fns, _ = make_shard_and_gather_fns(
+ model_ps, get_float_dtype_by_name(FLAGS.dtype)
+ )
+
+ with mesh:
+ params = tree_apply(shard_fns, params)
+
+ def _forward_generate(params, rng, batch, n_tokens, cfg_scale, top_k, temperature):
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
+ cfg_scales = jnp.ones((batch['input_ids'].shape[0] // 2,), dtype=jnp.float32) * cfg_scale
+ cfg_scales = with_sharding_constraint(cfg_scales, PS(('dp', 'fsdp')))
+ rng_generator = JaxRNG(rng)
+ output = model.generate_vision(
+ batch['input_ids'],
+ cfg_scales,
+ attention_mask=batch['attention_mask'],
+ vision_masks=batch['vision_masks'],
+ params=params['params'],
+ prng_key=rng_generator(),
+ generation_config=GenerationConfig(
+ max_new_tokens=n_tokens,
+ min_new_tokens=n_tokens,
+ pad_token_id=tokenizer.pad_token_id,
+ temperature=temperature,
+ do_sample=True,
+ top_k=top_k,
+ )
+ ).sequences[:, batch['input_ids'].shape[1]:]
+ return output, rng_generator()
+ _sharded_forward_generate = pjit(
+ _forward_generate,
+ in_shardings=(model_ps, PS(), PS()),
+ out_shardings=(PS(), PS()),
+ static_argnums=(3, 4, 5, 6)
+ )
+
+ # Generate an image or first frame (for video)
+ def generate_first_frame(prompts, max_input_length):
+ nonlocal sharded_rng
+ uncond_prompts = [""] * len(prompts)
+ prompts = prompts + uncond_prompts
+ inputs = prefix_tokenizer(
+ prompts,
+ padding='max_length',
+ truncation=True,
+ max_length=max_input_length,
+ return_tensors='np'
+ )
+ batch = dict(
+ input_ids=inputs.input_ids,
+ attention_mask=inputs.attention_mask,
+ vision_masks=np.zeros(inputs.input_ids.shape, dtype=bool),
+ )
+ with mesh:
+ output, sharded_rng = _sharded_forward_generate(
+ params, sharded_rng, batch,
+ tokens_per_frame, FLAGS.cfg_scale_image,
+ FLAGS.top_k_image, FLAGS.temperature_image
+ )
+ output = jax.device_get(output)
+ output = np.split(output, 2, axis=0)[0]
+ output = output.reshape(len(prompts) // 2, tokens_per_frame)
+ image = vqgan.decode(output[:, :-1].reshape(-1, 16, 16))
+ image = ((jax.device_get(image) + 1) * 127.5).astype(np.uint8)
+ return output, image
+
+ sharded_rng = next_rng()
+ prompts = [FLAGS.prompt]
+ entries = []
+ for prompt in prompts:
+ entries.append({
+ 'caption': prompt,
+ 'prompt': f"You are a helpful assistant. USER: Generate an image of {prompt} ASSISTANT: ",
+ })
+
+ B = 1
+ images, image_encodings = [], []
+ for i in tqdm(list(range(0, len(entries), B))):
+ entries_i = entries[i:i + B]
+ prompts = [entry['prompt'] for entry in entries_i]
+ img_enc, img = generate_first_frame(prompts, max_input_length=128)
+ image_encodings.extend(img_enc)
+ images.extend(img)
+
+ if FLAGS.n_frames == 1:
+ image = images[0]
+ Image.fromarray(image).save(FLAGS.output_file)
+ return
+
+ # Generate the rest of the video
+ def generate_video_pred(prompts, images, max_input_length):
+ nonlocal sharded_rng
+ images = np.concatenate([images, images], axis=0)
+ uncond_prompts = [""] * len(prompts)
+ prompts = prompts + uncond_prompts
+ inputs = prefix_tokenizer(
+ prompts,
+ padding='max_length',
+ truncation=True,
+ max_length=max_input_length,
+ return_tensors='np'
+ )
+ batch = dict(
+ input_ids=np.concatenate([inputs.input_ids, images], axis=1),
+ attention_mask=np.concatenate([inputs.attention_mask, np.ones(images.shape, dtype=inputs.attention_mask.dtype)], axis=1),
+ vision_masks=np.concatenate([
+ np.zeros(inputs.input_ids.shape, dtype=bool),
+ np.ones(images.shape, dtype=bool)
+ ], axis=1),
+ )
+ with mesh:
+ output, sharded_rng = _sharded_forward_generate(
+ params, sharded_rng, batch,
+ (FLAGS.n_frames - 1) * tokens_per_frame, FLAGS.cfg_scale_video,
+ FLAGS.top_k_video, FLAGS.temperature_video
+ )
+ output = jax.device_get(output)
+ output = np.split(output, 2, axis=0)[0]
+ output = output.reshape(len(prompts) // 2, FLAGS.n_frames - 1, tokens_per_frame)
+ output = np.concatenate([images[:len(prompts) // 2, None], output], axis=1)
+ output = output[:, :, :-1].reshape(-1, FLAGS.n_frames, 16, 16)
+ vision = []
+ for v in output:
+ v = vqgan.decode(v)
+ v = ((jax.device_get(v) + 1) * 127.5).astype(np.uint8)
+ vision.append(v)
+ return vision
+
+ new_entries = []
+ for img_enc, entry in zip(image_encodings, entries):
+ new_entries.append({
+ 'caption': entry['caption'],
+ 'prompt': f"You are a helpful assistant. USER: Generate a video of {entry['caption']} ASSISTANT: ",
+ 'image': np.array(img_enc, dtype=np.int32),
+ })
+ entries = new_entries
+
+ B = 1
+ videos = []
+ for i in tqdm(list(range(0, len(entries), B))):
+ entries_i = entries[i:i + B]
+ prompts = [entry['prompt'] for entry in entries_i]
+ images = np.array([entry['image'] for entry in entries_i], dtype=np.int32)
+ videos.extend(generate_video_pred(prompts, images, max_input_length=128))
+
+ video = videos[0]
+ writer = imageio.get_writer(FLAGS.output_file, fps=4)
+ for frame in video:
+ writer.append_data(frame)
+ writer.close()
+
+ print('done')
+
+if __name__ == "__main__":
+ run(main)
diff --git a/lwm/vision_llama.py b/lwm/vision_llama.py
new file mode 100644
index 0000000..154bec5
--- /dev/null
+++ b/lwm/vision_llama.py
@@ -0,0 +1,734 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+import json
+import warnings
+import copy
+
+import jax
+import jax.numpy as jnp
+from jax import lax
+from jax.sharding import PartitionSpec as PS
+import flax.linen as nn
+from flax.core.frozen_dict import unfreeze, freeze
+from flax.traverse_util import flatten_dict, unflatten_dict
+
+from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
+from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
+from transformers.generation.flax_utils import SampleState, FlaxLogitsProcessorList, FlaxSampleOutput, logger
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward
+from transformers import GenerationConfig
+
+from tux import load_pickle, open_file
+from lwm.llama import LLaMAConfig, LLAMA_STANDARD_CONFIGS, FlaxLLaMABlockCollection, RMSNorm
+
+
+VIDEO_LLAMA_STANDARD_CONFIGS = LLAMA_STANDARD_CONFIGS
+
+
+class VideoLLaMAConfig(LLaMAConfig):
+ model_type = "video_llama"
+
+ def __init__(self, vision_vocab_size=8448, tie_vision_embeddings=False, sample_mode='all', **kwargs):
+ super().__init__(**kwargs)
+ self.vision_vocab_size = vision_vocab_size # 8192 + 256
+ self.tie_vision_embeddings = tie_vision_embeddings
+ self.sample_mode = sample_mode
+
+ @staticmethod
+ def get_partition_rules(scan_layers=False, scan_axis=0):
+ """ Parition rules for GPTJ. Note that these rules are orderd, so that
+ the beginning rules match first. It is important to use
+ PartitionSpec() instead of None here because JAX does not treat
+ None as a pytree leaf.
+ """
+ if scan_layers:
+ if scan_axis == 0:
+ return (
+ # embeddings
+ ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
+ ("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))),
+ # atention
+ ("attention/(wq|wk|wv)/kernel", PS(None, ("fsdp", "sp"), "tp")),
+ ("attention/wo/kernel", PS(None, "tp", ("fsdp", "sp"))),
+ # mlp
+ ("feed_forward/w1/kernel", PS(None, ("fsdp", "sp"), "tp")),
+ ("feed_forward/w2/kernel", PS(None, "tp", ("fsdp", "sp"))),
+ ("feed_forward/w3/kernel", PS(None, ("fsdp", "sp"), "tp")),
+ # layer norms
+ ("attention_norm/kernel", PS(None, None)),
+ ("ffn_norm/kernel", PS(None, None)),
+ # output head
+ ("transformer/ln_f/kernel", PS(None)),
+ ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
+ ("vision_head/kernel", PS(("fsdp", "sp"), "tp")),
+ ('.*', PS(None)),
+ )
+ elif scan_axis == 1:
+ return (
+ # embeddings
+ ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
+ ("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))),
+ # atention
+ ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), None, "tp")),
+ ("attention/wo/kernel", PS("tp", None, ("fsdp", "sp"))),
+ # mlp
+ ("feed_forward/w1/kernel", PS(("fsdp", "sp"), None, "tp")),
+ ("feed_forward/w2/kernel", PS("tp", None, ("fsdp", "sp"))),
+ ("feed_forward/w3/kernel", PS(("fsdp", "sp"), None, "tp")),
+ # layer norms
+ ("attention_norm/kernel", PS(None, None)),
+ ("ffn_norm/kernel", PS(None, None)),
+ # output head
+ ("transformer/ln_f/kernel", PS(None)),
+ ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
+ ("vision_head/kernel", PS(("fsdp", "sp"), "tp")),
+ ('.*', PS(None)),
+ )
+ else:
+ raise ValueError(f"Invalid scan_axis {scan_axis}")
+ else:
+ return (
+ # embeddings
+ ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
+ ("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))),
+ # atention
+ ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), "tp")),
+ ("attention/wo/kernel", PS("tp", ("fsdp", "sp"))),
+ # mlp
+ ("feed_forward/w1/kernel", PS(("fsdp", "sp"), "tp")),
+ ("feed_forward/w2/kernel", PS("tp", ("fsdp", "sp"))),
+ ("feed_forward/w3/kernel", PS(("fsdp", "sp"), "tp")),
+ # layer norms
+ ("attention_norm/kernel", PS(None)),
+ ("ffn_norm/kernel", PS(None)),
+ # output head
+ ("transformer/ln_f/kernel", PS(None)),
+ ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
+ ("vision_head/kernel", PS(("fsdp", "sp"), "tp")),
+ ('.*', PS(None)),
+ )
+
+ @classmethod
+ def load_config(cls, path):
+ if path in VIDEO_LLAMA_STANDARD_CONFIGS:
+ return cls.from_dict(VIDEO_LLAMA_STANDARD_CONFIGS[path])
+ load_type, load_path = path.split('::', 1)
+ if load_type == 'pickle':
+ return cls.from_dict(load_pickle(load_path)['llama_config'])
+ elif load_type == 'json':
+ with open_file(load_path, 'r') as fin:
+ raw_config = fin.read()
+ return cls.from_dict(json.loads(raw_config))
+ else:
+ raise ValueError(f'Unsupported load config type: {load_type}')
+
+
+class FlaxVideoLLaMAPreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = VideoLLaMAConfig
+ base_model_prefix = "transformer"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: VideoLLaMAConfig,
+ input_shape: Tuple = (4, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_cache(self, batch_size, max_length):
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length))
+ attention_mask = jnp.ones_like(input_ids)
+ segment_ids = jnp.zeros_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+ vision_masks = jnp.ones((batch_size, max_length), dtype=bool)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True
+ )
+ return init_variables["cache"]
+
+ def init_weights(self, rng, input_shape, params=None):
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+ vision_masks = jnp.ones(input_ids.shape, dtype=bool)
+ segment_ids = jnp.zeros_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ @add_start_docstrings_to_model_forward("")
+ def __call__(
+ self,
+ input_ids,
+ vision_masks,
+ attention_mask=None,
+ segment_ids=None,
+ position_ids=None,
+ params: dict = None,
+ past_key_values: dict = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ batch_size, sequence_length = input_ids.shape
+
+ if position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
+
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ if attention_mask is None:
+ attention_mask = jnp.ones((batch_size, sequence_length))
+
+ if segment_ids is None:
+ segment_ids = jnp.zeros((batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(vision_masks, dtype="f4"),
+ jnp.array(attention_mask, dtype="i4"),
+ jnp.array(segment_ids, dtype="i4"),
+ jnp.array(position_ids, dtype="i4"),
+ not train,
+ False,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+class FlaxVideoLLaMAModule(nn.Module):
+ config: VideoLLaMAConfig
+ dtype: jnp.dtype = jnp.float32
+ param_dtype: jnp.dtype=jnp.float32
+ precision: Optional[Union[jax.lax.Precision, str]]=None
+
+ def setup(self):
+ self.embed_dim = self.config.hidden_size
+
+ self.vte = nn.Embed(
+ self.config.vision_vocab_size,
+ self.config.hidden_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ )
+
+ self.wte = nn.Embed(
+ self.config.vocab_size,
+ self.config.hidden_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ )
+ self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
+ self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
+ self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ vision_masks,
+ attention_mask,
+ segment_ids,
+ position_ids,
+ deterministic=True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ input_ids = input_ids.astype("i4")
+
+ if input_ids.shape[1] == 1:
+ if self.config.sample_mode == 'text':
+ input_embeds = self.wte(input_ids)
+ elif self.config.sample_mode == 'vision':
+ input_embeds = self.vte(input_ids)
+ elif self.config.sample_mode == 'all':
+ raise NotImplementedError
+ else:
+ raise ValueError(f"Invalid sample_mode: {self.config.sample_mode}")
+ else:
+ input_text_embeds = self.wte(jnp.where(vision_masks, 0, input_ids))
+ input_vision_embeds = self.vte(jnp.where(vision_masks, input_ids, 0))
+ vision_masks = vision_masks[..., None].astype("f4") # 1 is vision, 0 is text
+ input_embeds = input_text_embeds * (1 - vision_masks) + input_vision_embeds * vision_masks
+
+ hidden_states = self.dropout(input_embeds, deterministic=deterministic)
+
+ outputs = self.h(
+ hidden_states,
+ attention_mask,
+ segment_ids,
+ position_ids=position_ids,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = outputs[1] + (hidden_states,)
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=outputs[1],
+ attentions=outputs[-1],
+ )
+
+
+class FlaxVideoLLaMAForCausalLMModule(nn.Module):
+ config: VideoLLaMAConfig
+ dtype: jnp.dtype = jnp.float32
+ param_dtype: jnp.dtype=jnp.float32
+ precision: Optional[Union[jax.lax.Precision, str]]=None
+
+ def setup(self):
+ self.transformer = FlaxVideoLLaMAModule(self.config, dtype=self.dtype)
+ self.vision_head = nn.Dense(
+ self.config.vision_vocab_size,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ precision=self.precision,
+ )
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ precision=self.precision,
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ vision_masks,
+ attention_mask=None,
+ segment_ids=None,
+ position_ids=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ batch_size, seq_length = input_ids.shape
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+ if segment_ids is None:
+ segment_ids = jnp.zeros_like(input_ids)
+ if position_ids is None:
+ position_ids = jnp.broadcast_to(
+ jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
+ (batch_size, seq_length)
+ )
+
+
+ outputs = self.transformer(
+ input_ids,
+ vision_masks,
+ attention_mask,
+ segment_ids,
+ position_ids,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_vision_embeddings:
+ shared_kernel = self.transformer.variables["params"]["vte"]["embedding"].T
+ vision_logits = self.vision_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
+ else:
+ vision_logits = self.vision_head(hidden_states)
+
+ if self.config.tie_word_embeddings:
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ if self.config.sample_mode == 'all':
+ if not return_dict:
+ return (vision_logits, lm_logits,) + outputs[1:]
+
+ return FlaxCausalLMOutput(logits=(vision_logits, lm_logits), hidden_states=outputs.hidden_states, attentions=outputs.attentions)
+ elif self.config.sample_mode == 'vision':
+ if not return_dict:
+ return (vision_logits,) + outputs[1:]
+
+ return FlaxCausalLMOutput(logits=vision_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
+ elif self.config.sample_mode == 'text':
+ if not return_dict:
+ return (lm_logits,) + outputs[1:]
+
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
+ else:
+ raise ValueError(f"Invalid sample_mode: {self.config.sample_mode}")
+
+
+
+@add_start_docstrings("", "")
+class FlaxVideoLLaMAForCausalLM(FlaxVideoLLaMAPreTrainedModel):
+ module_class = FlaxVideoLLaMAForCausalLMModule
+
+ def prepare_inputs_for_generation(
+ self, input_ids, max_length, attention_mask: Optional[jax.Array] = None, vision_masks = None
+ ):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since GPTJ uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ "vision_masks": vision_masks
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ return {
+ "past_key_values": model_outputs.past_key_values,
+ "position_ids": model_kwargs["position_ids"][:, -1:] + 1,
+ "attention_mask": model_kwargs["attention_mask"],
+ "vision_masks": model_kwargs["vision_masks"]
+ }
+
+ def _sample_vision(
+ self,
+ input_ids: None,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[int] = None,
+ prng_key: Optional[jnp.ndarray] = None,
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
+ logits_warper: Optional[FlaxLogitsProcessorList] = None,
+ cfg_scales: jnp.ndarray = 1.0,
+ trace: bool = True,
+ params: Optional[Dict[str, jnp.ndarray]] = None,
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
+ ):
+ # init values
+ max_length = max_length if max_length is not None else self.generation_config.max_length
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
+
+ batch_size, cur_len = input_ids.shape
+ initial_len = cur_len
+
+ eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
+ pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
+ cur_len = jnp.array(cur_len)
+
+ # per batch-item holding current token in loop.
+ sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
+ sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
+
+ # per batch-item state bit indicating if sentence has finished.
+ is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
+
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
+ model = self.decode if self.config.is_encoder_decoder else self
+
+ # initialize model specific kwargs
+ model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
+
+ # initialize state
+ state = SampleState(
+ cur_len=cur_len,
+ sequences=sequences,
+ running_token=input_ids,
+ is_sent_finished=is_sent_finished,
+ prng_key=prng_key,
+ model_kwargs=model_kwargs,
+ )
+
+ def sample_search_cond_fn(state):
+ """state termination condition fn."""
+ has_reached_max_length = state.cur_len == max_length
+ all_sequence_finished = jnp.all(state.is_sent_finished)
+ finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
+ return ~finish_generation
+
+ def sample_search_body_fn(state):
+ """state update fn."""
+ prng_key, prng_key_next = jax.random.split(state.prng_key)
+ model_outputs = model(state.running_token, params=params, **state.model_kwargs)
+
+ logits = model_outputs.logits[:, -1]
+ cond_logits, uncond_logits = jnp.split(logits, 2, axis=0)
+ logits = uncond_logits + cfg_scales[:, None] * (cond_logits - uncond_logits)
+
+ # apply min_length, ...
+ logits = logits_processor(state.sequences, logits, state.cur_len)
+ # apply top_p, top_k, temperature
+ logits = logits_warper(logits, logits, state.cur_len)
+
+ next_token = jax.random.categorical(prng_key, logits, axis=-1)
+ next_token = jax.lax.cond(
+ (state.cur_len - initial_len + 1) % 257 == 0,
+ lambda: jnp.full_like(next_token, 8192),
+ lambda: next_token
+ )
+ next_token = jnp.concatenate([next_token, next_token], axis=0)
+
+ #next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
+ next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
+ next_token = next_token[:, None]
+
+ next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
+ next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
+
+ return SampleState(
+ cur_len=state.cur_len + 1,
+ sequences=next_sequences,
+ running_token=next_token,
+ is_sent_finished=next_is_sent_finished,
+ model_kwargs=next_model_kwargs,
+ prng_key=prng_key_next,
+ )
+
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
+ if input_ids.shape[1] > 1:
+ state = sample_search_body_fn(state)
+
+ if not trace:
+ state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
+ else:
+ state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
+
+ return FlaxSampleOutput(sequences=state.sequences)
+
+ def generate_vision(
+ self,
+ input_ids: jnp.ndarray,
+ cfg_scales: jnp.ndarray,
+ generation_config: Optional[GenerationConfig] = None,
+ prng_key: Optional[jnp.ndarray] = None,
+ trace: bool = True,
+ params: Optional[Dict[str, jnp.ndarray]] = None,
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
+ **kwargs,
+ ):
+ # Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
+ self._validate_model_class()
+
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
+ if generation_config is None:
+ # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
+ # two conditions must be met
+ # 1) the generation config must have been created from the model config (`_from_model_config` field);
+ # 2) the generation config must have seen no modification since its creation (the hash is the same).
+ if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
+ self.generation_config
+ ):
+ new_generation_config = GenerationConfig.from_model_config(self.config)
+ if new_generation_config != self.generation_config:
+ warnings.warn(
+ "You have modified the pretrained model configuration to control generation. This is a"
+ " deprecated strategy to control generation and will be removed soon, in a future version."
+ " Please use and modify the model generation configuration (see"
+ " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
+ )
+ self.generation_config = new_generation_config
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+ generation_config.validate()
+ self._validate_model_kwargs(model_kwargs.copy())
+
+ logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
+
+ # set init values
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
+
+ if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
+ if model_kwargs.get("attention_mask") is None:
+ logger.warning(
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
+ )
+ eos_token_id = generation_config.eos_token_id
+ if isinstance(eos_token_id, list):
+ eos_token_id = eos_token_id[0]
+ logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
+ generation_config.pad_token_id = eos_token_id
+
+ if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
+ raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
+
+ # decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
+ if not self.config.is_encoder_decoder and not trace:
+ if (
+ generation_config.pad_token_id is not None
+ and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0
+ ):
+ logger.warning(
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
+ )
+
+ batch_size = input_ids.shape[0]
+
+ if self.config.is_encoder_decoder:
+ # add encoder_outputs to model_kwargs
+ if model_kwargs.get("encoder_outputs") is None:
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
+ # prepare decoder_input_ids for generation
+ input_ids = self._prepare_decoder_input_ids_for_generation(
+ batch_size,
+ decoder_start_token_id=generation_config.decoder_start_token_id,
+ bos_token_id=generation_config.bos_token_id,
+ model_kwargs=model_kwargs,
+ )
+
+ # Prepare `max_length` depending on other stopping criteria.
+ input_ids_seq_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
+ # 20 is the default max_length of the generation config
+ warnings.warn(
+ f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
+ "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif generation_config.max_new_tokens is not None:
+ if not has_default_max_length and generation_config.max_length is not None:
+ logger.warning(
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
+ )
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
+
+ if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
+ raise ValueError(
+ f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than"
+ f" the maximum length ({generation_config.max_length})"
+ )
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ logger.warning(
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing`max_new_tokens`."
+ )
+
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ logits_processor=logits_processor,
+ )
+
+ if not generation_config.do_sample and generation_config.num_beams == 1:
+ raise NotImplementedError
+ elif generation_config.do_sample and generation_config.num_beams == 1:
+ logits_warper = self._get_logits_warper(generation_config=generation_config)
+ return self._sample_vision(
+ input_ids,
+ generation_config.max_length,
+ generation_config.pad_token_id,
+ generation_config.eos_token_id,
+ prng_key,
+ logits_warper=logits_warper,
+ logits_processor=logits_processor,
+ cfg_scales=cfg_scales,
+ trace=trace,
+ params=params,
+ model_kwargs=model_kwargs,
+ )
+ elif not generation_config.do_sample and generation_config.num_beams > 1:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError("`Beam sampling is currently not implemented.")
diff --git a/lwm/vqgan.py b/lwm/vqgan.py
new file mode 100644
index 0000000..77715f8
--- /dev/null
+++ b/lwm/vqgan.py
@@ -0,0 +1,351 @@
+from typing import Optional
+from functools import cached_property, partial
+import pickle
+import numpy as np
+import jax
+import jax.numpy as jnp
+import flax.linen as nn
+from flax import jax_utils
+from transformers.configuration_utils import PretrainedConfig
+from ml_collections import ConfigDict
+from tux import function_args_to_config, open_file
+
+
+class VQGAN:
+ def __init__(self, vqgan_checkpoint, replicate=False):
+ assert vqgan_checkpoint != ''
+ self.replicate = replicate
+ self.config = VQGANConfig.get_default_config()
+ self.params = pickle.load(open_file(vqgan_checkpoint, 'rb'))
+ if replicate:
+ self.params = jax_utils.replicate(self.params)
+ else:
+ self.params = jax.jit(lambda x: x)(self.params)
+ self.model = VQGANModel(self.config)
+
+ def _wrap_fn(self, fn):
+ if self.replicate:
+ return jax.pmap(fn, devices=jax.local_devices())
+ else:
+ return jax.jit(fn)
+
+ @cached_property
+ def _encode(self):
+ def fn(pixel_values, params):
+ return self.model.apply(
+ {'params': params},
+ pixel_values,
+ method=self.model.encode
+ )
+ return partial(self._wrap_fn(fn), params=self.params)
+
+ @cached_property
+ def _decode(self):
+ def fn(encoding, params):
+ return self.model.apply(
+ {'params': params},
+ encoding,
+ method=self.model.decode
+ )
+ return partial(self._wrap_fn(fn), params=self.params)
+
+ def encode(self, pixel_values):
+ return self._encode(pixel_values)
+
+ def decode(self, encoding):
+ return self._decode(encoding)
+
+
+class VQGANConfig(PretrainedConfig):
+ model_type = "vqgan"
+
+ def __init__(
+ self,
+ resolution=256,
+ num_channels=3,
+ hidden_channels=128,
+ channel_mult=(1, 2, 2, 4, 6),
+ num_res_blocks=2,
+ attn_resolutions=(),
+ no_attn_mid_block=True,
+ z_channels=64,
+ num_embeddings=8192,
+ quantized_embed_dim=64,
+ dropout=0.0,
+ resample_with_conv=True,
+ commitment_cost=0.25
+ ):
+ self.resolution = resolution
+ self.num_channels = num_channels
+ self.hidden_channels = hidden_channels
+ self.channel_mult = channel_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_resolutions = attn_resolutions
+ self.no_attn_mid_block = no_attn_mid_block
+ self.z_channels = z_channels
+ self.num_embeddings = num_embeddings
+ self.quantized_embed_dim = quantized_embed_dim
+ self.dropout = dropout
+ self.resample_with_conv = resample_with_conv
+ self.commitment_cost = commitment_cost
+
+ @classmethod
+ def get_default_config(cls, updates=None):
+ config = function_args_to_config(cls.__init__)
+ if updates is not None:
+ config.update(ConfigDict(updates).copy_and_resolve_references())
+ config.num_resolutions = len(config.channel_mult)
+ return config
+
+ @classmethod
+ def load_config(cls, path):
+ return cls.get_default_config(cls)
+
+
+class VQGANModel(nn.Module):
+ config: VQGANConfig
+
+ def setup(self):
+ self.encoder = Encoder(self.config)
+ self.decoder = Decoder(self.config)
+ self.quantize = VectorQuantizer(
+ self.config.num_embeddings, self.config.quantized_embed_dim
+ )
+ self.quant_conv = nn.Conv(self.config.quantized_embed_dim, [1, 1])
+ self.post_quant_conv = nn.Conv(self.config.z_channels, [1, 1])
+
+ def encode(self, pixel_values):
+ T = None
+ if len(pixel_values.shape) == 5: # video
+ T = pixel_values.shape[1]
+ pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:])
+ hidden_states = self.encoder(pixel_values)
+ hidden_states = self.quant_conv(hidden_states)
+ quantized_states, codebook_indices = self.quantize(hidden_states)
+ if T is not None:
+ quantized_states = quantized_states.reshape(-1, T, *quantized_states.shape[1:])
+ codebook_indices = codebook_indices.reshape(-1, T, *codebook_indices.shape[1:])
+ return quantized_states, codebook_indices
+
+ def decode(self, encoding, is_codebook_indices=True):
+ if is_codebook_indices:
+ encoding = self.quantize(None, encoding)
+ T = None
+ if len(encoding.shape) == 5:
+ T = encoding.shape[1]
+ encoding = encoding.reshape(-1, *encoding.shape[2:])
+ hidden_states = self.post_quant_conv(encoding)
+ reconstructed_pixel_values = self.decoder(hidden_states)
+ if T is not None:
+ reconstructed_pixel_values = reconstructed_pixel_values.reshape(-1, T, *reconstructed_pixel_values.shape[1:])
+ return jnp.clip(reconstructed_pixel_values, -1, 1)
+
+ def __call__(self, pixel_values):
+ encoding = self.encode(pixel_values)[1]
+ recon = self.decode(encoding)
+ return recon
+
+
+class Encoder(nn.Module):
+ config: VQGANConfig
+
+ @nn.compact
+ def __call__(self, pixel_values):
+ assert pixel_values.shape[1] == pixel_values.shape[2] == self.config.resolution, pixel_values.shape
+ hidden_states = nn.Conv(self.config.hidden_channels, [3, 3])(pixel_values)
+ for i_level in range(self.config.num_resolutions):
+ hidden_states = DownsamplingBlock(self.config, i_level)(hidden_states)
+ hidden_states = MidBlock(
+ self.config, self.config.no_attn_mid_block, self.config.dropout
+ )(hidden_states)
+ hidden_states = nn.GroupNorm()(hidden_states)
+ hidden_states = nn.silu(hidden_states)
+ hidden_states = nn.Conv(self.config.z_channels, [3, 3])(hidden_states)
+ return hidden_states
+
+
+class Decoder(nn.Module):
+ config: VQGANConfig
+
+ @nn.compact
+ def __call__(self, hidden_states):
+ hidden_states = nn.Conv(
+ self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1],
+ [3, 3]
+ )(hidden_states)
+ hidden_states = MidBlock(
+ self.config, self.config.no_attn_mid_block, self.config.dropout
+ )(hidden_states)
+ for i_level in reversed(range(self.config.num_resolutions)):
+ hidden_states = UpsamplingBlock(self.config, i_level)(hidden_states)
+ hidden_states = nn.GroupNorm()(hidden_states)
+ hidden_states = nn.silu(hidden_states)
+ hidden_states = nn.Conv(self.config.num_channels, [3, 3])(hidden_states)
+ return hidden_states
+
+
+class VectorQuantizer(nn.Module):
+ n_e: int
+ e_dim: int
+
+ @nn.compact
+ def __call__(self, z, encoding_indices=None):
+ def quantize(encoding_indices):
+ w = jax.device_put(embeddings)
+ return w[(encoding_indices,)]
+ embeddings = self.param(
+ 'embeddings',
+ lambda rng, shape, dtype: jax.random.uniform(
+ rng, shape, dtype, minval=-1.0 / self.n_e, maxval=1.0 / self.n_e
+ ),
+ [self.n_e, self.e_dim], jnp.float32
+ )
+
+ if encoding_indices is not None:
+ return quantize(encoding_indices)
+
+ z_flattened = z.reshape(-1, z.shape[-1])
+ d = jnp.sum(z_flattened ** 2, axis=1, keepdims=True) + \
+ jnp.sum(embeddings.T ** 2, axis=0, keepdims=True) - \
+ 2 * jnp.einsum('bd,nd->bn', z_flattened, embeddings)
+
+ min_encoding_indices = jnp.argmin(d, axis=1)
+ z_q = quantize(min_encoding_indices)
+ z_q = jnp.reshape(z_q, z.shape)
+ z_q = z + jax.lax.stop_gradient(z_q - z)
+
+ encodings_one_hot = jax.nn.one_hot(min_encoding_indices, num_classes=self.n_e)
+ assert len(encodings_one_hot.shape) == 2
+ min_encoding_indices = jnp.reshape(min_encoding_indices, z.shape[:-1])
+
+ return z_q, min_encoding_indices
+
+
+class DownsamplingBlock(nn.Module):
+ config: VQGANConfig
+ block_idx: int
+
+ @nn.compact
+ def __call__(self, hidden_states):
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
+ for _ in range(self.config.num_res_blocks):
+ hidden_states = ResnetBlock(
+ block_out, dropout_prob=self.config.dropout
+ )(hidden_states)
+ if hidden_states.shape[1] in self.config.attn_resolutions:
+ hidden_states = AttnBlock()(hidden_states)
+ if self.block_idx != self.config.num_resolutions - 1:
+ hidden_states = Downsample(self.config.resample_with_conv)(hidden_states)
+ return hidden_states
+
+
+class ResnetBlock(nn.Module):
+ out_channels: Optional[int] = None
+ use_conv_shortcut: bool = False
+ dropout_prob: float = 0.0
+
+ @nn.compact
+ def __call__(self, hidden_states):
+ out_channels = self.out_channels or hidden_states.shape[-1]
+ residual = hidden_states
+ hidden_states = nn.GroupNorm()(hidden_states)
+ hidden_states = nn.silu(hidden_states)
+ hidden_states = nn.Conv(out_channels, [3, 3])(hidden_states)
+ hidden_states = nn.GroupNorm()(hidden_states)
+ hidden_states = nn.silu(hidden_states)
+ hidden_states = nn.Dropout(self.dropout_prob, deterministic=True)(hidden_states)
+ hidden_states = nn.Conv(out_channels, [3, 3])(hidden_states)
+ if out_channels != residual.shape[-1]:
+ if self.use_conv_shortcut:
+ residual = nn.Conv(out_channels, [3, 3])(residual)
+ else:
+ residual = nn.Conv(out_channels, [1, 1])(residual)
+ return hidden_states + residual
+
+
+class AttnBlock(nn.Module):
+ @nn.compact
+ def __call__(self, hidden_states):
+ residual = hidden_states
+ hidden_states = nn.GroupNorm()(hidden_states)
+ query = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
+ key = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
+ value = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
+ query, key, value = map(
+ lambda x: x.reshape(x.shape[0], -1, x.shape[-1]),
+ [query, key, value]
+ )
+ attn_weights = jnp.einsum("bqd,bkd->bqk", query, key)
+ attn_weights *= hidden_states.shape[-1] ** -0.5
+ attn_weights = jax.nn.softmax(attn_weights, axis=-1)
+ hidden_states = jnp.einsum("bqk,bkd->bqd", attn_weights, value)
+ hidden_states = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
+ return hidden_states + residual
+
+
+class Downsample(nn.Module):
+ with_conv: bool
+
+ @nn.compact
+ def __call__(self, hidden_states):
+ if self.with_conv:
+ hidden_states = jnp.pad(
+ hidden_states,
+ [(0, 0), (0, 1), (0, 1), (0, 0)]
+ )
+ hidden_states = nn.Conv(
+ hidden_states.shape[-1], [3, 3],
+ strides=[2, 2],
+ padding="VALID"
+ )(hidden_states)
+ else:
+ hidden_states = nn.avg_pool(hidden_states, [2, 2], [2, 2])
+ return hidden_states
+
+
+class Upsample(nn.Module):
+ with_conv: bool
+
+ @nn.compact
+ def __call__(self, hidden_states):
+ B, H, W, C = hidden_states.shape
+ hidden_states = jax.image.resize(
+ hidden_states,
+ (B, H * 2, W * 2, C),
+ method="nearest"
+ )
+ if self.with_conv:
+ hidden_states = nn.Conv(hidden_states.shape[-1], [3, 3])(hidden_states)
+ return hidden_states
+
+
+class UpsamplingBlock(nn.Module):
+ config: VQGANConfig
+ block_idx: int
+
+ @nn.compact
+ def __call__(self, hidden_states):
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
+ for _ in range(self.config.num_res_blocks + 1):
+ hidden_states = ResnetBlock(
+ block_out, dropout_prob=self.config.dropout
+ )(hidden_states)
+ if hidden_states.shape[1] in self.config.attn_resolutions:
+ hidden_states = AttnBlock()(hidden_states)
+ if self.block_idx != 0:
+ hidden_states = Upsample(self.config.resample_with_conv)(hidden_states)
+ return hidden_states
+
+
+class MidBlock(nn.Module):
+ config: VQGANConfig
+ no_attn: bool
+ dropout: float
+
+ @nn.compact
+ def __call__(self, hidden_states):
+ hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states)
+ if not self.no_attn:
+ hidden_states = AttnBlock()(hidden_states)
+ hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states)
+ return hidden_states
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..f6648b0
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,24 @@
+tensorflow==2.11.0
+tensorboard-plugin-profile
+flax==0.7.0
+optax==0.1.7
+chex==0.1.82
+einops
+--extra-index-url https://download.pytorch.org/whl/cpu
+torch==2.0.0
+torchvision==0.15.0
+transformers==4.29.2
+datasets==2.13.0
+tqdm
+ml_collections
+wandb
+gcsfs
+requests
+typing-extensions
+sentencepiece
+tux @ git+https://github.com/lhao499/tux.git
+Pillow
+ipdb
+imageio[ffmpeg]
+decord
+tiktoken
\ No newline at end of file
diff --git a/scripts/eval_needle.py b/scripts/eval_needle.py
new file mode 100644
index 0000000..798c7eb
--- /dev/null
+++ b/scripts/eval_needle.py
@@ -0,0 +1,447 @@
+from absl.app import run
+import time
+import json
+import math
+import os
+from tqdm import tqdm
+import random
+from functools import cached_property
+import numpy as np
+import jax
+from jax.experimental.pjit import pjit
+from jax.sharding import PartitionSpec as PS
+import gcsfs
+import tiktoken
+from transformers import GenerationConfig
+from tux import (
+ define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
+ set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
+ match_partition_rules, make_shard_and_gather_fns,
+ with_sharding_constraint, tree_apply, open_file
+)
+from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLM
+
+
+FLAGS, FLAGS_DEF = define_flags_with_default(
+ haystack_file="",
+ max_tokens_per_batch=2000000,
+ output_file="results.json",
+ context_lengths_min=1000,
+ context_lengths_max=32000,
+ n_context_length_intervals=3,
+ n_document_depth_intervals=3,
+ n_rounds=2,
+ seed=1234,
+ mesh_dim='1,-1,1,1',
+ dtype='fp32',
+ load_llama_config='',
+ update_llama_config='',
+ load_checkpoint='',
+ tokenizer=LLaMAConfig.get_tokenizer_config(),
+ checkpointer=StreamingCheckpointer.get_default_config(),
+ llama=LLaMAConfig.get_default_config(),
+ jax_distributed=JaxDistributedConfig.get_default_config(),
+)
+
+
+class LLMNeedleHaystackTester:
+ OURS_TEMPLATE = "You are a helpful assistant. USER: {context} {question} Don't give information outside the document or repeat your findings. Keep your response short and direct. ASSISTANT: "
+ RANDOM_NEEDLE_CITIES = [
+ 'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City',
+ 'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar',
+ 'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman',
+ 'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco',
+ 'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali',
+ 'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki',
+ 'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo',
+ 'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis',
+ 'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta',
+ 'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels',
+ 'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul',
+ 'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta'
+ ]
+
+ def __init__(self,
+ needle="",
+ haystack_file="",
+ retrieval_question="What is the special magic {} number?",
+ results_version = 1,
+ rnd_number_digits = 7,
+ context_lengths_min = 1000,
+ context_lengths_max = 126000,
+ context_lengths_num_intervals = 10,
+ document_depth_percent_min = 0,
+ document_depth_percent_max = 100,
+ document_depth_percent_intervals = 10,
+ document_depth_percent_interval_type = "linear",
+ save_results = False,
+ final_context_length_buffer = 200,
+ print_ongoing_status = True):
+ needle="\nThe special magic {city} number is: {rnd_number}\n"
+ self.needle = needle
+ if not needle or not haystack_file or not retrieval_question:
+ raise ValueError("Needle, haystack, and retrieval_question must be provided.")
+
+ self.rnd_number_digits = rnd_number_digits
+ self.context_lengths_num_intervals = context_lengths_num_intervals
+ self.document_depth_percent_intervals = document_depth_percent_intervals
+ self.haystack_file = haystack_file
+ self.retrieval_question = retrieval_question
+ self.results_version = results_version
+ self.save_results = save_results
+ self.final_context_length_buffer = final_context_length_buffer
+ self.print_ongoing_status = print_ongoing_status
+ self.testing_results = []
+
+ self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int)
+ if document_depth_percent_interval_type == 'linear':
+ self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int)
+ elif document_depth_percent_interval_type == 'sigmoid':
+ self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)]
+ else:
+ raise ValueError(f"Unsupported document_depth_percent_interval_type: {document_depth_percent_interval_type}")
+
+ self.model = Sampler()
+
+ self.enc = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
+ self.enc_tiktoken = tiktoken.encoding_for_model("gpt-4-1106-preview")
+
+ def generate_random_number(self, num_digits):
+ lower_bound = 10**(num_digits - 1)
+ upper_bound = 10**num_digits - 1
+ return random.randint(lower_bound, upper_bound)
+
+ def logistic(self, x, L=100, x0=50, k=.1):
+ if x == 0:
+ return 0
+ if x == 100:
+ return 100
+ return np.round(L / (1 + np.exp(-k * (x - x0))), 3)
+
+ def read_context_files(self, n):
+ max_context_length = max(self.context_lengths)
+ contexts = []
+ f = open_file(self.haystack_file, 'r')
+ for _ in range(n):
+ context = ""
+ toks = 0
+ while toks < max_context_length:
+ text = json.loads(f.readline())['text']
+ context += text
+ toks += len(self.enc.encode(text))
+ contexts.append(context)
+ return contexts
+
+ def encode_and_trim(self, context, context_length):
+ tokens = self.enc.encode(context)
+ if len(tokens) > context_length:
+ context = self.enc.decode(tokens[:context_length])
+ return context
+
+ def create_contexts(self, needle_rnd_number, insert_needle, random_city, trim_context, context_length, depth_percent, seed):
+ if self.save_results:
+ if self.result_exists(context_length, depth_percent):
+ return
+ needle = self.needle.format(city=random_city, rnd_number=needle_rnd_number)
+ question = self.retrieval_question.format(random_city)
+ if not insert_needle:
+ needle = " " #replace needle with a space
+ context = self.generate_context(needle, trim_context, context_length, depth_percent)
+ results = {
+ 'context' : context,
+ 'context_length' : int(context_length),
+ 'depth_percent' : float(depth_percent),
+ 'needle' : needle,
+ 'question' : question,
+ 'insert_needle' : insert_needle,
+ 'needle_rnd_number' : needle_rnd_number,
+ 'seed': seed,
+ }
+ return results
+
+ def insert_needle(self, needle, context, depth_percent, context_length):
+ tokens_needle = self.enc_tiktoken.encode(needle)
+ tokens_context = self.enc_tiktoken.encode(context)
+
+ # Reducing the context length by 150 buffer. This is to account for system message, the user question, and response.
+ context_length -= self.final_context_length_buffer
+
+ # If your context + needle are longer than the context length (which it will be), then reduce tokens from the context by the needle length
+ if len(tokens_context) + len(tokens_needle) > context_length:
+ tokens_context = tokens_context[:context_length - len(tokens_needle)]
+
+ if depth_percent == 100:
+ # If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end
+ tokens_new_context = tokens_context + tokens_needle
+ else:
+ # Go get the position (in terms of tokens) to insert your needle
+ insertion_point = int(len(tokens_context) * (depth_percent / 100))
+
+ # tokens_new_context represents the tokens before the needle
+ tokens_new_context = tokens_context[:insertion_point]
+
+ # We want to make sure that we place our needle at a sentence break so we first see what token a '.' is
+ period_tokens = self.enc_tiktoken.encode('.')
+
+ # Then we iteration backwards until we find the first period
+ while tokens_new_context and tokens_new_context[-1] not in period_tokens:
+ insertion_point -= 1
+ tokens_new_context = tokens_context[:insertion_point]
+
+ # Once we get there, then add in your needle, and stick the rest of your context in on the other end.
+ # Now we have a needle in a haystack
+ tokens_new_context += tokens_needle + tokens_context[insertion_point:]
+
+ # Convert back to a string and return it
+ new_context = self.enc_tiktoken.decode(tokens_new_context)
+ return new_context
+
+ def generate_context(self, needle, trim_context, context_length, depth_percent):
+ context = self.insert_needle(needle, trim_context, depth_percent, context_length)
+ return context
+
+ def compute_max_input_length(self, context_length, buffer=1024):
+ block_size = self.model.block_size
+ context_length += buffer
+ context_length = math.ceil(context_length / block_size) * block_size
+ return int(context_length)
+
+ def run_test(self):
+ fs = gcsfs.GCSFileSystem()
+ contexts = []
+ template = self.OURS_TEMPLATE
+
+ def _key_from_result(result):
+ return (result['context_length'], result['depth_percent'], result['seed'])
+
+ results = []
+ completed = set()
+ def exists(fname):
+ if fname.startswith('gs://'):
+ return fs.exists(fname)
+ else:
+ return os.path.exists(fname)
+ if exists(FLAGS.output_file):
+ with open_file(FLAGS.output_file, 'r') as f:
+ results = json.load(f)
+ completed = set([_key_from_result(result) for result in results])
+ print('completed', len(completed))
+
+ full_contexts = self.read_context_files(FLAGS.n_rounds)
+ full_tokens = [self.enc.encode(full_context) for full_context in tqdm(full_contexts)]
+
+ start = time.time()
+ for context_length in self.context_lengths:
+ trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in tqdm(full_tokens)]
+ max_input_length = self.compute_max_input_length(context_length)
+ contexts = []
+ for depth_percent in self.document_depth_percents:
+ for i in range(FLAGS.n_rounds):
+ if (int(context_length), float(depth_percent), i) in completed:
+ continue
+ random_city = random.choice(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES)
+ insert_needle = True
+ needle_rnd_number = str(self.generate_random_number(self.rnd_number_digits))
+ print("context length: " + str(context_length))
+ print("depth_percent : " + str(depth_percent))
+ context = self.create_contexts(needle_rnd_number, insert_needle, random_city, trim_contexts[i], context_length, depth_percent, i)
+ contexts.append(context)
+
+ if len(contexts) == 0:
+ continue
+
+ B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size)
+ B = int(B / self.model.data_dim) * self.model.data_dim
+ if B < self.model.data_dim:
+ B = self.model.data_dim
+ elif B > len(contexts):
+ B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim)
+ if len(contexts) % B == 0:
+ n_pad = 0
+ else:
+ n_pad = B - len(contexts) % B
+ for _ in range(n_pad):
+ contexts.insert(0, contexts[0])
+
+ pbar = tqdm(total=len(contexts))
+ for i in range(0, len(contexts), B):
+ contexts_i = contexts[i:i + B]
+ prompts = [
+ template.format(context=context['context'], question=context['question'])
+ for context in contexts_i
+ ]
+ outs = self.model(prompts, max_input_length)
+ for j, (context, out) in enumerate(zip(contexts_i, outs)):
+ if i + j < n_pad:
+ continue
+ results.append({
+ 'context_length': context['context_length'],
+ 'depth_percent': context['depth_percent'],
+ 'response': out,
+ 'answer': context['needle_rnd_number'],
+ 'correct': context['needle_rnd_number'] in out,
+ 'seed': context['seed'],
+ })
+ print(results[-1])
+ if jax.process_index() == 0:
+ with open_file(FLAGS.output_file, 'w') as f:
+ json.dump(results, f)
+ pbar.update(len(contexts_i))
+ pbar.close()
+ print('elapsed', time.time() - start)
+ print('done')
+
+
+ def print_start_test_summary(self):
+ print ("\n")
+ print ("Starting Needle In A Haystack Testing...")
+ print (f"- Context Lengths: {len(self.context_lengths)}, Min: {min(self.context_lengths)}, Max: {max(self.context_lengths)}")
+ print (f"- Document Depths: {len(self.document_depth_percents)}, Min: {min(self.document_depth_percents)}%, Max: {max(self.document_depth_percents)}%")
+ print (f"- Needle: {self.needle.strip()}")
+ print ("\n\n")
+
+ def start_test(self):
+ if self.print_ongoing_status:
+ self.print_start_test_summary()
+ self.run_test()
+
+
+
+class Sampler:
+ def __init__(self):
+ self.mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
+ self.prefix_tokenizer = LLaMAConfig.get_tokenizer(
+ FLAGS.tokenizer, truncation_side='left', padding_side='left'
+ )
+ self.tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
+ self.sharded_rng = next_rng()
+ self._load_model()
+
+ @property
+ def block_size(self):
+ # return 2 * max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size)
+ return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']
+
+ @property
+ def data_dim(self):
+ return self.mesh.shape['dp'] * self.mesh.shape['fsdp']
+
+ def _load_model(self):
+ if FLAGS.load_llama_config != '':
+ llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
+ updates = LLaMAConfig(**FLAGS.llama)
+ llama_config.update(dict(
+ remat_block=updates.remat_block,
+ remat_attention=updates.remat_attention,
+ remat_mlp=updates.remat_mlp,
+ scan_attention=updates.scan_attention,
+ scan_mlp=updates.scan_mlp,
+ scan_query_chunk_size=updates.scan_query_chunk_size,
+ scan_key_chunk_size=updates.scan_key_chunk_size,
+ scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
+ scan_layers=updates.scan_layers,
+ param_scan_axis=updates.param_scan_axis,
+ ))
+ else:
+ llama_config = LLaMAConfig(**FLAGS.llama)
+
+ if FLAGS.update_llama_config != '':
+ llama_config.update(dict(eval(FLAGS.update_llama_config)))
+
+ llama_config.update(dict(
+ bos_token_id=self.tokenizer.bos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ ))
+ llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
+ self.config = llama_config
+ assert not self.config.use_flash_attention, f"Flash attention is not supported for inference"
+
+ with jax.default_device(jax.devices("cpu")[0]):
+ _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
+ FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
+ )
+ self.model = FlaxLLaMAForCausalLM(
+ llama_config,
+ input_shape=(512, self.block_size),
+ seed=FLAGS.seed,
+ _do_init=False,
+ dtype=get_float_dtype_by_name(FLAGS.dtype),
+ )
+ self.model_ps = match_partition_rules(
+ LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
+ )
+ shard_fns, _ = make_shard_and_gather_fns(
+ self.model_ps, get_float_dtype_by_name(FLAGS.dtype)
+ )
+
+ with self.mesh:
+ self.params = tree_apply(shard_fns, self.params)
+
+ @cached_property
+ def _forward_generate(self):
+ def fn(params, rng, batch):
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
+ rng_generator = JaxRNG(rng)
+ output = self.model.generate(
+ batch['input_ids'],
+ attention_mask=batch['attention_mask'],
+ params=params['params'],
+ prng_key=rng_generator(),
+ generation_config=GenerationConfig(
+ max_new_tokens=self.block_size,
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ temperature=0.,
+ do_sample=False,
+ num_beams=1,
+ top_k=50,
+ top_p=1.0,
+ )
+ ).sequences[:, batch['input_ids'].shape[1]:]
+ return output, rng_generator()
+ return pjit(
+ fn,
+ in_shardings=(self.model_ps, PS(), PS()),
+ out_shardings=(PS(), PS())
+ )
+
+ def __call__(self, prompts, max_input_length):
+ inputs = self.prefix_tokenizer(
+ prompts,
+ padding='max_length',
+ truncation=True,
+ max_length=max_input_length,
+ return_tensors='np'
+ )
+ batch = dict(
+ input_ids=inputs.input_ids,
+ attention_mask=inputs.attention_mask
+ )
+ with self.mesh:
+ output, self.sharded_rng = self._forward_generate(
+ self.params, self.sharded_rng, batch
+ )
+ output = jax.device_get(output)
+ output_text = []
+ for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)):
+ if self.tokenizer.eos_token in text:
+ text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]
+ output_text.append(text)
+ return output_text
+
+
+def main(argv):
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
+ set_random_seed(FLAGS.seed)
+
+ ht = LLMNeedleHaystackTester(
+ haystack_file=FLAGS.haystack_file,
+ context_lengths_min=FLAGS.context_lengths_min,
+ context_lengths_max=FLAGS.context_lengths_max,
+ context_lengths_num_intervals=FLAGS.n_context_length_intervals,
+ document_depth_percent_intervals=FLAGS.n_document_depth_intervals,
+ )
+ ht.start_test()
+
+if __name__ == "__main__":
+ run(main)
diff --git a/scripts/eval_needle_multi.py b/scripts/eval_needle_multi.py
new file mode 100644
index 0000000..7145def
--- /dev/null
+++ b/scripts/eval_needle_multi.py
@@ -0,0 +1,455 @@
+from absl.app import run
+import glob
+import time
+import json
+import math
+import os
+from tqdm import tqdm
+import random
+from functools import cached_property
+import numpy as np
+import jax
+from jax.experimental.pjit import pjit
+from jax.sharding import PartitionSpec as PS
+import gcsfs
+import tiktoken
+from transformers import GenerationConfig, LlamaTokenizer
+from tux import (
+ define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
+ set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
+ match_partition_rules, make_shard_and_gather_fns,
+ with_sharding_constraint, tree_apply, open_file
+)
+from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLM
+
+
+FLAGS, FLAGS_DEF = define_flags_with_default(
+ haystack_file="",
+ max_tokens_per_batch=2000000,
+ output_file="results.json",
+ context_lengths_min=1000,
+ context_lengths_max=32000,
+ n_context_length_intervals=3,
+ n_document_depth_intervals=3,
+ n_rounds=2,
+ n_needles_total=4,
+ n_needles_retrieve=4,
+ seed=1234,
+ mesh_dim='1,-1,1,1',
+ dtype='fp32',
+ load_llama_config='',
+ update_llama_config='',
+ load_checkpoint='',
+ tokenizer=LLaMAConfig.get_tokenizer_config(),
+ checkpointer=StreamingCheckpointer.get_default_config(),
+ llama=LLaMAConfig.get_default_config(),
+ jax_distributed=JaxDistributedConfig.get_default_config(),
+)
+
+
+class LLMNeedleHaystackTester:
+ OURS_TEMPLATE = "You are a helpful assistant. USER: {context} {question} Don't give information outside the document. ASSISTANT: "
+ RANDOM_NEEDLE_CITIES = [
+ 'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City',
+ 'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar',
+ 'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman',
+ 'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco',
+ 'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali',
+ 'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki',
+ 'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo',
+ 'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis',
+ 'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta',
+ 'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels',
+ 'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul',
+ 'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta'
+ ]
+
+ def __init__(self,
+ needle="",
+ haystack_file="",
+ retrieval_question="What are the special magic numbers for {}?",
+ results_version = 1,
+ rnd_number_digits = 7,
+ context_lengths_min = 1000,
+ context_lengths_max = 126000,
+ context_lengths_num_intervals = 10,
+ document_depth_percent_min = 0,
+ document_depth_percent_max = 100,
+ document_depth_percent_intervals = 10,
+ document_depth_percent_interval_type = "linear",
+ save_results = False,
+ final_context_length_buffer = 200,
+ print_ongoing_status = True):
+ needle="\nThe special magic {city} number is: {rnd_number}\n"
+ self.needle = needle
+ if not needle or not haystack_file or not retrieval_question:
+ raise ValueError("Needle, haystack, and retrieval_question must be provided.")
+
+ self.rnd_number_digits = rnd_number_digits
+ self.context_lengths_num_intervals = context_lengths_num_intervals
+ self.document_depth_percent_intervals = document_depth_percent_intervals
+ self.haystack_file = haystack_file
+ self.retrieval_question = retrieval_question
+ self.results_version = results_version
+ self.save_results = save_results
+ self.final_context_length_buffer = final_context_length_buffer
+ self.print_ongoing_status = print_ongoing_status
+ self.testing_results = []
+
+ self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int)
+ self.context_lengths = self.context_lengths.tolist()
+ if document_depth_percent_interval_type == 'linear':
+ self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int)
+ elif document_depth_percent_interval_type == 'sigmoid':
+ self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)]
+ else:
+ raise ValueError(f"Unsupported document_depth_percent_interval_type: {document_depth_percent_interval_type}")
+ self.document_depth_percents = self.document_depth_percents.tolist()
+
+ self.model = Sampler()
+
+ self.enc = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
+ self.enc_tiktoken = tiktoken.encoding_for_model("gpt-4-1106-preview")
+
+ def generate_random_number(self, num_digits):
+ lower_bound = 10**(num_digits - 1)
+ upper_bound = 10**num_digits - 1
+ return random.randint(lower_bound, upper_bound)
+
+ def logistic(self, x, L=100, x0=50, k=.1):
+ if x == 0:
+ return 0
+ if x == 100:
+ return 100
+ return np.round(L / (1 + np.exp(-k * (x - x0))), 3)
+
+ def read_context_files(self, n):
+ max_context_length = max(self.context_lengths)
+ contexts = []
+ f = open_file(self.haystack_file, 'r')
+ for i in range(n):
+ context = ""
+ while len(self.enc.encode(context)) < max_context_length:
+ context += json.loads(f.readline())['text']
+ contexts.append(context)
+ return contexts
+
+ def encode_and_trim(self, context, context_length):
+ tokens = self.enc.encode(context)
+ if len(tokens) > context_length:
+ context = self.enc.decode(tokens[:context_length])
+ return context
+
+ def create_contexts(self, needles_info, random_cities_retrieve, context, context_length, seed):
+ assert all([random_city in needles_info for random_city in random_cities_retrieve])
+ for random_city, (needle_rnd_number, depth_percent) in needles_info.items():
+ context = self.generate_context(
+ self.needle.format(city=random_city, rnd_number=needle_rnd_number),
+ context, context_length, depth_percent
+ )
+
+ if len(random_cities_retrieve) == 1:
+ question = f"What is the special magic number for {random_cities_retrieve[0]}?"
+ else:
+ q = ', '.join(random_cities_retrieve[:-1]) + ', and ' + random_cities_retrieve[-1]
+ question = self.retrieval_question.format(q)
+ results = {
+ 'context' : context,
+ 'context_length' : int(context_length),
+ 'needles_info': needles_info,
+ 'question' : question,
+ 'cities_to_retrieve' : random_cities_retrieve,
+ 'seed': seed,
+ }
+ return results
+
+ def insert_needle(self, needle, context, depth_percent, context_length):
+ tokens_needle = self.enc_tiktoken.encode(needle)
+ tokens_context = self.enc_tiktoken.encode(context)
+
+ # Reducing the context length by 150 buffer. This is to account for system message, the user question, and response.
+ context_length -= self.final_context_length_buffer
+
+ # If your context + needle are longer than the context length (which it will be), then reduce tokens from the context by the needle length
+ if len(tokens_context) + len(tokens_needle) > context_length:
+ tokens_context = tokens_context[:context_length - len(tokens_needle)]
+
+ if depth_percent == 100:
+ # If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end
+ tokens_new_context = tokens_context + tokens_needle
+ else:
+ # Go get the position (in terms of tokens) to insert your needle
+ insertion_point = int(len(tokens_context) * (depth_percent / 100))
+
+ # tokens_new_context represents the tokens before the needle
+ tokens_new_context = tokens_context[:insertion_point]
+
+ # We want to make sure that we place our needle at a sentence break so we first see what token a '.' is
+ period_tokens = self.enc_tiktoken.encode('.')
+
+ # Then we iteration backwards until we find the first period
+ while tokens_new_context and tokens_new_context[-1] not in period_tokens:
+ insertion_point -= 1
+ tokens_new_context = tokens_context[:insertion_point]
+
+ # Once we get there, then add in your needle, and stick the rest of your context in on the other end.
+ # Now we have a needle in a haystack
+ tokens_new_context += tokens_needle + tokens_context[insertion_point:]
+
+ # Convert back to a string and return it
+ new_context = self.enc_tiktoken.decode(tokens_new_context)
+ return new_context
+
+ def generate_context(self, needle, trim_context, context_length, depth_percent):
+ context = self.insert_needle(needle, trim_context, depth_percent, context_length)
+ return context
+
+ def compute_max_input_length(self, context_length, buffer=1024):
+ block_size = self.model.block_size
+ context_length += buffer
+ # context_length = 2 ** math.ceil(math.log2(context_length))
+ context_length = math.ceil(context_length / block_size) * block_size
+ return int(context_length)
+
+ def run_test(self):
+ fs = gcsfs.GCSFileSystem()
+ contexts = []
+ template = self.OURS_TEMPLATE
+
+ def _key_from_result(result):
+ return (result['context_length'], result['depth_percent'], result['seed'])
+
+ results = []
+ completed = set()
+ def exists(fname):
+ if fname.startswith('gs://'):
+ return fs.exists(fname)
+ else:
+ return os.path.exists(fname)
+ if exists(FLAGS.output_file):
+ with open_file(FLAGS.output_file, 'r') as f:
+ results = json.load(f)
+ completed = set([_key_from_result(result) for result in results])
+ print('completed', len(completed))
+
+ full_contexts = self.read_context_files(FLAGS.n_rounds)
+ full_tokens = [self.enc.encode(full_context) for full_context in full_contexts]
+
+ start = time.time()
+ for context_length in self.context_lengths:
+ trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in full_tokens]
+ max_input_length = self.compute_max_input_length(context_length)
+ contexts = []
+ for i in range(FLAGS.n_rounds):
+ if (int(context_length), i) in completed:
+ continue
+ random_cities = random.sample(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES, FLAGS.n_needles_total)
+ document_depths = random.sample(self.document_depth_percents, FLAGS.n_needles_total)
+ random_cities_retrieve = random.sample(random_cities, FLAGS.n_needles_retrieve)
+ needles_info = {}
+ for random_city, depth_percent in zip(random_cities, document_depths):
+ needles_info[random_city] = (
+ str(self.generate_random_number(self.rnd_number_digits)),
+ depth_percent
+ )
+ context = self.create_contexts(needles_info, random_cities_retrieve, trim_contexts[i], context_length, i)
+ contexts.append(context)
+
+ if len(contexts) == 0:
+ continue
+
+ B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size)
+ B = int(B / self.model.data_dim) * self.model.data_dim
+ if B < self.model.data_dim:
+ B = self.model.data_dim
+ elif B > len(contexts):
+ B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim)
+ n_pad = B - len(contexts) % B
+ for _ in range(n_pad):
+ contexts.insert(0, contexts[0])
+
+ pbar = tqdm(total=len(contexts))
+ for i in range(0, len(contexts), B):
+ contexts_i = contexts[i:i + B]
+ prompts = [
+ template.format(context=context['context'], question=context['question'])
+ for context in contexts_i
+ ]
+ outs = self.model(prompts, max_input_length)
+ for j, (context, out) in enumerate(zip(contexts_i, outs)):
+ if i + j < n_pad:
+ continue
+ rnd_nums_to_retrieve = [
+ context['needles_info'][city][0] for city in context['cities_to_retrieve']
+ ]
+ results.append({
+ 'context_length': context['context_length'],
+ 'needles_info': context['needles_info'],
+ 'question': context['question'],
+ 'answer': rnd_nums_to_retrieve,
+ 'response': out,
+ 'correct': [rnd_num in out for rnd_num in rnd_nums_to_retrieve],
+ 'seed': context['seed'],
+ })
+ print(results[-1]['correct'], out, rnd_nums_to_retrieve)
+ if jax.process_index() == 0:
+ with open_file(FLAGS.output_file, 'w') as f:
+ json.dump(results, f)
+ pbar.update(len(contexts_i))
+ pbar.close()
+ print('elapsed', time.time() - start)
+ print('done')
+
+
+ def print_start_test_summary(self):
+ print ("\n")
+ print ("Starting Needle In A Haystack Testing...")
+ print (f"- Context Lengths: {len(self.context_lengths)}, Min: {min(self.context_lengths)}, Max: {max(self.context_lengths)}")
+ print (f"- Document Depths: {len(self.document_depth_percents)}, Min: {min(self.document_depth_percents)}%, Max: {max(self.document_depth_percents)}%")
+ print (f"- Needle: {self.needle.strip()}")
+ print ("\n\n")
+
+ def start_test(self):
+ if self.print_ongoing_status:
+ self.print_start_test_summary()
+ self.run_test()
+
+
+
+class Sampler:
+ def __init__(self):
+ self.mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
+ self.prefix_tokenizer = LLaMAConfig.get_tokenizer(
+ FLAGS.tokenizer, truncation_side='left', padding_side='left'
+ )
+ self.tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
+ self.sharded_rng = next_rng()
+ self._load_model()
+
+ @property
+ def block_size(self):
+ # return 2 * max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size)
+ return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']
+
+ @property
+ def data_dim(self):
+ return self.mesh.shape['dp'] * self.mesh.shape['fsdp']
+
+ def _load_model(self):
+ if FLAGS.load_llama_config != '':
+ llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
+ updates = LLaMAConfig(**FLAGS.llama)
+ llama_config.update(dict(
+ remat_block=updates.remat_block,
+ remat_attention=updates.remat_attention,
+ remat_mlp=updates.remat_mlp,
+ scan_attention=updates.scan_attention,
+ scan_mlp=updates.scan_mlp,
+ scan_query_chunk_size=updates.scan_query_chunk_size,
+ scan_key_chunk_size=updates.scan_key_chunk_size,
+ scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
+ scan_layers=updates.scan_layers,
+ param_scan_axis=updates.param_scan_axis,
+ ))
+ else:
+ llama_config = LLaMAConfig(**FLAGS.llama)
+
+ if FLAGS.update_llama_config != '':
+ llama_config.update(dict(eval(FLAGS.update_llama_config)))
+
+ llama_config.update(dict(
+ bos_token_id=self.tokenizer.bos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ ))
+ llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
+ self.config = llama_config
+
+ with jax.default_device(jax.devices("cpu")[0]):
+ _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
+ FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
+ )
+ self.model = FlaxLLaMAForCausalLM(
+ llama_config,
+ input_shape=(512, self.block_size),
+ seed=FLAGS.seed,
+ _do_init=False,
+ dtype=get_float_dtype_by_name(FLAGS.dtype),
+ )
+ self.model_ps = match_partition_rules(
+ LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
+ )
+ shard_fns, _ = make_shard_and_gather_fns(
+ self.model_ps, get_float_dtype_by_name(FLAGS.dtype)
+ )
+
+ with self.mesh:
+ self.params = tree_apply(shard_fns, self.params)
+
+ @cached_property
+ def _forward_generate(self):
+ def fn(params, rng, batch):
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
+ rng_generator = JaxRNG(rng)
+ output = self.model.generate(
+ batch['input_ids'],
+ attention_mask=batch['attention_mask'],
+ params=params['params'],
+ prng_key=rng_generator(),
+ generation_config=GenerationConfig(
+ max_new_tokens=self.block_size,
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ temperature=0.,
+ do_sample=False,
+ num_beams=1,
+ top_k=50,
+ top_p=1.0,
+ )
+ ).sequences[:, batch['input_ids'].shape[1]:]
+ return output, rng_generator()
+ return pjit(
+ fn,
+ in_shardings=(self.model_ps, PS(), PS()),
+ out_shardings=(PS(), PS())
+ )
+
+ def __call__(self, prompts, max_input_length):
+ inputs = self.prefix_tokenizer(
+ prompts,
+ padding='max_length',
+ truncation=True,
+ max_length=max_input_length,
+ return_tensors='np'
+ )
+ batch = dict(
+ input_ids=inputs.input_ids,
+ attention_mask=inputs.attention_mask
+ )
+ with self.mesh:
+ output, self.sharded_rng = self._forward_generate(
+ self.params, self.sharded_rng, batch
+ )
+ output = jax.device_get(output)
+ output_text = []
+ for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)):
+ if self.tokenizer.eos_token in text:
+ text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]
+ output_text.append(text)
+ return output_text
+
+
+def main(argv):
+ JaxDistributedConfig.initialize(FLAGS.jax_distributed)
+ set_random_seed(FLAGS.seed)
+
+ ht = LLMNeedleHaystackTester(
+ haystack_file=FLAGS.haystack_file,
+ context_lengths_min=FLAGS.context_lengths_min,
+ context_lengths_max=FLAGS.context_lengths_max,
+ context_lengths_num_intervals=FLAGS.n_context_length_intervals,
+ document_depth_percent_intervals=FLAGS.n_document_depth_intervals,
+ )
+ ht.start_test()
+
+if __name__ == "__main__":
+ run(main)
diff --git a/scripts/run_eval_needle.sh b/scripts/run_eval_needle.sh
new file mode 100755
index 0000000..24eae3b
--- /dev/null
+++ b/scripts/run_eval_needle.sh
@@ -0,0 +1,31 @@
+#! /bin/bash
+
+export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
+export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
+cd $PROJECT_DIR
+export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
+export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
+
+export llama_tokenizer_path=""
+export lwm_text_checkpoint=""
+# jsonl file containing text for haystack. Each line should be a json
+# with a single key "text" containing the text.
+export haystack_file=""
+export output_file=""
+
+python3 -u scripts/eval_needle.py \
+ --mesh_dim='!1,-1,4,1' \
+ --dtype='fp32' \
+ --load_llama_config='7b' \
+ --update_llama_config="dict(theta=10000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=True,scan_query_chunk_size=1024,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \
+ --load_checkpoint="params::$lwm_text_checkpoint" \
+ --tokenizer.vocab_file="$llama_tokenizer_path" \
+ --max_tokens_per_batch=5000 \
+ --output_file="$output_file" \
+ --haystack_file="$haystack_file" \
+ --context_lengths_min=1000 \
+ --context_lengths_max=10000 \
+ --n_context_length_intervals=20 \
+ --n_document_depth_intervals=20 \
+ --n_rounds=3
+read
diff --git a/scripts/run_eval_needle_multi.sh b/scripts/run_eval_needle_multi.sh
new file mode 100755
index 0000000..c87e2a6
--- /dev/null
+++ b/scripts/run_eval_needle_multi.sh
@@ -0,0 +1,33 @@
+#! /bin/bash
+
+export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
+export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
+cd $PROJECT_DIR
+export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
+export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
+
+export llama_tokenizer_path=""
+export lwm_text_checkpoint=""
+# jsonl file containing text for haystack. Each line should be a json
+# with a single key "text" containing the text.
+export haystack_file=""
+export output_file=""
+
+python3 -u scripts/eval_needle_multi.py \
+ --mesh_dim='!1,-1,4,1' \
+ --dtype='fp32' \
+ --load_llama_config='7b' \
+ --update_llama_config="dict(theta=10000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=True,scan_query_chunk_size=1024,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \
+ --load_checkpoint="params::$lwm_text_checkpoint" \
+ --tokenizer.vocab_file="$llama_tokenizer_path" \
+ --max_tokens_per_batch=5000 \
+ --output_file="$output_file" \
+ --haystack_file="$haystack_file" \
+ --context_lengths_min=1000 \
+ --context_lengths_max=10000 \
+ --n_context_length_intervals=10 \
+ --n_document_depth_intervals=10 \
+ --n_needles_total=4 \
+ --n_needles_retrieve=2 \
+ --n_rounds=10
+read
diff --git a/scripts/run_sample_image.sh b/scripts/run_sample_image.sh
new file mode 100755
index 0000000..2627ad1
--- /dev/null
+++ b/scripts/run_sample_image.sh
@@ -0,0 +1,27 @@
+#! /bin/bash
+
+export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
+export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
+cd $PROJECT_DIR
+export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
+export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
+
+export llama_tokenizer_path=""
+export vqgan_checkpoint=""
+export lwm_checkpoint=""
+
+python3 -u -m lwm.vision_generation \
+ --prompt='Fireworks over the city' \
+ --output_file='fireworks.png' \
+ --temperature_image=1.0 \
+ --top_k_image=8192 \
+ --cfg_scale_image=5.0 \
+ --vqgan_checkpoint="$vqgan_checkpoint" \
+ --n_frames=1 \
+ --mesh_dim='!-1,1,8,1' \
+ --dtype='fp32' \
+ --load_llama_config='7b' \
+ --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,use_flash_attention=True,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
+ --load_checkpoint="params::$lwm_checkpoint" \
+ --tokenizer.vocab_file="$llama_tokenizer_path"
+read
diff --git a/scripts/run_sample_video.sh b/scripts/run_sample_video.sh
new file mode 100755
index 0000000..f4cbd12
--- /dev/null
+++ b/scripts/run_sample_video.sh
@@ -0,0 +1,30 @@
+#! /bin/bash
+
+export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
+export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
+cd $PROJECT_DIR
+export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
+export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
+
+export llama_tokenizer_path=""
+export vqgan_checkpoint=""
+export lwm_checkpoint=""
+
+python3 -u -m lwm.vision_generation \
+ --prompt='Fireworks over the city' \
+ --output_file='fireworks.mp4' \
+ --temperature_image=1.0 \
+ --temperature_video=1.0 \
+ --top_k_image=8192 \
+ --top_k_video=1000 \
+ --cfg_scale_image=5.0 \
+ --cfg_scale_video=1.0 \
+ --vqgan_checkpoint="$vqgan_checkpoint" \
+ --n_frames=8 \
+ --mesh_dim='!-1,1,8,1' \
+ --dtype='fp32' \
+ --load_llama_config='7b' \
+ --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,use_flash_attention=True,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
+ --load_checkpoint="params::$lwm_checkpoint" \
+ --tokenizer.vocab_file="$llama_tokenizer_path"
+read
diff --git a/scripts/run_train_text.sh b/scripts/run_train_text.sh
new file mode 100755
index 0000000..c1b83e3
--- /dev/null
+++ b/scripts/run_train_text.sh
@@ -0,0 +1,55 @@
+#! /bin/bash
+
+export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
+export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
+cd $PROJECT_DIR
+export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
+export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
+
+export llama_tokenizer_path=""
+export dataset_path=""
+export output_dir=""
+
+export project_id='lwm'
+export experiment_note=''
+export experiment_id='example-text-train'
+
+# mesh_dim: dp, fsdp, tp, sp
+python3 -u -m lwm.train \
+ --modality='text' \
+ --mesh_dim='!1,-1,2,2' \
+ --dtype='fp32' \
+ --total_steps=200\
+ --log_freq=1 \
+ --save_model_freq=0 \
+ --save_milestone_freq=10 \
+ --load_llama_config='debug' \
+ --update_llama_config="dict(theta=10000,max_sequence_length=4096,use_flash_attention=False,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
+ --tokenizer.vocab_file="$llama_tokenizer_path" \
+ --optimizer.type='adamw' \
+ --optimizer.accumulate_gradient_steps=1 \
+ --optimizer.adamw_optimizer.weight_decay=0.1 \
+ --optimizer.adamw_optimizer.lr=8e-5 \
+ --optimizer.adamw_optimizer.end_lr=8e-5 \
+ --optimizer.adamw_optimizer.lr_warmup_steps=5 \
+ --optimizer.adamw_optimizer.lr_decay_steps=200 \
+ --use_data_sharded_loader=True \
+ --train_dataset.type='json' \
+ --train_dataset.text_processor.fields='text' \
+ --train_dataset.json_dataset.path="$dataset_path" \
+ --train_dataset.json_dataset.seq_length=1024 \
+ --train_dataset.json_dataset.batch_size=8 \
+ --train_dataset.json_dataset.tokenizer_processes=4 \
+ --train_dataset.json_dataset.tokenizer_parallel_chunk_size=2 \
+ --train_dataset.json_dataset.tokenizer_parallel_batch_size=8 \
+ --train_dataset.json_dataset.use_data_sharded_loader=True \
+ --checkpointer.save_optimizer_state=True \
+ --autoresume=False \
+ --logger.append_uuid=False \
+ --logger.online=False \
+ --logger.project_id="$project_id" \
+ --logger.experiment_id="$experiment_id" \
+ --logger.experiment_note="$experiment_note" \
+ --logger.output_dir="$output_dir" \
+ --logger.wandb_dir="$HOME/experiment_output/$project_id"
+read
diff --git a/scripts/run_train_vision_text.sh b/scripts/run_train_vision_text.sh
new file mode 100755
index 0000000..6f5af6b
--- /dev/null
+++ b/scripts/run_train_vision_text.sh
@@ -0,0 +1,57 @@
+#! /bin/bash
+
+export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
+export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
+cd $PROJECT_DIR
+export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
+export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
+
+export llama_tokenizer_path=""
+export dataset_path=""
+export output_dir=""
+
+export project_id='lwm'
+export experiment_note=''
+export experiment_id='example-vision-text-train'
+
+# mesh_dim: dp, fsdp, tp, sp
+python3 -u -m lwm.train \
+ --modality='vision,text' \
+ --mesh_dim='!1,-1,2,2' \
+ --dtype='fp32' \
+ --total_steps=200 \
+ --log_freq=1 \
+ --save_model_freq=0 \
+ --save_milestone_freq=10 \
+ --load_llama_config='debug' \
+ --update_llama_config="dict(theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
+ --tokenizer.vocab_file="$llama_tokenizer_path" \
+ --optimizer.type='adamw' \
+ --optimizer.accumulate_gradient_steps=1 \
+ --optimizer.adamw_optimizer.weight_decay=0.1 \
+ --optimizer.adamw_optimizer.lr=8e-5 \
+ --optimizer.adamw_optimizer.end_lr=8e-5 \
+ --optimizer.adamw_optimizer.lr_warmup_steps=5 \
+ --optimizer.adamw_optimizer.lr_decay_steps=200 \
+ --use_data_sharded_loader=True \
+ --train_dataset.type='json_vision' \
+ --train_dataset.vision_text_processor.fields_from_example='fields' \
+ --train_dataset.vision_text_processor.max_n_frames=4 \
+ --train_dataset.json_vision_dataset.mode="no_pad" \
+ --train_dataset.json_vision_dataset.path="$dataset_path" \
+ --train_dataset.json_vision_dataset.seq_length=2048 \
+ --train_dataset.json_vision_dataset.batch_size=8 \
+ --train_dataset.json_vision_dataset.tokenizer_processes=4 \
+ --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=2 \
+ --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=8 \
+ --train_dataset.json_vision_dataset.use_data_sharded_loader=True \
+ --checkpointer.save_optimizer_state=True \
+ --autoresume=False \
+ --logger.append_uuid=False \
+ --logger.online=False \
+ --logger.project_id="$project_id" \
+ --logger.experiment_id="$experiment_id" \
+ --logger.experiment_note="$experiment_note" \
+ --logger.output_dir="$output_dir" \
+ --logger.wandb_dir="$HOME/experiment_output/$project_id"
+read
diff --git a/scripts/run_vision_chat.sh b/scripts/run_vision_chat.sh
new file mode 100755
index 0000000..869ceb4
--- /dev/null
+++ b/scripts/run_vision_chat.sh
@@ -0,0 +1,25 @@
+#! /bin/bash
+
+export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
+export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
+cd $PROJECT_DIR
+export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
+
+export llama_tokenizer_path=""
+export vqgan_checkpoint=""
+export lwm_checkpoint=""
+export input_file=""
+
+python3 -u -m lwm.vision_chat \
+ --prompt="What is the video about?" \
+ --input_file="$input_file" \
+ --vqgan_checkpoint="$vqgan_checkpoint" \
+ --mesh_dim='!1,-1,32,1' \
+ --dtype='fp32' \
+ --load_llama_config='7b' \
+ --max_n_frames=8 \
+ --update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)" \
+ --load_checkpoint="params::$lwm_checkpoint" \
+ --tokenizer.vocab_file="$llama_tokenizer_path" \
+2>&1 | tee ~/output.log
+read
diff --git a/tpu_vm_setup.sh b/tpu_vm_setup.sh
new file mode 100755
index 0000000..6c4d1fb
--- /dev/null
+++ b/tpu_vm_setup.sh
@@ -0,0 +1,170 @@
+#! /bin/bash
+
+sudo umount /mnt/ramdisk
+sudo rm -rf /mnt/ramdisk
+sudo mkdir /mnt/ramdisk
+sudo mount -t tmpfs -o size=10G tmpfs /mnt/ramdisk
+
+sudo apt-get update && sudo apt-get install -y \
+ build-essential \
+ python-is-python3 \
+ tmux \
+ htop \
+ git \
+ nodejs \
+ bmon \
+ p7zip-full \
+ nfs-common \
+ ffmpeg
+
+# Update pip
+pip install --upgrade pip
+
+pip uninstall -y tux
+
+# Python dependencies
+cat > $HOME/tpu_requirements.txt <<- EndOfFile
+-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+jax[tpu]==0.4.23
+tensorflow-cpu==2.11.0
+tensorboard-plugin-profile
+flax==0.7.0
+optax==0.1.7
+chex==0.1.82
+einops
+--extra-index-url https://download.pytorch.org/whl/cpu
+torch==2.0.0
+torchvision==0.15.0
+transformers==4.29.2
+datasets==2.13.0
+tqdm
+ml_collections
+wandb
+requests
+gcsfs
+typing-extensions
+sentencepiece
+tux @ git+https://github.com/lhao499/tux.git
+Pillow
+ipdb
+imageio[ffmpeg]
+tiktoken
+EndOfFile
+
+pip install --upgrade -r $HOME/tpu_requirements.txt
+
+# vim configurations
+cat > $HOME/.vimrc <<- EndOfFile
+set tabstop=4
+set shiftwidth=4
+set softtabstop=4
+set expandtab
+set backspace=indent,eol,start
+syntax on
+EndOfFile
+
+# tmux configurations
+cat > $HOME/.tmux.conf <<- EndOfFile
+bind r source-file ~/.tmux.conf \; display-message "█▓░ ~/.tmux.conf reloaded."
+
+# Enable colors, https://github.com/tmux/tmux/wiki/FAQ
+set -g default-terminal "tmux-256color"
+
+# start with window 1 (instead of 0)
+set -g base-index 1
+setw -g pane-base-index 1
+
+set -g prefix C-a
+
+set -g set-titles on
+set -g set-titles-string '#(whoami)::#h::#(curl ipecho.net/plain;echo)'
+
+# Status bar customization
+set -g status-interval 5
+set -g status-left-length 90
+set -g status-right-length 60
+set -g status-justify left
+
+# send the prefix to client inside window (ala nested sessions)
+bind-key a send-prefix
+
+bind-key x kill-pane
+
+# auto reorder
+set-option -g renumber-windows on
+
+# default window name
+set -g status-left "#[fg=green,bg=colour236] #S "
+
+# default statusbar colors
+set-option -g status-style fg=yellow,dim,bg=colour235
+
+# default window title colors
+set-window-option -g window-status-style fg=yellow,bg=colour236,dim
+
+# active window title colors
+set-window-option -g window-status-current-style fg=brightred,bg=colour236
+
+# basename as window title https://stackoverflow.com/a/37136828
+set-window-option -g window-status-current-format '#{window_index} #{pane_current_command} #(echo "#{pane_current_path}" | rev | cut -d'/' -f-3 | rev)'
+set-window-option -g window-status-format '#{window_index} #{pane_current_command} #(echo "#{pane_current_path}" | rev | cut -d'/' -f-3 | rev)'
+
+# pane border
+set-option -g pane-border-style fg=white #base2
+set-option -g pane-active-border-style fg=brightcyan #base1
+
+# enable mouse click
+set -g mouse on
+
+# keep window on
+set -g remain-on-exit on
+
+# Longer scrollback history
+set -g history-limit 50000
+
+# Scroll position indicator
+set -g mode-style bg=colour235,fg=colour245
+
+# SSH agent forwarding
+# set-environment -g SSH_AUTH_SOCK $SSH_AUTH_SOCK
+if-shell '[ -n $SSH_AUTH_SOCK ]' " \
+ set-option -sg update-environment \"DISPLAY WINDOWID XAUTHORITY\"; \
+ setenv -g SSH_AUTH_SOCK /tmp/ssh_auth_sock_tmux; \
+ run-shell \"ln -sf $(find /tmp/ssh-* -type s -readable | head -n 1) /tmp/ssh_auth_sock_tmux\" \
+"
+
+# Drag windows on the status bar
+bind-key -n MouseDrag1Status swap-window -t=
+EndOfFile
+
+
+# htop Configurations
+mkdir -p $HOME/.config/htop
+cat > $HOME/.config/htop/htoprc <<- EndOfFile
+# Beware! This file is rewritten by htop when settings are changed in the interface.
+# The parser is also very primitive, and not human-friendly.
+fields=0 48 17 18 38 39 40 2 46 47 49 1
+sort_key=46
+sort_direction=1
+hide_threads=0
+hide_kernel_threads=1
+hide_userland_threads=1
+shadow_other_users=0
+show_thread_names=0
+show_program_path=1
+highlight_base_name=0
+highlight_megabytes=1
+highlight_threads=1
+tree_view=0
+header_margin=1
+detailed_cpu_time=0
+cpu_count_from_zero=0
+update_process_names=0
+account_guest_in_cpu_meter=0
+color_scheme=0
+delay=15
+left_meters=CPU Memory Swap
+left_meter_modes=1 1 1
+right_meters=Tasks LoadAverage Uptime
+right_meter_modes=2 2 2
+EndOfFile