Skip to content

Commit

Permalink
fix readme
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 22, 2024
1 parent 62298b5 commit 9487828
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

Porting [FlexAttention](https://github.com/pytorch-labs/attention-gym) to pure JAX.

Example usage:
Please install Jax nightly: pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html

Example usage:

```python
import jax
Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ homepage = "https://github.com/zinccat/flaxattention"
python = "^3.10"
jax = "^0.4.34"

[tool.poetry.extras]
viz = [
"matplotlib",
"numpy",
"jsonargparse",
"docstring-parser"
]

[tool.mypy]
disable_error_code = "import-untyped"

Expand Down

0 comments on commit 9487828

Please sign in to comment.