-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Insights: jax-ml/jax
Overview
Could not load contribution data
Please try again later
2 Releases published by 1 person
-
jax-v0.4.37 JAX v0.4.37
published
Dec 10, 2024 -
jax-v0.4.38 JAX v0.4.38
published
Dec 17, 2024
287 Pull requests merged by 37 people
-
bazel visibility change
#25796 merged
Jan 9, 2025 -
Make sharding_in_types work with Shardy
#25795 merged
Jan 9, 2025 -
#sdy add repr for Sdy ArraySharding and DimSharding
#25781 merged
Jan 8, 2025 -
Added 3.13 ft requirements lock file and updated WORKSPACE
#25786 merged
Jan 8, 2025 -
Make api_test.py work when test cases are run using multiple threads.
#25780 merged
Jan 8, 2025 -
Implement the extension to the custom_partitioning API.
#24719 merged
Jan 8, 2025 -
Avoid adding conflicting
--repo_env=HERMETIC_PYTHON_VERSION=
to bazel command#25527 merged
Jan 8, 2025 -
[Mosaic GPU] Allow multiple indexing on refs
#25480 merged
Jan 8, 2025 -
[Pallas] Add empty/empty_like helper functions
#25722 merged
Jan 8, 2025 -
Add a unittest test extension that runs test cases in parallel using threads.
#25772 merged
Jan 8, 2025 -
Port tests away from setUpClass and setUpModule to setUp alone.
#25774 merged
Jan 8, 2025 -
[pallas:mosaic_gpu] Fix the tests following the changes to
pl.core_map
#25773 merged
Jan 8, 2025 -
[JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
#25765 merged
Jan 8, 2025 -
[Mosaic TPU] Be much more aggressive in inferring large 2nd minor layouts for 16-bit types on v6
#25605 merged
Jan 8, 2025 -
[Mosaic TPU] Add support for second minor broadcasts with packed types
#25636 merged
Jan 8, 2025 -
[Mosaic TPU] Add support for true divide in bf16 on TPUv6
#25608 merged
Jan 8, 2025 -
[Mosaic] Use tpu::CreateMask for getX32VmaskByPaddingEnd.
#25737 merged
Jan 8, 2025 -
[Mosaic GPU] Use
num_q_heads=2
inflash_attention.py
#25754 merged
Jan 8, 2025 -
Removed leftover mentions of xmap from the code
#25752 merged
Jan 8, 2025 -
Update the tutorial for jax.checkpoint activation offloading policies by add examples
#25594 merged
Jan 8, 2025 -
Add JAX events that have time spans, not only durations.
#25747 merged
Jan 8, 2025 -
[array api] update test suite to latest commit
#25761 merged
Jan 8, 2025 -
[Pallas] Improvements to core_map
#25721 merged
Jan 8, 2025 -
Increase the minimum SciPy version to 1.11.1.
#25744 merged
Jan 8, 2025 -
Add an example demonstrating input-output aliasing with the FFI
#25042 merged
Jan 7, 2025 -
[Pallas] Fix pallas_call lowering mutating compiler params during Triton lowering.
#25735 merged
Jan 7, 2025 -
[AutoPGLE] FIx PGLE kokoro test failures.
#25746 merged
Jan 7, 2025 -
Update the advanced autodiff tutorial and replace some vmap with grad
#25441 merged
Jan 7, 2025 -
Clean up
backend_or_name
vs.platforms
in lowering code.#25493 merged
Jan 7, 2025 -
Add a register_custom_type_id function to the GPU plugins.
#25413 merged
Jan 7, 2025 -
Fix type signature for __divmod__
#25748 merged
Jan 7, 2025 -
Move
jax.extend.ffi
module to top leveljax.ffi
#25633 merged
Jan 7, 2025 -
[shape_poly] Improve threefry with symbolic shapes
#25731 merged
Jan 7, 2025 -
Fix the doc error: module 'scipy.misc' has no attribute 'face'
#25728 merged
Jan 7, 2025 -
Make inspect_array_sharding work without mesh context manager too.
#25743 merged
Jan 7, 2025 -
Remove deprecated jax.experimental.array_api
#25740 merged
Jan 7, 2025 -
[Mosaic TPU] Add some elementwise canonicalizations
#25720 merged
Jan 6, 2025 -
Remove the need for check_rep for with_sharding_constraint.
#25739 merged
Jan 6, 2025 -
Rename pybind_extension to nanobind_extension.
#25736 merged
Jan 6, 2025 -
jnp.einsum: default to optimize='auto'
#25214 merged
Jan 6, 2025 -
[Mosaic TPU] Validate inserted layout in relayout-insertion pass.
#25645 merged
Jan 6, 2025 -
[shmap/partial-auto] Fixes lowering for jax.lax.axis_index in shard_map for degenerated shmaps.
#25699 merged
Jan 6, 2025 -
bazel: export serialization.fbs for downstream usage
#25697 merged
Jan 6, 2025 -
Deprecate scipy.special.lpmn & lpmn_values
#25675 merged
Jan 6, 2025 -
[shape_poly] Remove old non_negative support.
#25462 merged
Jan 6, 2025 -
Disable
avxvnniint8
when building with Clang version < 19, or GCC < 13.#25712 merged
Jan 6, 2025 -
[ROCm] Update package indentation fix
#25715 merged
Jan 6, 2025 -
Add SMEM as a supported Pallas output memory space.
#25689 merged
Jan 5, 2025 -
jax.debug.print: respect trace-time np.printoptions
#25713 merged
Jan 3, 2025 -
[Mosaic] NFC: Pull out vreg related functions to util.
#25600 merged
Jan 2, 2025 -
Compute cost-analysis on only one HLO module.
#25537 merged
Jan 2, 2025 -
Fix formatting in docs for transposing pytrees
#25686 merged
Jan 2, 2025 -
Fix OSS build for the Mosaic GPU dialect
#25710 merged
Jan 2, 2025 -
[Mosaic:TPU][NFC] Clean up unused variable
#25680 merged
Jan 2, 2025 -
Fix log10 and log2 for large inputs.
#25706 merged
Jan 2, 2025 -
Tensorboard profiling plugin nightly instructions
#25661 merged
Jan 2, 2025 -
Don't use x32 mode for pallas_test
#25709 merged
Jan 2, 2025 -
Fixes incorrect sharding lowering for key<fry> in shard_map partial-auto
#25692 merged
Dec 30, 2024 -
[pallas:mosaic_gpu] Addressed a todo in
broadcasted_iota
lowering#25683 merged
Dec 24, 2024 -
Finalize deprecation of some symbols from jax.lib.xla_client
#25672 merged
Dec 24, 2024 -
Bump actions/upload-artifact from 4.4.3 to 4.5.0
#25674 merged
Dec 24, 2024 -
[Mosaic:TPU] In infer ext rule, avoid assigning offsets outside of dst first tile
#25560 merged
Dec 23, 2024 -
Switch
mlir
bindings frompybind11
tonanobind
#25677 merged
Dec 23, 2024 -
[Mosaic:TPU] Vreg-slice-aligned offset changes with scratch retiling
#25243 merged
Dec 23, 2024 -
Remove casting from jax.nn.one_hot
#25669 merged
Dec 23, 2024 -
changelog: link to api compatibility & python version docs
#25673 merged
Dec 23, 2024 -
Internal: use a single registry for abstractify APIs
#25651 merged
Dec 23, 2024 -
Fix a typo in documentation for
pinv
function#25662 merged
Dec 23, 2024 -
[NVIDIA] Cudnn dot_product_attention supports head size up to 256
#24607 merged
Dec 23, 2024 -
Use the right include for gmock and gtest
#25668 merged
Dec 23, 2024 -
[mosaic_gpu] Include Mosaic GPU dialect fiels into jaxlib
#25667 merged
Dec 23, 2024 -
[pallas:triton] Add support for lowering
int4
load.#25533 merged
Dec 23, 2024 -
[pallas:mosaic_gpu] Reduced duplication between
_ensure_fa
and_ensure_ir_value
#25665 merged
Dec 23, 2024 -
[pallas:mosaic_gpu] Updated the lowering following the changes in in Mosaic GPU internals
#25663 merged
Dec 23, 2024 -
Move ragged tests under a new class.
#25563 merged
Dec 22, 2024 -
Support int4/uint4 in jnp.ndarray.view
#25650 merged
Dec 21, 2024 -
Move output path to be inside the wheel build command execution loop
#25653 merged
Dec 21, 2024 -
[JAX] Add a test using inputs with different device orders for a single colocated Python call
#25652 merged
Dec 21, 2024 -
Change the namespace name to avoid using
export
c++ keyword on namespace.#25644 merged
Dec 21, 2024 -
Accelerate deprecation of legacy JAX FFI calling convention.
#25638 merged
Dec 20, 2024 -
DOC: clarify API compatibility discussion
#25416 merged
Dec 20, 2024 -
Avoid calls to warnings.catch_warnings in JAX core code.
#25648 merged
Dec 20, 2024 -
Add int4/uint4 support to bitcast_convert_type
#25646 merged
Dec 20, 2024 -
[Mosaic:TPU] Roll forward of cl/708011538 (expanded trunc support), minus changes in infer-vector-layout
#25629 merged
Dec 20, 2024 -
Migrate _mlir Python binding target to nanobind.
#25619 merged
Dec 20, 2024 -
[Mosaic GPU] Add a lowering for simple
async_load
andasync_store
ops.#25635 merged
Dec 20, 2024 -
Unify abstractify & shaped_abstractify rules
#25616 merged
Dec 20, 2024 -
Clarify documentation for output_offsets operand of ragged_all_to_all.
#25567 merged
Dec 20, 2024 -
Make jax.Arrays a necessary part of the cycle in the GC guard test
#25632 merged
Dec 20, 2024 -
[pallas:mosaic_gpu] Change the fori tests to also take the while_p path and fix the bug.
#25613 merged
Dec 20, 2024 -
Relax some test tolerances in N-D FFT tests.
#25631 merged
Dec 20, 2024 -
add mutable array ref error checks to cond and custom_vjp
#25625 merged
Dec 20, 2024 -
[Mosaic:TPU] Roll back cl/708011538 and cl/708112341
#25628 merged
Dec 20, 2024 -
[Mosaic TPU] Add relayout-insertion pass and support bitwidth change for i1 vector relayout
#25457 merged
Dec 20, 2024 -
[Mosaic:TPU] Fix trunc infer rule after cl/708011538
#25624 merged
Dec 20, 2024 -
Add more input validation to jax.distributed.initialize.
#25489 merged
Dec 20, 2024 -
[Mosaic] Extend macros to handle parentheses.
#25523 merged
Dec 19, 2024 -
[Mosaic] Remove TODOs that are already addressed or obsolete.
#25536 merged
Dec 19, 2024 -
Ensure that the two offsets of a dynamic_slice have the same dtype regardless
#25388 merged
Dec 19, 2024 -
[Mosaic TPU] Add support for sqrt and rsqrt in bf16 on TPUv6
#25572 merged
Dec 19, 2024 -
[Mosaic:TPU] For trunc, expand supported tilings, offsets and bitwidths
#25340 merged
Dec 19, 2024 -
[Mosaic GPU] Prototype of a warp-specialized pipeline emitter for Mosaic GPU.
#25599 merged
Dec 19, 2024 -
add test for partial-auto ppermute
#25557 merged
Dec 19, 2024 -
Migrate shardy dialect extension to nanobind.
#25584 merged
Dec 19, 2024 -
Remove internal uses of api_util.shaped_abstractify
#25614 merged
Dec 19, 2024 -
Add support for N-D FFTs with N>3
#25606 merged
Dec 19, 2024 -
[ROCm] Create PyPI wheel upload script and setup.py bindings
#25440 merged
Dec 19, 2024 -
[ROCm] ci build and dockerfile changes
#25148 merged
Dec 19, 2024 -
Add experimental support for building JAX CPU and GPU wheels with GCC.
#25531 merged
Dec 19, 2024 -
JEP: effort-based versioning (EffVer)
#25516 merged
Dec 19, 2024 -
jax.nn.one_hot: deprecate non-integer inputs
#25590 merged
Dec 19, 2024 -
Always suppress the differing_executors Hypothesis health check
#25615 merged
Dec 19, 2024 -
[Mosaic GPU] Commit to using
Vector
s everywhere (and noTensor
s).#25609 merged
Dec 19, 2024 -
Refactor: move shaped_abstractify to core
#25595 merged
Dec 19, 2024 -
[Mosaic TPU] Guard tests for new features by the libtpu version
#25604 merged
Dec 19, 2024 -
[Pallas:TPU] Use self.pallas_call to properly handle interpret mode
#25603 merged
Dec 19, 2024 -
[Mosaic GPU] Add a new tiled layout, optimized for upcasting before WGMMA
#25470 merged
Dec 19, 2024 -
[Mosaic GPU] Replace the dialect's layout enum with layouts holding the proper
#25578 merged
Dec 19, 2024 -
[Mosaic:TPU][NFC] Small cleanup of extui rule in apply-vector-layout
#25587 merged
Dec 19, 2024 -
[Mosaic GPU][NFC] Move up the definition of
ThreadSemantics
.#25573 merged
Dec 19, 2024 -
Add lax.composite primitive
#25104 merged
Dec 19, 2024 -
[Pallas] Fix lowering tests for reduction ops
#25580 merged
Dec 19, 2024 -
add mutable array ref error checks to scan
#25593 merged
Dec 19, 2024 -
remove whitespace from distributed arrays doc
#25596 merged
Dec 19, 2024 -
[Mosaic TPU] Support i32 vector multi reduction except cross lane.
#25492 merged
Dec 19, 2024 -
[Mosaic] Improve error verbosity of tpu.memref_slice verification
#25476 merged
Dec 19, 2024 -
Re-land changes from https://github.com/jax-ml/jax/pull/25555
#25581 merged
Dec 18, 2024 -
[Mosaic] Verify that the target IDs are provided in remote DMAs
#25478 merged
Dec 18, 2024 -
[Mosaic:TPU][NFC] Delete unused functions
#25546 merged
Dec 18, 2024 -
[Pallas] Add version guard for non-32-bit selection in test and fix github build failure.
#25589 merged
Dec 18, 2024 -
[Mosaic GPU][NFC] Split LaunchContext into a separate file
#25574 merged
Dec 18, 2024 -
[Mosaic TPU] Support direct cast i8 vector to mask
#25455 merged
Dec 18, 2024 -
fix bug with jax.remat static_argnums not supporting int
#25582 merged
Dec 18, 2024 -
More linearize fixes
#25490 merged
Dec 18, 2024 -
[JAX] Fix a small bug if shardings is tuple.
#25549 merged
Dec 18, 2024 -
Increase the minimum NumPy version to v1.25.
#25569 merged
Dec 18, 2024 -
Migrate mhlo dialect extension to nanobind.
#25579 merged
Dec 18, 2024 -
Partial discharge for scan_p ops.
#24013 merged
Dec 18, 2024 -
Use capture_stderr instead of packing sys.stderr
#25576 merged
Dec 18, 2024 -
[Mosaic TPU] Add support for bf16 second minor reductions in TPUv6
#25466 merged
Dec 18, 2024 -
Enable --config=clang only on newer Clang versions
#25543 merged
Dec 18, 2024 -
[Mosaic GPU] Clean up imports in
gpu_dialect_test.py
.#25577 merged
Dec 18, 2024 -
Migrate StableHLO Python extension to nanobind.
#25575 merged
Dec 18, 2024 -
[Pallas-Triton] Fix squeeze lowering required sharding argument
#25412 merged
Dec 18, 2024 -
Use the right health check suppression
#25568 merged
Dec 18, 2024 -
[pallas:mosaic_gpu] Added a lowering rule for the general
lax.while_loop_p
#25310 merged
Dec 18, 2024 -
[mosaic_gpu] Error on static out of bounds indices in
utils.parse_indices
#25311 merged
Dec 18, 2024 -
Use StableHLO acos. Update complex acos accuracy tests.
#23830 merged
Dec 18, 2024 -
Increase the timeout in profiler tests, but use threading.Event to exit as soon as possible
#25570 merged
Dec 18, 2024 -
[Mosaic TPU] Add support for exp, exp2 and log in bf16 on TPUv6
#25467 merged
Dec 18, 2024 -
Bump xla_extension_version after jaxlib release.
#25559 merged
Dec 18, 2024 -
Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
#25566 merged
Dec 18, 2024 -
add error checks for refs, behind a flag
#25449 merged
Dec 18, 2024 -
[Mosaic:TPU] Allow null parts for tpu.pack_subelements, meaning "don't care"
#25333 merged
Dec 18, 2024 -
improve checkpoint / remat concreteness error with static_argnums
#24516 merged
Dec 18, 2024 -
[Mosaic:TPU] Fix bug after cl/707025084
#25540 merged
Dec 18, 2024 -
Remove core.concrete_aval and replace with abstractify
#25555 merged
Dec 18, 2024 -
Fix the breakage caused by deleted enable_memories config
#25558 merged
Dec 18, 2024 -
Delete enable_memories code in C++ since that flag is always True and cannot be turned off now.
#25554 merged
Dec 18, 2024 -
[Mosaic GPU] Enable 64-bit types in test_scalar_argument
#25548 merged
Dec 18, 2024 -
Merge release/0.4.38 branch and update version numbers.
#25550 merged
Dec 17, 2024 -
Add Pallas Philox implementation.
#25545 merged
Dec 17, 2024 -
Always use the same code for array avals
#25544 merged
Dec 17, 2024 -
Streamline some core.concrete_aval compute paths
#25534 merged
Dec 17, 2024 -
Enable PJRT compatibility in cloud TPU CI.
#25539 merged
Dec 17, 2024 -
Raise rather than return error
#25538 merged
Dec 17, 2024 -
Cleanup: toward merging core.concrete_aval & xla.abstractify
#25456 merged
Dec 17, 2024 -
get_githash: fix support for missing git
#25473 merged
Dec 17, 2024 -
Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
#25036 merged
Dec 17, 2024 -
[Mosaic TPU] Add support for the interleaved pack format to tpu.unpack_subelements
#25465 merged
Dec 17, 2024 -
[Mosaic GPU] Add end-to-end lowering example for a pointwise kernel using the dialect and layout inference.
#25501 merged
Dec 17, 2024 -
Put abstract_mesh on every eqn so that we can preserve it during
eval_jaxpr
andcheck_jaxpr
roundtrip.#25522 merged
Dec 17, 2024 -
Make gmm TPU kernel tests significantly cheaper
#25529 merged
Dec 17, 2024 -
[Mosaic TPU] Add support for bf16 abs
#25464 merged
Dec 17, 2024 -
[Mosaic:TPU][NFC] Clean up local variable
#25500 merged
Dec 17, 2024 -
Allow
lax.ragged_all_to_all
input and output operands to have different ragged dimension sizes.#25518 merged
Dec 17, 2024 -
[Pallas TPU] Use vector.broadcast instead of vector.BroadcastOp to fix type check failure
#25520 merged
Dec 17, 2024 -
Fix
save_from_both_policies
in presence ofsave_and_offload_only_these_names
by comparing the enum#25509 merged
Dec 17, 2024 -
Re-deprecate a number of symbols from jax.core
#25508 merged
Dec 16, 2024 -
Add utility script and env for running the CI scripts under Docker
#25356 merged
Dec 16, 2024 -
Fix incorrect capitalization in scan error message
#25511 merged
Dec 16, 2024 -
Minor typo fixin doc.
#25479 merged
Dec 16, 2024 -
Relax tolerance for LAX reduction test in float16.
#25514 merged
Dec 16, 2024 -
Raise the timeout for Cloud TPU nightly CI
#25506 merged
Dec 16, 2024 -
Fix some flaky LAX autodiff tests.
#25513 merged
Dec 16, 2024 -
[export] Expand exporting to work with AbstractMesh.
#25435 merged
Dec 16, 2024 -
Avoid assuming that jnp.sin will be traced in abstract mesh tests
#25505 merged
Dec 16, 2024 -
Add Scaled Dot Product Attention for FP8
#22670 merged
Dec 16, 2024 -
Bump the oldest supported libtpu to conform to the 12 week window
#25502 merged
Dec 16, 2024 -
Drop the frequency of Cloud TPU tests
#25503 merged
Dec 16, 2024 -
[mosaic_gpu] Allow calling
reduce_sum
on a fragmented array in splat layout#25343 merged
Dec 16, 2024 -
Log DeprecationWarnings once per method/class
#25463 merged
Dec 14, 2024 -
[Mosaic GPU] Fix layout inference traversal to traverse ops recursively.
#25494 merged
Dec 14, 2024 -
fix
gamma_p
in vmap-based impl rule mode#25487 merged
Dec 14, 2024 -
Add an experimental Cloud TPU presubmit job
#25482 merged
Dec 14, 2024 -
Implement
process_call
for LinearizeTrace#25481 merged
Dec 14, 2024 -
Finalize deprecation of jnp.round_
#25483 merged
Dec 13, 2024 -
Limit self-hosted jobs to JAX main repo
#24536 merged
Dec 13, 2024 -
Remove deprecated XLA GPU flags from docs.
#25444 merged
Dec 13, 2024 -
temporarily un-deprecate several jax.core APIs.
#25451 merged
Dec 13, 2024 -
Migrate JAX MLIR Python dialect extensions to nanobind.
#25448 merged
Dec 13, 2024 -
Use a broadcasted gather in the sort JVP, rather than forming explici…
#25459 merged
Dec 13, 2024 -
Add flag desc to gpu_performance_tips.md
#25454 merged
Dec 13, 2024 -
Remove autotune sharing.
#25439 merged
Dec 13, 2024 -
Implement flatten one level with keys in C++ and use it for the prefix/equality error printing.
#25220 merged
Dec 13, 2024 -
Cleanup: replace lax._abstractify with core.get_aval
#25452 merged
Dec 12, 2024 -
Add jax.tree shortcuts for .*_with_path calls, for convenience of users.
#25420 merged
Dec 12, 2024 -
Disable failing test cases when
JAX_ENABLE_X64=1
in the Bazel CPU build#25443 merged
Dec 12, 2024 -
[mosaic] Migrated the serialization pass from codegen to
pass_boilerplate.h
#25432 merged
Dec 12, 2024 -
Relax test tolerance for complex128 pow in lax_test.py.
#25438 merged
Dec 12, 2024 -
Cleanup: remove uses of no-op raise_to_shaped
#25442 merged
Dec 12, 2024 -
[shape_poly] Improve handling of mod(e, k) == 0 constraints.
#25409 merged
Dec 12, 2024 -
Don't monkey-patch functions in test_utils to count events for tests.
#25426 merged
Dec 12, 2024 -
[Mosaic GPU] Use events as the default profiling method
#25431 merged
Dec 12, 2024 -
Fixes to direct linearize
#25425 merged
Dec 12, 2024 -
[Mosaic] Pad trailing transposes chunks with zeros.
#25384 merged
Dec 12, 2024 -
internal: dedupe lax broadcasting logic
#25424 merged
Dec 11, 2024 -
Use boolean values for partial mask blocks in the splash attention kernel.
#25323 merged
Dec 11, 2024 -
Run mypy with latest NumPy
#25381 merged
Dec 11, 2024 -
internal: simplify broadcast_shapes logic
#25422 merged
Dec 11, 2024 -
jax.lax: raise TypeError for mismatched dtypes
#25419 merged
Dec 11, 2024 -
Finalize some deprecations in jax.core, jax.lib.xla_bridge, and jax.lib.xla_client
#25414 merged
Dec 11, 2024 -
aarch64: add aarch64 mkldnn+acl build config
#23225 merged
Dec 11, 2024 -
[ROCm] Remove cuda include from gpu plugin extension
#25407 merged
Dec 11, 2024 -
jax.core: remove private API
#25415 merged
Dec 11, 2024 -
Add Rotation return type hint to Rotation.__mul__()
#25401 merged
Dec 11, 2024 -
Increase shard count after adding more tests
#25411 merged
Dec 11, 2024 -
[shape_poly] Improve reasoning for >= in presence of == constraints.
#25395 merged
Dec 11, 2024 -
Minor change in the README, remove "expect bugs"
#25403 merged
Dec 11, 2024 -
[Mosaic GPU] Split layout inference and dialect lowering files and tests.
#25405 merged
Dec 11, 2024 -
jax.core: more API deprecations
#25387 merged
Dec 11, 2024 -
[Mosaic GPU] Add an initial skeleton for a layout inference pass.
#25258 merged
Dec 11, 2024 -
jax.numpy: implement matvec & vecmat
#25390 merged
Dec 11, 2024 -
[Pallas] Remove
grid=1
in tests#25402 merged
Dec 11, 2024 -
[Mosaic GPU] Add CUPTI profiler alongside events-based implementation
#24805 merged
Dec 11, 2024 -
Activate Triangular Solve to XLA's FFI
#25316 merged
Dec 11, 2024 -
[Mosaic GPU] Add WGMMA to the Mosaic GPU MLIR Dialect.
#25397 merged
Dec 11, 2024 -
Add test of relu grad at zero. Update paper links.
#25392 merged
Dec 11, 2024 -
[sharding_in_types] Enforce AxisTypes to always exist if
set_mesh
is used.#25391 merged
Dec 11, 2024 -
[Pallas TPU] Add
WeirdOp
to TPU dialect and add lowering forlax.is_finite
#25217 merged
Dec 11, 2024 -
Adding more tests for multi-head attention
#25361 merged
Dec 11, 2024 -
Enable New bazel presubmits for pull requests.
#25300 merged
Dec 10, 2024 -
Delete non-public API jax.lib.xla_bridge._backends
#24976 merged
Dec 10, 2024 -
[jax:custom_partitioning] Make SdyShardingRule a user facing class.
#25350 merged
Dec 10, 2024 -
jax.core: deprecate a number of APIs
#25357 merged
Dec 10, 2024 -
[Pallas] Add non-square pl.dot test cases.
#25364 merged
Dec 10, 2024 -
Deduplicate some GPU plugin definition code.
#25383 merged
Dec 10, 2024 -
Remove code in jax2tf for compatibility with TF 2.10 or earlier.
#25378 merged
Dec 10, 2024 -
Reverts bdadc53ebcd40a5091d66d2586deba82fe5e01ca
#25341 merged
Dec 10, 2024 -
[AutoPGLE] Cleanup compiler code.
#25375 merged
Dec 10, 2024 -
Add a
freeze
primitive to delimit ref lifetimes for AD.#25369 merged
Dec 10, 2024 -
Update conda-forge installation docs after drop of support for CUDA 11
#25374 merged
Dec 10, 2024 -
[export] Improved the documentation.
#25335 merged
Dec 10, 2024 -
Introduce
lax.ragged_all_to_all
primitive#25370 merged
Dec 10, 2024 -
Add a no-op batching rule for optimization_barrier_p
#25367 merged
Dec 10, 2024 -
CI: temporarily pin numpy version for mypy check
#25371 merged
Dec 10, 2024 -
Merge release/0.4.37 into main.
#25368 merged
Dec 10, 2024 -
Port symmetric tridiagonal reduction GPU kernel to FFI.
#25358 merged
Dec 9, 2024 -
Disable pjit ArrayPjitTest.test_device_put_grad test on TPU v5e
#25360 merged
Dec 9, 2024 -
[Pallas] Update TPU documentation
#25272 merged
Dec 9, 2024 -
Avoid index out of range error in carry structure check
#25355 merged
Dec 9, 2024 -
array API: improve test coverage
#25294 merged
Dec 9, 2024 -
Bump actions/cache from 4.1.2 to 4.2.0
#25351 merged
Dec 9, 2024 -
Use private names for args in api_util to avoid shadowing kwargs keys
#25349 merged
Dec 9, 2024 -
Reenable for_loop_test on TPU v5p.
#25348 merged
Dec 9, 2024 -
[Mosaic:TPU][NFC] In ext and trunc rules, avoid vreg array reshape by always using implicit shapes
#25339 merged
Dec 9, 2024 -
Remove dead code after minimum jaxlib version bump to v0.4.36.
#25334 merged
Dec 9, 2024 -
[ROCm] Fix kernel build
#25320 merged
Dec 9, 2024 -
Ensured that JAX type checks under pytype on Python 3.12
#25342 merged
Dec 9, 2024 -
Fix type annotation for numpy.linalg.matrix_norm argument 'ord'.
#25338 merged
Dec 9, 2024 -
Activate Tridiagonal Reduction to XLA's FFI
#23447 merged
Dec 9, 2024 -
[Mosaic TPU] Allow downgrading the IR during serialization for forward compat
#25223 merged
Dec 9, 2024
67 Pull requests opened by 19 people
-
Bump hypothesis from 6.102.4 to 6.122.3
#25354 opened
Dec 9, 2024 -
enable partitionable threefry by default
#25363 opened
Dec 9, 2024 -
Add matrix logm
#25377 opened
Dec 10, 2024 -
Try to fix a mypy failure
#25380 opened
Dec 10, 2024 -
Add a workflow to run benchmarks
#25386 opened
Dec 10, 2024 -
Plumb some more arguments through emit_pipeline
#25394 opened
Dec 11, 2024 -
Update references to JAX's GitHub repo
#25406 opened
Dec 11, 2024 -
Add a simple ragged all-to-all Pallas TPU kernel.
#25421 opened
Dec 11, 2024 -
Use consistent dtype for forward and backwards in jax.nn.dot_product_attention.
#25433 opened
Dec 12, 2024 -
update gpu_performance_tips doc to remove deleted and redundant flags
#25447 opened
Dec 12, 2024 -
[Mosaic GPU] Add support for fast WGMMA layout changes after 8- to 16-bit upcast
#25471 opened
Dec 13, 2024 -
Update compilation docs with a note on mocking GPUs
#25474 opened
Dec 13, 2024 -
Remove CUDA dependencies from jaxlib wheel.
#25485 opened
Dec 13, 2024 -
Bump fonttools from 4.51.0 to 4.55.3
#25512 opened
Dec 16, 2024 -
Fix debug_nans regressions.
#25519 opened
Dec 16, 2024 -
[export] Add back-compat test for tridiagonal on GPU
#25525 opened
Dec 17, 2024 -
Jax: Stop returning a list of cost-analyses.
#25542 opened
Dec 17, 2024 -
[Mosaic] Allow strided layouts in tpu.memref_slice
#25551 opened
Dec 17, 2024 -
[pallas:triton] Removed unused `serialized_metadata` field from `TritonCompilerParams`
#25552 opened
Dec 17, 2024 -
Reverts 32627702270e8642d2c02d1921b12b055c599242
#25562 opened
Dec 18, 2024 -
WIP: no special treatment for ShapeDtypeStruct
#25583 opened
Dec 18, 2024 -
#sdy support JAX export tests when Shardy is enabled.
#25585 opened
Dec 18, 2024 -
#sdy enable pure callbacks and debug prints in JAX.
#25586 opened
Dec 18, 2024 -
In progress experimention. Add StringDType to JAX's supported types.
#25592 opened
Dec 18, 2024 -
Remove dead codepaths now that MemorySpaceDescription works in OSS
#25621 opened
Dec 19, 2024 -
Switch to a new thread-safe utility for catching warnings.
#25626 opened
Dec 20, 2024 -
[Mosaic:TPU] Less urgent fixes to trunc infer rule after cl/708011538
#25627 opened
Dec 20, 2024 -
[Pallas] Add a test to ensure consistent rounding during float-to-int casts
#25637 opened
Dec 20, 2024 -
[pallas/pallas_mgpu] Discharging run_scoped should not be discharging the intermediates
#25639 opened
Dec 20, 2024 -
[pallas_mgpu] For loops can have accumulators for carries.
#25640 opened
Dec 20, 2024 -
Fix backwards attention test for larger head dims
#25642 opened
Dec 20, 2024 -
Fix attention backwards pass
#25647 opened
Dec 20, 2024 -
Simplify implementation of random.orthogonal.
#25655 opened
Dec 22, 2024 -
Add support for `axis_name` and `axis_index_groups` to `lax.ragged_all_to_all`
#25660 opened
Dec 23, 2024 -
Expand test case to repro a crash.
#25670 opened
Dec 23, 2024 -
Add jax.random.multinomial.
#25688 opened
Dec 27, 2024 -
[Shmap/PartialAuto] Temporary solution for `debug.print` inside a partial-auto shard map.
#25705 opened
Jan 1, 2025 -
Add GitHub action workflow for Bazel CUDA continuous tests
#25717 opened
Jan 3, 2025 -
[Mosaic] Create a stub for TPUExtDialect.
#25719 opened
Jan 3, 2025 -
[Mosaic TPU] Enable unaligned bf16 2D load/stores for earlier TPU gens
#25726 opened
Jan 6, 2025 -
[Mosaic TPU] Enable non-sublane-aligned 2D int8 load/stores
#25727 opened
Jan 6, 2025 -
[Mosaic TPU][NFC] Remove redundant num_subelems attribute from CreateSubelementMaskOp
#25729 opened
Jan 6, 2025 -
jnp.linalg.solve: finalize deprecation of batched 1D solves
#25741 opened
Jan 6, 2025 -
[shape_poly] Remove the deprecated PolyShape object for specifying symbolic dimensions
#25751 opened
Jan 7, 2025 -
Add `sph_harm_y` to `jax.scipy.special` and deprecate `sph_harm`
#25753 opened
Jan 7, 2025 -
[ROCm] Implement RNN support
#25755 opened
Jan 7, 2025 -
Introduce jax.shard_map, without requiring mesh arg
#25757 opened
Jan 7, 2025 -
[Mosaic] Fix inferMemRefLayout to error out if shape not aligned to tiling.
#25758 opened
Jan 7, 2025 -
[Mosaic GPU] Enable loop carries in the pipeline emitter.
#25762 opened
Jan 7, 2025 -
[Mosaic GPU] Allow multiple gmem indexers on copies.
#25763 opened
Jan 7, 2025 -
Add mode='fan_geo_avg' to nn.initializers.variance_scaling.
#25766 opened
Jan 8, 2025 -
Add a discussion of sharding to the FFI tutorial
#25771 opened
Jan 8, 2025 -
[pallas] Support DMA start partial discharge and run_scoped() does its own partial discharge.
#25775 opened
Jan 8, 2025 -
[NFC] Refactor conversion lowering for Mosaic TPU
#25776 opened
Jan 8, 2025 -
[Pallas] Improve testing for lowering of dtype conversions + fix uncovered bugs
#25777 opened
Jan 8, 2025 -
[pallas:mosaic_gpu] Tests now pass with x64 enabled
#25782 opened
Jan 8, 2025 -
New Experimental JAX Check Numerics API.
#25785 opened
Jan 8, 2025 -
Add a JVP rule for lax.linalg.tridiagonal_solve, fixing some bugs along the way
#25787 opened
Jan 8, 2025 -
Remove warning suppressions for array API tests.
#25788 opened
Jan 8, 2025 -
[Mosaic TPU] Support unaligned minormost size in memref slice and DMA.
#25790 opened
Jan 8, 2025 -
fix grad(logsumexp) to produce 0s where `where` is False
#25791 opened
Jan 8, 2025 -
Clarify documentation of composites.
#25792 opened
Jan 9, 2025 -
[Mosaic TPU] Append dump id to timestamp to make dump list ordered
#25793 opened
Jan 9, 2025 -
Collective IDs *do* guarantee that we get the same semaphore every time.
#25794 opened
Jan 9, 2025 -
[aot] Add support for as_text(debug_info=True).
#25797 opened
Jan 9, 2025
38 Issues closed by 24 people
-
`PositionalSharding`'s `reshape`, `transpose` and `replicate` methods reset memory kind
#25769 closed
Jan 9, 2025 -
multihost_utils.process_allgather acts counterintuitively on single process
#25783 closed
Jan 8, 2025 -
ordered_effects and passing tokens to custom_call using FFI v4 on CPU
#25756 closed
Jan 8, 2025 -
FP32 `jax.random.categorical` inconsistent with the FP64 version
#25749 closed
Jan 8, 2025 -
`jax.lax.composite` and `jax.nn.softmax` composes strangely
#25767 closed
Jan 8, 2025 -
caching not sensitive to bound axis names not present on inputs (args or closed over)
#9187 closed
Jan 7, 2025 -
Automatically treat dataclasses as pytrees
#2371 closed
Jan 7, 2025 -
An example for `ffi` input/output aliasing
#24986 closed
Jan 7, 2025 -
⚠️ Nightly upstream-dev CI failed ⚠️
#25664 closed
Jan 6, 2025 -
Small Einsum is hanging
#24929 closed
Jan 6, 2025 -
Missing annotations
#24888 closed
Jan 3, 2025 -
pure_callback is broken with multiple vmap
#23624 closed
Jan 2, 2025 -
Non-symmetric eigenvalue decomposition: unimplemented error on jax>=0.4.36
#25687 closed
Dec 31, 2024 -
Failure to build jaxlib, AMD GPU
#25204 closed
Dec 29, 2024 -
Getting different results from CPU vs. CUDA backend
#22382 closed
Dec 21, 2024 -
arr.view fails with int4
#25620 closed
Dec 21, 2024 -
gcc: error: unrecognized command-line option '-Qunused-arguments'
#25488 closed
Dec 19, 2024 -
jax.nn.one_hot should not allow float inputs
#25484 closed
Dec 19, 2024 -
Is `jax.scipy.stats.norm.logcdf` twice differentiable?
#25564 closed
Dec 19, 2024 -
`jnp.finfo(x).eps` is not hashable
#25571 closed
Dec 18, 2024 -
scan of device_put carry raises TypeError during the backward pass
#22045 closed
Dec 17, 2024 -
jax.lax.scan transforms dict keys to lower case when reporting mismatch in pytree structures
#25507 closed
Dec 16, 2024 -
Customizable reduction in jax.lax.scatter
#6265 closed
Dec 14, 2024 -
`jax.random.beta` 3 orders of magnitude slower from 0.4.36 on GPU
#25469 closed
Dec 14, 2024 -
Unclear documentation errors
#25430 closed
Dec 13, 2024 -
std::bad_cast
#25468 closed
Dec 13, 2024 -
XLA-introduced copies supersede `lax.optimization_barrier`
#25399 closed
Dec 11, 2024 -
all_to_all with axis_index_group argument broken
#5861 closed
Dec 10, 2024 -
"IndexError: list index out of range" raised by _check_carry_type in lax/control_flow/loops.py
#25332 closed
Dec 10, 2024 -
Using a device function inside the host function of host_callback fails confusingly
#5934 closed
Dec 10, 2024 -
vmap support for optimization barrier
#25365 closed
Dec 10, 2024 -
About FFI failure: Failed to destroy GPU Graph
#25141 closed
Dec 10, 2024 -
Jitted functions can't take keyword arguments named `f` in v0.4.36
#25329 closed
Dec 9, 2024 -
ptxas : Unsupported .version 8.4; current version is '8.2' with jaxlib 0.4.34
#25344 closed
Dec 9, 2024 -
potential jax 0.4.35 release issue?
#24826 closed
Dec 9, 2024 -
unset JAX_PLATFORMS finds cuda, but JAX_PLATFORMS=gpu tries to use rocm (and fails)
#25315 closed
Dec 9, 2024
42 Issues opened by 34 people
-
Cache initialization fails when a JAX Array is created before enabling local cache
#25768 opened
Jan 8, 2025 -
Registering generic pytree nodes via getattr() and __setattr__
#25760 opened
Jan 7, 2025 -
XLA runtime error when taking grad + vmap + scan on GPU
#25759 opened
Jan 7, 2025 -
vmap(jnp.asarray)(numpy_array) does not return a JAX array
#25745 opened
Jan 7, 2025 -
Bfloat16 jnp.allclose raises Illegal Instruction Exception
#25730 opened
Jan 6, 2025 -
`vmap(custom_jvp)` does not strip zeros from nondifferentiable return values, leading to AD crashes
#25724 opened
Jan 5, 2025 -
If CUDA 12.1 is installed, pip-installed ptxas binary is not used and jax throws an error
#25718 opened
Jan 3, 2025 -
GPU `pallas_call` loses compiler params during second call when double jit-wrapped
#25714 opened
Jan 3, 2025 -
JAX can create str-dtyped tracers under `eval_shape` with numpy 2
#25707 opened
Jan 2, 2025 -
Get "invalid value (nan) encountered in jit" even when jit disabled globally
#25701 opened
Dec 31, 2024 -
Batched `meshgrid` or alternative
#25696 opened
Dec 30, 2024 -
Differentiation rule for tridiagonal_solve
#25693 opened
Dec 29, 2024 -
Unexpected NaN gradient of `jnp.abs` at `±inf + 0j`
#25681 opened
Dec 24, 2024 -
Unexpected NaN when signing `±inf + 0j`
#25679 opened
Dec 23, 2024 -
Crash with jit of ordered io_callback under shard_map and then no shard_map on CPU
#25671 opened
Dec 23, 2024 -
tree_util error handling throws error
#25659 opened
Dec 22, 2024 -
Got "Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR" with TF 2.18 and jax with cuda_local
#25658 opened
Dec 22, 2024 -
Add mode='fan_geo_avg' to nn.initializers.variance_scaling
#25649 opened
Dec 20, 2024 -
Improvements to debug_nans
#25643 opened
Dec 20, 2024 -
Add `mask` argument to `lax.argmax`
#25623 opened
Dec 20, 2024 -
Add `replace: bool` argument to `random.categorical` to sample without replacement using Gumbel-top-k trick
#25617 opened
Dec 19, 2024 -
Surprising difference of output between NumPy's `float32` and JAX's `float32`
#25601 opened
Dec 19, 2024 -
error: 'string' in namespace 'std' does not name a type
#25598 opened
Dec 19, 2024 -
Error when following the JAX tutorial on "Introduction to parallel programming"
#25541 opened
Dec 17, 2024 -
Make `xla_bridge.is_gpu` more extensible
#25521 opened
Dec 17, 2024 -
Clarification re: supported data types in `jax.linearize` and `jax.linear_transpose`
#25517 opened
Dec 16, 2024 -
Public API for mesh axis index size
#25515 opened
Dec 16, 2024 -
jax.random.choice(replace=True) samples 0 probability index
#25498 opened
Dec 16, 2024 -
jax.scipy.spatial.transform.Rotation.from_quat is missing scalar_first flag
#25491 opened
Dec 14, 2024 -
`register_dataclass` does not handle `__init__` methods
#25486 opened
Dec 13, 2024 -
Slow convolution, many memory warnings
#25461 opened
Dec 13, 2024 -
Segmentation fault in jaxlib 0.4.37 and Debian 11 ARM64
#25436 opened
Dec 12, 2024 -
Replace np.array() with np.fromfile() to improve performance
#25418 opened
Dec 11, 2024 -
Latency Hiding Scheduler not working with jax 0.4.35
#25404 opened
Dec 11, 2024 -
jax.full allocates memory on the wrong device
#25396 opened
Dec 11, 2024 -
Pallas Kernel Expected Output Shape Error Using Grids On TPU
#25379 opened
Dec 10, 2024 -
Add syntax highlighting for Lowered and Compiled objects
#25366 opened
Dec 10, 2024 -
Error when donating buffer with under-specified sharding
#25346 opened
Dec 9, 2024
62 Unresolved conversations
Sometimes conversations happen on old items that aren’t yet closed. Here is a list of all the Issues and Pull Requests with unresolved conversations.
-
Start a new TPU interpret mode for Pallas
#25097 commented on
Jan 8, 2025 • 12 new comments -
Implement column-pivoted QR via geqp3 (CPU lowering only)
#20282 commented on
Jan 6, 2025 • 10 new comments -
Update `jnp.floor_divide` to make it consistent with `np.floor_divide` for division by zero
#25032 commented on
Jan 3, 2025 • 10 new comments -
Implement SVD algorithm based on QR for CPU targets
#25053 commented on
Dec 18, 2024 • 7 new comments -
Add JAX_COMPILATION_CACHE_EXPECT_PGLE option
#24910 commented on
Jan 8, 2025 • 3 new comments -
[Pallas] Fix integer array indexing
#23758 commented on
Jan 3, 2025 • 0 new comments -
Support Hessian of gamma-distributed samples
#21432 commented on
Jan 7, 2025 • 0 new comments -
when a tracer error happens in for_loop, should point to the user's body function
#23637 commented on
Jan 9, 2025 • 0 new comments -
jax-metal: reduce with multiple operands failed to legalize
#21384 commented on
Jan 9, 2025 • 0 new comments -
jax.pjit non-partitionable tensor error delayed
#15788 commented on
Jan 9, 2025 • 0 new comments -
jax.scipy.linalg.eigh_tridiagonal() doesn't implement calculation of eigenvectors
#14019 commented on
Jan 9, 2025 • 0 new comments -
feature request: sparse jacobian and sparse hessians
#1032 commented on
Jan 8, 2025 • 0 new comments -
Generating random numbers in pallas on gpu
#25188 commented on
Jan 7, 2025 • 0 new comments -
Caching bug for jax.experimental.callback.rewrite for `jit` call inside `custom_gradient`
#6685 commented on
Jan 7, 2025 • 0 new comments -
Slicing a CPU-placed jax array results in unnecessary host-to-device transfers
#16002 commented on
Jan 6, 2025 • 0 new comments -
Counterintuitive speed of einsums vs equivalent matmuls
#20952 commented on
Jan 3, 2025 • 0 new comments -
Apple Silicon: error: failed to legalize operation 'mhlo.cholesky'
#16321 commented on
Jan 3, 2025 • 0 new comments -
bfloat16/float32 memory requirements seem off
#3302 commented on
Jan 3, 2025 • 0 new comments -
improve jax.checkpoint / jax.remat prevent_cse docstring
#24062 commented on
Dec 17, 2024 • 0 new comments -
improve readthedocs behavior for jax.remat / jax.checkpoint
#24064 commented on
Dec 18, 2024 • 0 new comments -
RFC: add pytree-compatible dataclass decorator
#24664 commented on
Jan 7, 2025 • 0 new comments -
Added CI job with TSAN and free-threading
#24898 commented on
Jan 9, 2025 • 0 new comments -
[Doc] Rename gpu_performance_tips.md to performance_tips.md with new CPU performance tips session
#24961 commented on
Dec 12, 2024 • 0 new comments -
[Pallas TPU] Add vector support to `pl.debug_print`
#25099 commented on
Jan 8, 2025 • 0 new comments -
Add float8_e8m0fnu type support
#25116 commented on
Jan 8, 2025 • 0 new comments -
improve jnp.mean / jnp.sum / ... error message for out-of-bounds axis index
#25155 commented on
Dec 18, 2024 • 0 new comments -
[Mosaic:TPU] 32-bit sublane broadcast for non-native tilings
#25160 commented on
Dec 23, 2024 • 0 new comments -
Support dynamic masks in splash attention
#25213 commented on
Dec 10, 2024 • 0 new comments -
Add Github action workflows for running continuous tests with Pytest
#25238 commented on
Jan 8, 2025 • 0 new comments -
A test for using symbolic shapes for minformer.
#25317 commented on
Jan 8, 2025 • 0 new comments -
Simplify implementation of nn.relu.
#25331 commented on
Dec 11, 2024 • 0 new comments -
Allow Pallas Triton to be serialized as PTX
#25196 commented on
Dec 9, 2024 • 0 new comments -
Get wrong structure when using jax2tf to convert nnx.module into tflite file
#24497 commented on
Dec 9, 2024 • 0 new comments -
JAX running in CPU only mode only uses a single core
#5022 commented on
Dec 10, 2024 • 0 new comments -
JAX does not recognise my NVIDIA GPU when installed via conda
#24604 commented on
Dec 10, 2024 • 0 new comments -
[Colab] jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure
#14893 commented on
Dec 10, 2024 • 0 new comments -
Host callback attempts to perform computation on GPU
#13046 commented on
Dec 10, 2024 • 0 new comments -
Clear Cuda Cache
#13330 commented on
Dec 10, 2024 • 0 new comments -
Memory leaked over time when calling the same computation
#25184 commented on
Dec 10, 2024 • 0 new comments -
Surprisingly slow jax.lax.top_k
#9940 commented on
Dec 11, 2024 • 0 new comments -
Lowering rule for `lax.scan` unnecessarily emits `while` loop for `unroll=1` when `length` or the size of `xs` is 1
#25330 commented on
Dec 11, 2024 • 0 new comments -
Multi-machine multi-card support
#16172 commented on
Dec 12, 2024 • 0 new comments -
add bounds in optimizers
#2164 commented on
Dec 12, 2024 • 0 new comments -
`pmap` with `out_axes != 0` errors when within `jit`
#5756 commented on
Dec 12, 2024 • 0 new comments -
Difference in output between jitted and non-jitted call
#20371 commented on
Dec 12, 2024 • 0 new comments -
Latency Hiding Scheduler leads to x5 memory usage if used without jax.lax.scan
#20763 commented on
Dec 16, 2024 • 0 new comments -
jax.numpy.arcsinh not has accurate for complex64 dtype
#19398 commented on
Dec 18, 2024 • 0 new comments -
int4 reshape: Reshape should have supported layout before reaching the emitter
#22121 commented on
Dec 19, 2024 • 0 new comments -
JAX code is extremely slow on GPUs
#24411 commented on
Dec 19, 2024 • 0 new comments -
TraceAnnotation not showing inside jax.jit
#12381 commented on
Dec 19, 2024 • 0 new comments -
Implement JVP for SVD when full_matrices=True
#508 commented on
Dec 19, 2024 • 0 new comments -
`jax.profiler.trace` repeatedly fails to display entire trace
#21295 commented on
Dec 19, 2024 • 0 new comments -
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 11
#21621 commented on
Dec 19, 2024 • 0 new comments -
`jax.debug.breakpoint(num_frames=1)` crashes in colab
#16187 commented on
Dec 20, 2024 • 0 new comments -
Fast JAX compilation when network architecture has per-dataset components.
#2316 commented on
Dec 20, 2024 • 0 new comments -
Add copy method on jax.Array
#13552 commented on
Dec 24, 2024 • 0 new comments -
Add random.binomial and random.multinomial
#13327 commented on
Dec 27, 2024 • 0 new comments -
support batched matrix multiplication in pallas
#21618 commented on
Dec 27, 2024 • 0 new comments -
Metal: missing functionality
#20375 commented on
Dec 30, 2024 • 0 new comments -
Build issues with local CUDA installation
#23689 commented on
Dec 31, 2024 • 0 new comments -
`jax` and `xarray` integration for automatic differentiation?
#17107 commented on
Jan 1, 2025 • 0 new comments -
Sharding is much slower than pmap for while loops of varying length while loops
#20968 commented on
Jan 2, 2025 • 0 new comments