-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Insights: jax-ml/jax
Overview
Could not load contribution data
Please try again later
54 Pull requests merged by 15 people
-
Avoid adding conflicting
--repo_env=HERMETIC_PYTHON_VERSION=
to bazel command#25527 merged
Jan 8, 2025 -
[Mosaic GPU] Allow multiple indexing on refs
#25480 merged
Jan 8, 2025 -
[Pallas] Add empty/empty_like helper functions
#25722 merged
Jan 8, 2025 -
Add a unittest test extension that runs test cases in parallel using threads.
#25772 merged
Jan 8, 2025 -
Port tests away from setUpClass and setUpModule to setUp alone.
#25774 merged
Jan 8, 2025 -
[pallas:mosaic_gpu] Fix the tests following the changes to
pl.core_map
#25773 merged
Jan 8, 2025 -
[JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
#25765 merged
Jan 8, 2025 -
[Mosaic TPU] Be much more aggressive in inferring large 2nd minor layouts for 16-bit types on v6
#25605 merged
Jan 8, 2025 -
[Mosaic TPU] Add support for second minor broadcasts with packed types
#25636 merged
Jan 8, 2025 -
[Mosaic TPU] Add support for true divide in bf16 on TPUv6
#25608 merged
Jan 8, 2025 -
[Mosaic] Use tpu::CreateMask for getX32VmaskByPaddingEnd.
#25737 merged
Jan 8, 2025 -
[Mosaic GPU] Use
num_q_heads=2
inflash_attention.py
#25754 merged
Jan 8, 2025 -
Removed leftover mentions of xmap from the code
#25752 merged
Jan 8, 2025 -
Update the tutorial for jax.checkpoint activation offloading policies by add examples
#25594 merged
Jan 8, 2025 -
Add JAX events that have time spans, not only durations.
#25747 merged
Jan 8, 2025 -
[array api] update test suite to latest commit
#25761 merged
Jan 8, 2025 -
[Pallas] Improvements to core_map
#25721 merged
Jan 8, 2025 -
Increase the minimum SciPy version to 1.11.1.
#25744 merged
Jan 8, 2025 -
Add an example demonstrating input-output aliasing with the FFI
#25042 merged
Jan 7, 2025 -
[Pallas] Fix pallas_call lowering mutating compiler params during Triton lowering.
#25735 merged
Jan 7, 2025 -
[AutoPGLE] FIx PGLE kokoro test failures.
#25746 merged
Jan 7, 2025 -
Update the advanced autodiff tutorial and replace some vmap with grad
#25441 merged
Jan 7, 2025 -
Clean up
backend_or_name
vs.platforms
in lowering code.#25493 merged
Jan 7, 2025 -
Add a register_custom_type_id function to the GPU plugins.
#25413 merged
Jan 7, 2025 -
Fix type signature for __divmod__
#25748 merged
Jan 7, 2025 -
Move
jax.extend.ffi
module to top leveljax.ffi
#25633 merged
Jan 7, 2025 -
[shape_poly] Improve threefry with symbolic shapes
#25731 merged
Jan 7, 2025 -
Fix the doc error: module 'scipy.misc' has no attribute 'face'
#25728 merged
Jan 7, 2025 -
Make inspect_array_sharding work without mesh context manager too.
#25743 merged
Jan 7, 2025 -
Remove deprecated jax.experimental.array_api
#25740 merged
Jan 7, 2025 -
[Mosaic TPU] Add some elementwise canonicalizations
#25720 merged
Jan 6, 2025 -
Remove the need for check_rep for with_sharding_constraint.
#25739 merged
Jan 6, 2025 -
Rename pybind_extension to nanobind_extension.
#25736 merged
Jan 6, 2025 -
jnp.einsum: default to optimize='auto'
#25214 merged
Jan 6, 2025 -
[Mosaic TPU] Validate inserted layout in relayout-insertion pass.
#25645 merged
Jan 6, 2025 -
[shmap/partial-auto] Fixes lowering for jax.lax.axis_index in shard_map for degenerated shmaps.
#25699 merged
Jan 6, 2025 -
bazel: export serialization.fbs for downstream usage
#25697 merged
Jan 6, 2025 -
Deprecate scipy.special.lpmn & lpmn_values
#25675 merged
Jan 6, 2025 -
[shape_poly] Remove old non_negative support.
#25462 merged
Jan 6, 2025 -
Disable
avxvnniint8
when building with Clang version < 19, or GCC < 13.#25712 merged
Jan 6, 2025 -
[ROCm] Update package indentation fix
#25715 merged
Jan 6, 2025 -
Add SMEM as a supported Pallas output memory space.
#25689 merged
Jan 5, 2025 -
jax.debug.print: respect trace-time np.printoptions
#25713 merged
Jan 3, 2025 -
[Mosaic] NFC: Pull out vreg related functions to util.
#25600 merged
Jan 2, 2025 -
Compute cost-analysis on only one HLO module.
#25537 merged
Jan 2, 2025 -
Fix formatting in docs for transposing pytrees
#25686 merged
Jan 2, 2025 -
Fix OSS build for the Mosaic GPU dialect
#25710 merged
Jan 2, 2025 -
[Mosaic:TPU][NFC] Clean up unused variable
#25680 merged
Jan 2, 2025 -
Fix log10 and log2 for large inputs.
#25706 merged
Jan 2, 2025 -
Tensorboard profiling plugin nightly instructions
#25661 merged
Jan 2, 2025 -
Don't use x32 mode for pallas_test
#25709 merged
Jan 2, 2025
26 Pull requests opened by 8 people
-
Add GitHub action workflow for Bazel CUDA continuous tests
#25717 opened
Jan 3, 2025 -
[Mosaic] Create a stub for TPUExtDialect.
#25719 opened
Jan 3, 2025 -
[Mosaic TPU] Enable unaligned bf16 2D load/stores for earlier TPU gens
#25726 opened
Jan 6, 2025 -
[Mosaic TPU] Enable non-sublane-aligned 2D int8 load/stores
#25727 opened
Jan 6, 2025 -
[Mosaic TPU][NFC] Remove redundant num_subelems attribute from CreateSubelementMaskOp
#25729 opened
Jan 6, 2025 -
jnp.linalg.solve: finalize deprecation of batched 1D solves
#25741 opened
Jan 6, 2025 -
[shape_poly] Remove the deprecated PolyShape object for specifying symbolic dimensions
#25751 opened
Jan 7, 2025 -
Add `sph_harm_y` to `jax.scipy.special` and deprecate `sph_harm`
#25753 opened
Jan 7, 2025 -
[ROCm] Implement RNN support
#25755 opened
Jan 7, 2025 -
Introduce jax.shard_map, without requiring mesh arg
#25757 opened
Jan 7, 2025 -
[Mosaic] Fix inferMemRefLayout to error out if shape not aligned to tiling.
#25758 opened
Jan 7, 2025 -
[Mosaic GPU] Enable loop carries in the pipeline emitter.
#25762 opened
Jan 7, 2025 -
[Mosaic GPU] Allow multiple gmem indexers on copies.
#25763 opened
Jan 7, 2025 -
Add mode='fan_geo_avg' to nn.initializers.variance_scaling.
#25766 opened
Jan 8, 2025 -
Add a discussion of sharding to the FFI tutorial
#25771 opened
Jan 8, 2025 -
[pallas] DMA start discharge.
#25775 opened
Jan 8, 2025 -
[NFC] Refactor conversion lowering for Mosaic TPU
#25776 opened
Jan 8, 2025 -
[Pallas] Improve testing for lowering of dtype conversions + fix uncovered bugs
#25777 opened
Jan 8, 2025 -
Make api_test.py work when test cases are run using multiple threads.
#25780 opened
Jan 8, 2025 -
#sdy add repr for Sdy ArraySharding and DimSharding
#25781 opened
Jan 8, 2025 -
[pallas:mosaic_gpu] Tests now pass with x64 enabled
#25782 opened
Jan 8, 2025 -
New Experimental JAX Check Numerics API.
#25785 opened
Jan 8, 2025 -
Added 3.13 ft requirements lock file and updated WORKSPACE
#25786 opened
Jan 8, 2025 -
Add a JVP rule for lax.linalg.tridiagonal_solve, fixing some bugs along the way
#25787 opened
Jan 8, 2025 -
Remove warning suppressions for array API tests.
#25788 opened
Jan 8, 2025
10 Issues closed by 7 people
-
ordered_effects and passing tokens to custom_call using FFI v4 on CPU
#25756 closed
Jan 8, 2025 -
FP32 `jax.random.categorical` inconsistent with the FP64 version
#25749 closed
Jan 8, 2025 -
`jax.lax.composite` and `jax.nn.softmax` composes strangely
#25767 closed
Jan 8, 2025 -
caching not sensitive to bound axis names not present on inputs (args or closed over)
#9187 closed
Jan 7, 2025 -
Automatically treat dataclasses as pytrees
#2371 closed
Jan 7, 2025 -
An example for `ffi` input/output aliasing
#24986 closed
Jan 7, 2025 -
⚠️ Nightly upstream-dev CI failed ⚠️
#25664 closed
Jan 6, 2025 -
Small Einsum is hanging
#24929 closed
Jan 6, 2025 -
Missing annotations
#24888 closed
Jan 3, 2025 -
pure_callback is broken with multiple vmap
#23624 closed
Jan 2, 2025
11 Issues opened by 9 people
-
multihost_utils.process_allgather acts counterintuitively on single process
#25783 opened
Jan 8, 2025 -
`PositionalSharding`'s `reshape`, `transpose` and `replicate` methods reset memory kind
#25769 opened
Jan 8, 2025 -
Cache initialization fails when a JAX Array is created before enabling local cache
#25768 opened
Jan 8, 2025 -
Registering generic pytree nodes via getattr() and __setattr__
#25760 opened
Jan 7, 2025 -
XLA runtime error when taking grad + vmap + scan on GPU
#25759 opened
Jan 7, 2025 -
vmap(jnp.asarray)(numpy_array) does not return a JAX array
#25745 opened
Jan 7, 2025 -
Bfloat16 jnp.allclose raises Illegal Instruction Exception
#25730 opened
Jan 6, 2025 -
`vmap(custom_jvp)` does not strip zeros from nondifferentiable return values, leading to AD crashes
#25724 opened
Jan 5, 2025 -
If CUDA 12.1 is installed, pip-installed ptxas binary is not used and jax throws an error
#25718 opened
Jan 3, 2025 -
GPU `pallas_call` loses compiler params during second call when double jit-wrapped
#25714 opened
Jan 3, 2025 -
JAX can create str-dtyped tracers under `eval_shape` with numpy 2
#25707 opened
Jan 2, 2025
32 Unresolved conversations
Sometimes conversations happen on old items that aren’t yet closed. Here is a list of all the Issues and Pull Requests with unresolved conversations.
-
Implement column-pivoted QR via geqp3 (CPU lowering only)
#20282 commented on
Jan 6, 2025 • 10 new comments -
Add jax.random.multinomial.
#25688 commented on
Jan 7, 2025 • 4 new comments -
Update `jnp.floor_divide` to make it consistent with `np.floor_divide` for division by zero
#25032 commented on
Jan 3, 2025 • 3 new comments -
Add JAX_COMPILATION_CACHE_EXPECT_PGLE option
#24910 commented on
Jan 8, 2025 • 2 new comments -
Support Hessian of gamma-distributed samples
#21432 commented on
Jan 7, 2025 • 0 new comments -
[Pallas] Fix integer array indexing
#23758 commented on
Jan 3, 2025 • 0 new comments -
RFC: add pytree-compatible dataclass decorator
#24664 commented on
Jan 7, 2025 • 0 new comments -
Implement the extension to the custom_partitioning API.
#24719 commented on
Jan 8, 2025 • 0 new comments -
Added CI job with TSAN and free-threading
#24898 commented on
Jan 8, 2025 • 0 new comments -
Start a new TPU interpret mode for Pallas
#25097 commented on
Jan 8, 2025 • 0 new comments -
[Pallas TPU] Add vector support to `pl.debug_print`
#25099 commented on
Jan 8, 2025 • 0 new comments -
Add float8_e8m0fnu type support
#25116 commented on
Jan 8, 2025 • 0 new comments -
Add Github action workflows for running continuous tests with Pytest
#25238 commented on
Jan 8, 2025 • 0 new comments -
A test for using symbolic shapes for minformer.
#25317 commented on
Jan 8, 2025 • 0 new comments -
[Mosaic][TPU] Add a compatibility mode to Mosaic's canonicalization pass, skipping over elementwise and matmul op insertions and/or type compat casts.
#25556 commented on
Jan 8, 2025 • 0 new comments -
In progress experimention. Add StringDType to JAX's supported types.
#25592 commented on
Jan 7, 2025 • 0 new comments -
[pallas/pallas_mgpu] Discharging run_scoped should not be discharging the intermediates
#25639 commented on
Jan 8, 2025 • 0 new comments -
Unexpected NaN gradient of `jnp.abs` at `±inf + 0j`
#25681 commented on
Jan 8, 2025 • 0 new comments -
Differentiation rule for tridiagonal_solve
#25693 commented on
Jan 8, 2025 • 0 new comments -
Add mode='fan_geo_avg' to nn.initializers.variance_scaling
#25649 commented on
Jan 8, 2025 • 0 new comments -
feature request: sparse jacobian and sparse hessians
#1032 commented on
Jan 8, 2025 • 0 new comments -
Generating random numbers in pallas on gpu
#25188 commented on
Jan 7, 2025 • 0 new comments -
Caching bug for jax.experimental.callback.rewrite for `jit` call inside `custom_gradient`
#6685 commented on
Jan 7, 2025 • 0 new comments -
Slicing a CPU-placed jax array results in unnecessary host-to-device transfers
#16002 commented on
Jan 6, 2025 • 0 new comments -
Unexpected NaN when signing `±inf + 0j`
#25679 commented on
Jan 4, 2025 • 0 new comments -
Improvements to debug_nans
#25643 commented on
Jan 3, 2025 • 0 new comments -
Get "invalid value (nan) encountered in jit" even when jit disabled globally
#25701 commented on
Jan 3, 2025 • 0 new comments -
Counterintuitive speed of einsums vs equivalent matmuls
#20952 commented on
Jan 3, 2025 • 0 new comments -
Apple Silicon: error: failed to legalize operation 'mhlo.cholesky'
#16321 commented on
Jan 3, 2025 • 0 new comments -
bfloat16/float32 memory requirements seem off
#3302 commented on
Jan 3, 2025 • 0 new comments -
Sharding is much slower than pmap for while loops of varying length while loops
#20968 commented on
Jan 2, 2025 • 0 new comments -
jax.experimental.jet.jet series argument: unintuitive ordering; at least documentation should explain it better
#25700 commented on
Jan 2, 2025 • 0 new comments