-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Insights: jax-ml/jax
Overview
Could not load contribution data
Please try again later
95 Pull requests merged by 14 people
-
Pass through the
use_shardy_partitioner
withjax.config.jax_use_shardy_partitioner
.#29511 merged
Jun 17, 2025 -
Add test for programmatic tracing with options.
#29482 merged
Jun 17, 2025 -
Removing Tensorflow references from the document.
#29231 merged
Jun 17, 2025 -
Add custom-call ops to roofline.
#29473 merged
Jun 17, 2025 -
[JAX] Relax the return type of
colocated_python
decorator#29506 merged
Jun 16, 2025 -
[Mosaic:TPU][NFC] Delete unused variable
#29502 merged
Jun 16, 2025 -
Fix a missing bounds check in traceback code.
#29500 merged
Jun 16, 2025 -
[doc] add missing axis_types documentation
#29497 merged
Jun 16, 2025 -
Remove legacy CPU custom calls.
#29180 merged
Jun 16, 2025 -
Install SciPy from its source (head) to test against Python 3.14.0b1
#29496 merged
Jun 16, 2025 -
Set heartbeat_timeout argument and flag.
#29443 merged
Jun 16, 2025 -
[export] Add back-compat test for tridiagonal solve on GPU
#29488 merged
Jun 16, 2025 -
Fix some more instances of unhashable jaxpr equation arguments.
#29486 merged
Jun 16, 2025 -
[jax] Increase absolute test tolerance for lax_control_flow test
#29487 merged
Jun 16, 2025 -
Removed unused
PyTreeDef::MakeFromNodeDataAndChildren
and its Python binding#29491 merged
Jun 16, 2025 -
Bump the libtpu check to 6/20
#29492 merged
Jun 16, 2025 -
Add version guards to testAutoPgle
#29490 merged
Jun 16, 2025 -
Make mosaic_gpu equation params hashable.
#29464 merged
Jun 14, 2025 -
[JAX] Fix the test names in colocated_python_test.py to following the standard snake case
#29472 merged
Jun 14, 2025 -
Remove
with_spec
from NamedSharding and replace with.update
#29468 merged
Jun 13, 2025 -
Replace
with_partitions
andwith_unreduced
with.update
on Partitions#29467 merged
Jun 13, 2025 -
[Pallas] Add no_pipelining debugging option to emit_pipeline.
#29465 merged
Jun 13, 2025 -
Load CUDA libraries up front with cdll.LoadLibrary().
#29462 merged
Jun 13, 2025 -
Internal refactor: move TPU lowering rules out of jax/_src/lax
#29420 merged
Jun 13, 2025 -
Add extra cuBLAS/cuDNN version checks
#28931 merged
Jun 13, 2025 -
Add
pjit_p
tojax.extend.core.primitives
#29456 merged
Jun 13, 2025 -
Make params of assert_consumed_value_p hashable.
#29457 merged
Jun 13, 2025 -
Improve reshape not supported error message
#29459 merged
Jun 13, 2025 -
Make params of several pallas primitives hashable.
#29438 merged
Jun 13, 2025 -
[Mosaic GPU] Parametrize the
test_subview
test.#29454 merged
Jun 13, 2025 -
Fix return type annotation for tree_util.tree_broadcast.
#29442 merged
Jun 13, 2025 -
[Mosaic GPU] Reconcile the swizzle of the a and b operands for wgmma in the Mosaic GPU dialect.
#29437 merged
Jun 13, 2025 -
Disable
too_slow
in data.draw() for test_ndindexer#29453 merged
Jun 13, 2025 -
[Mosaic GPU] Use warpgroup semantics for the ragged dot example kernel.
#29007 merged
Jun 13, 2025 -
[Mosaic GPU] Resolve different tile transforms using the largest common divisor.
#29006 merged
Jun 13, 2025 -
[Mosaic GPU] Add a Mosaic GPU op
with_transforms
for manually setting memref transforms.#29440 merged
Jun 13, 2025 -
[Mosaic GPU] Convert all memrefs with transforms to unrealized casts and check them.
#29122 merged
Jun 13, 2025 -
fix-forward for pallas tpu memory spaces test
#29450 merged
Jun 13, 2025 -
Remove unused internal optimization_barrier alias
#29334 merged
Jun 13, 2025 -
Make the params of more jaxpr primitives hashable.
#29447 merged
Jun 13, 2025 -
[jaxlib] Change Traceback to be a raw CPython class rather than a nanobind class.
#29191 merged
Jun 13, 2025 -
Use a frozenset for unconstrained_dims in sharding_constraint_p.
#29435 merged
Jun 13, 2025 -
[doc] fix some inaccuracies in jnp.bincount docs
#29441 merged
Jun 12, 2025 -
Add colorama back into test-requirements
#29418 merged
Jun 12, 2025 -
[Mosaic GPU] Enable transpose tests in mosaic_gpu.
#29427 merged
Jun 12, 2025 -
[Pallas TPU] Small fix to memory space constraints on pallas_call inputs.
#29439 merged
Jun 12, 2025 -
Add
all_gather_invariant
to lax.#29408 merged
Jun 12, 2025 -
Reland the C++ safe_zip implementation.
#29416 merged
Jun 12, 2025 -
[JAX] Update the example to use jax.numpy rather than numpy.
#29374 merged
Jun 12, 2025 -
[Pallas TPU] Support memory space constraints on pallas_call inputs.
#29391 merged
Jun 12, 2025 -
add missing dtypes to jax.numpy.__init__.pyi
#29425 merged
Jun 12, 2025 -
Add hermetic
nvshmem
dependencies to JAX targets.#28892 merged
Jun 12, 2025 -
Temporarily disable AVX512 in linalg_test_cpu.
#29430 merged
Jun 12, 2025 -
[Mosaic GPU] Add conversion logic for
i4 -> f8e4m3fn
.#29428 merged
Jun 12, 2025 -
Move NamedSharding.__eq__ and NamedSharding.__hash__ into C++.
#29414 merged
Jun 12, 2025 -
Improve batching for lax.platform_dependent
#29362 merged
Jun 12, 2025 -
Fix GPU quantized paged attention tests for < sm89
#29423 merged
Jun 12, 2025 -
Add nightly linux jax wheel tests for python 3.14.0b1
#29417 merged
Jun 11, 2025 -
Add execution to unreduced tests now that it works end-to-end
#29411 merged
Jun 11, 2025 -
Add a pytype disable around zstandard.
#29238 merged
Jun 11, 2025 -
[XProf] Change tensorboard-plugin-profile to new xprof package
#29129 merged
Jun 11, 2025 -
Extend pallas paged_attention with kv scales
#29354 merged
Jun 11, 2025 -
Migrated to mypy 1.16.0
#29405 merged
Jun 11, 2025 -
[JAX] Extend
colocated_cpu_devices
to acceptMesh
besides devices#29387 merged
Jun 11, 2025 -
Move materialization of NDIndexer out of draw()
#29379 merged
Jun 11, 2025 -
//tests:scaled_matmul_stablehlo_test: fix for xla#27096
#29294 merged
Jun 11, 2025 -
Move jax/_src/export to its own build rule
#29385 merged
Jun 11, 2025 -
Do not call update_weak_type on the result of get_aval().
#29402 merged
Jun 11, 2025 -
add doc comment to vma in ShapedArray
#29406 merged
Jun 11, 2025 -
[Pallas] Fix shard_map + Megacore in TPU interpret mode.
#29350 merged
Jun 11, 2025 -
Delete instantiate_const_abstracted.
#29404 merged
Jun 11, 2025 -
Set explicit dot precision in the sparse solver test.
#29401 merged
Jun 11, 2025 -
Don't recompute np.iinfo in _scalar_type_to_dtype.
#29393 merged
Jun 11, 2025 -
Ensure that memory_kind is restored after pickling in SingleDeviceSharding and GSPMDSharding.
#29398 merged
Jun 11, 2025 -
Propagate source_info in more places:
#29389 merged
Jun 11, 2025 -
[Mosaic GPU] Remove unneeded code.
#29293 merged
Jun 11, 2025 -
[pallas:mosaic] A few more primitives now have lowerings for all kernel types
#29397 merged
Jun 11, 2025 -
[Mosaic GPU] Use _slice_smem also for barriers.
#29364 merged
Jun 11, 2025 -
Ensure that all attributes are restored after pickling in
NamedSharding
.#29365 merged
Jun 11, 2025 -
Save a jaxpr equation in pl.cdiv if the rhs is an int.
#29390 merged
Jun 11, 2025 -
Add basic mutable array tests with AOT
#29388 merged
Jun 11, 2025 -
[JAX] Move the fallback of
colocated_cpu_devices
logic from the colocated Python test to the API#29382 merged
Jun 10, 2025 -
Rollback of #29353 due to downstream failures
#29383 merged
Jun 10, 2025 -
Add is_leaf_with_path predicate.
#28300 merged
Jun 10, 2025 -
* Add support for output and input memory space colors in tpu custom calls via CustomCallConfig.
#29333 merged
Jun 10, 2025 -
fix for a downstream breakage from #29353
#29377 merged
Jun 10, 2025 -
Skip NumPy's
isClose
test for NumPy 2.3.0#29376 merged
Jun 10, 2025 -
[Mosaic GPU] Fix test after a previous PR changed the config params.
#29372 merged
Jun 10, 2025 -
[jax2tf] fix jax2tf sharding tests for shardy
#29367 merged
Jun 10, 2025 -
[contributing.md] add reference to pr-checklist
#29358 merged
Jun 10, 2025 -
[mutable-arrays] upgrade scan to work with partial_eval_jaxpr_fwd
#29353 merged
Jun 10, 2025 -
Pass source_info to custom_staging_rules and into jaxpr inlining.
#29347 merged
Jun 10, 2025 -
fix type annotation for _IndexUpdateRef.get
#29257 merged
Jun 10, 2025 -
[Pallas/Mosaic GPU] Fix the abstract eval rule for
load_p
in the presence ofRefUnion
s.#29366 merged
Jun 10, 2025
51 Pull requests opened by 11 people
-
[mutable-arrays] remat discharge rule
#29370 opened
Jun 10, 2025 -
[JAX][numpy] dont change indices type if not needed
#29378 opened
Jun 10, 2025 -
Add an API to overwrite the current execution_stream_id and respect it in XLA CPU dispatch.
#29380 opened
Jun 10, 2025 -
Expose local/global `ExchangeTopologies` timeouts for PJRT CPU client.
#29384 opened
Jun 10, 2025 -
[Mosaic] Use tree-based reduction for parallelism.
#29386 opened
Jun 10, 2025 -
Precompute has_changed and will_change during pallas pipelines.
#29392 opened
Jun 11, 2025 -
[XLA:GPU] Legalize dot precision into casts+algorithm.
#29403 opened
Jun 11, 2025 -
Expose local/global topology exchange timeouts for CPU client with collectives.
#29409 opened
Jun 11, 2025 -
add jax.nn module type hints (__init__.pyi)
#29410 opened
Jun 11, 2025 -
Refactor `custom_call` to use common `FindCudaExecutable` method from XLA repository to find CUDA binaries.
#29412 opened
Jun 11, 2025 -
[jax:benchmark] Add tracing benchmarks for some common operations.
#29413 opened
Jun 11, 2025 -
[CI] Testing interaction with skipped tests
#29415 opened
Jun 11, 2025 -
[Mosaic] Allow sublane rotation.
#29419 opened
Jun 11, 2025 -
[mutable-arrays] re-land #29353
#29421 opened
Jun 12, 2025 -
Adjust quantized paged attention tests to skip float8_e4m3 on < sm89.
#29422 opened
Jun 12, 2025 -
hacky dce
#29424 opened
Jun 12, 2025 -
Update missed ml_dtypes version check to 0.5.0
#29429 opened
Jun 12, 2025 -
This is an internal change that does not affect any public-facing features of JAX... yet. =)
#29431 opened
Jun 12, 2025 -
Add more models to auto tune and update tuned block.
#29432 opened
Jun 12, 2025 -
Disallow aliased mutable array arguments to vmap.
#29433 opened
Jun 12, 2025 -
Add `wrap_negative_indices` paramter to `jnp.ndarray.at[]`
#29434 opened
Jun 12, 2025 -
Link to PR message about QuasiDynamicData
#29436 opened
Jun 12, 2025 -
Use heartbeat_timeout argument.
#29444 opened
Jun 12, 2025 -
Remove old heartbeat flags and arguments.
#29445 opened
Jun 12, 2025 -
Remove old heartbeat options.
#29446 opened
Jun 12, 2025 -
Add `split_p` to `jax.extend.core.primitives`
#29452 opened
Jun 13, 2025 -
feat(scipy.special): add erfcx — scaled complementary error function
#29455 opened
Jun 13, 2025 -
[pallas] `AbstractMemoryRef` now uses `update` for all functional updates
#29460 opened
Jun 13, 2025 -
Refactor jax._src.lax imports
#29461 opened
Jun 13, 2025 -
[Pallas][Mosaic GPU] Enable collective MMA from TMEM.
#29466 opened
Jun 13, 2025 -
internal change
#29469 opened
Jun 14, 2025 -
[Pallas][Mosaic GPU] Add support for TMEM Ref aliasing.
#29471 opened
Jun 14, 2025 -
Migrate `scenic/common_lib/debug_utils` cost-analysis to `roofline`.
#29474 opened
Jun 14, 2025 -
Colocated python: Allow serializing `jax.Device`.
#29477 opened
Jun 15, 2025 -
[Mosaic:TPU] Extend support for ext to be symmetrical with trunc
#29484 opened
Jun 16, 2025 -
Fix incorrect gradient in custom_gradient example
#29485 opened
Jun 16, 2025 -
Reserve a few CPU cores when running CUDA tests to avoid getting CPU starved
#29493 opened
Jun 16, 2025 -
Prefer binaries in NVIDIA `nvcc` wheel over system CUDA installation in Mosaic GPU implementation.
#29494 opened
Jun 16, 2025 -
[Mosaic] Support collapsing multiple dimensions to load
#29495 opened
Jun 16, 2025 -
Add a cache around abstract_eval rules.
#29498 opened
Jun 16, 2025 -
Increase error tolerance for cudnn sdpa fp8 inference test
#29501 opened
Jun 16, 2025 -
[Pallas TPU] Add flag to enable using registers to keep track of slot info
#29503 opened
Jun 16, 2025 -
[TSAN CI] Removed fixed cpython suppressions
#29504 opened
Jun 16, 2025 -
Add an option to enable GPU collective cancelling.
#29505 opened
Jun 16, 2025 -
Use more concise tracer reprs by default.
#29507 opened
Jun 16, 2025 -
Add very basic support for shard_map + unreduced.
#29508 opened
Jun 16, 2025 -
[doc] Clarify Profiling Docs for XProf and Tensorboard integration
#29509 opened
Jun 17, 2025 -
[Mosaic GPU] Implement canonicalization for `TiledLayout`s.
#29512 opened
Jun 17, 2025
5 Issues closed by 3 people
-
doc: `axis_types` description missing from `jax.make_mesh` docstring
#29478 closed
Jun 16, 2025 -
Error: branches platform index mapped for cond with linear transpose
#29329 closed
Jun 12, 2025 -
Forward compatibility with 0.6.1
#29407 closed
Jun 11, 2025 -
Be able to consider subtree's key path when determining if is_leaf in tree operations
#27996 closed
Jun 10, 2025 -
Unsupported int8 in mosaic transpose
#29278 closed
Jun 10, 2025
14 Issues opened by 14 people
-
how to use platform dependent for function that requires kwargs
#29510 opened
Jun 17, 2025 -
CPU Over-utilization and taskset
#29499 opened
Jun 16, 2025 -
Unimplemented primitive in Pallas TPU lowering for KernelType.TC: dynamic_slice.
#29481 opened
Jun 16, 2025 -
cuSolver internal error in `jnp.linalg.eigvalsh(jnp.eye(2))`
#29475 opened
Jun 14, 2025 -
`jit` compiled function overhead, help for PyTree registration
#29470 opened
Jun 14, 2025 -
Missing `cost_estimate` in MHA splash attention kernel
#29463 opened
Jun 13, 2025 -
Show uv installation instructions in the docs
#29451 opened
Jun 13, 2025 -
Transposed preconditioned GMRES
#29449 opened
Jun 13, 2025 -
`jax.scipy.fft.idctn` brings different results with `scipy.fft.idctn`
#29426 opened
Jun 12, 2025 -
Experimental support for AMD GPUs on WSL2
#29400 opened
Jun 11, 2025 -
cannot find gpu
#29399 opened
Jun 11, 2025 -
jax.experimental.saved_input_vjp does not support has_aux
#29395 opened
Jun 11, 2025 -
Segmentation fault with simple comparison
#29373 opened
Jun 10, 2025
26 Unresolved conversations
Sometimes conversations happen on old items that aren’t yet closed. Here is a list of all the Issues and Pull Requests with unresolved conversations.
-
gumbel distribution implementation
#29343 commented on
Jun 14, 2025 • 24 new comments -
add psend and precv to jax/lax/parallel
#29135 commented on
Jun 17, 2025 • 16 new comments -
added solve_sylvester and accompanying tests
#28810 commented on
Jun 16, 2025 • 2 new comments -
Make canonicalization of reshapes more robust:
#29359 commented on
Jun 16, 2025 • 0 new comments -
[CI] Introduce action lint workflow
#29355 commented on
Jun 12, 2025 • 0 new comments -
[pallas:mosaic] Lower cond to balanced binary tree of if-else statements
#29346 commented on
Jun 13, 2025 • 0 new comments -
[CI] Run Mosaic H100 and B200 tests on all PRs that target mosaic subpaths
#29298 commented on
Jun 12, 2025 • 0 new comments -
Ensure all JAX benchmarks have `block_until_ready`.
#29289 commented on
Jun 14, 2025 • 0 new comments -
[ROCm] ROCm7 Plugin Updates
#29281 commented on
Jun 16, 2025 • 0 new comments -
#sdy #mixed_serialization don't make JAX export use `SdyRoundTripExportPipeline` to stringify attributes and convert ops to StableHLO `CustomCallOp`s and back.
#29272 commented on
Jun 16, 2025 • 0 new comments -
[Mosaic] Use BF16 ops for math::PowF on TPUv6+.
#29214 commented on
Jun 17, 2025 • 0 new comments -
Add _XlaShardingV2 to tf.XlaShardOp and use it for tf2xla lowering.
#29172 commented on
Jun 16, 2025 • 0 new comments -
Parametrize build system on CUDA major version
#28968 commented on
Jun 16, 2025 • 0 new comments -
[Mosaic:TPU] Byte-granularity dynamic gathers
#28952 commented on
Jun 17, 2025 • 0 new comments -
[CI] Add additional hardware to continuous non-rbe testing
#28688 commented on
Jun 16, 2025 • 0 new comments -
Major deps update:
#28497 commented on
Jun 16, 2025 • 0 new comments -
Fix overloaded type signature for jax.numpy.where.
#28314 commented on
Jun 10, 2025 • 0 new comments -
jax.scipy.linalg.eigh_tridiagonal() doesn't implement calculation of eigenvectors
#14019 commented on
Jun 16, 2025 • 0 new comments -
Binary op compare with different element types in sharded, jitted function call
#19691 commented on
Jun 16, 2025 • 0 new comments -
lax.cond crashes on Windows
#29049 commented on
Jun 14, 2025 • 0 new comments -
jax.nn.dot_product_attention(...implementation='cudnn') fails due to incorrect loading of libnvrtc
#29260 commented on
Jun 13, 2025 • 0 new comments -
jax_jit.cc support for Tracer: don't cache_miss
#10976 commented on
Jun 13, 2025 • 0 new comments -
NotImplementedError: sparse rule for reduce_max is not implemented
#28749 commented on
Jun 13, 2025 • 0 new comments -
`XlaRuntimeError` on latest jax-metal
#27062 commented on
Jun 12, 2025 • 0 new comments -
Implement scipy.stats.multivariate_normal.cdf
#10562 commented on
Jun 10, 2025 • 0 new comments -
Reverse mode through JVP of vmapped function returns "TypeError: [...] <class 'jax._src.ad_util.Zero'> is not a valid JAX type"
#29342 commented on
Jun 10, 2025 • 0 new comments