Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update references to JAX's GitHub repo #25406

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/about.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ DeepMind](https://deepmind.google/), Alphabet more broadly,
and elsewhere.

At the heart of the project is the [JAX
core](http://github.com/google/jax) library, which focuses on the
core](http://github.com/jax-ml/jax) library, which focuses on the
fundamentals of machine learning and numerical computing, at scale.

When [developing](#development) the core, we want to maintain agility
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def rev(objective_fn, res, g):
def _maybe_perturbed(x: Any) -> bool:
# False if x can't represent an AD-perturbed value (i.e. a value
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
# See https://github.com/google/jax/issues/6415 for motivation.
# See https://github.com/jax-ml/jax/issues/6415 for motivation.
if not isinstance(x, core.Tracer):
# If x is not a Tracer, it can't be perturbed.
return False
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/colocated_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Colocated Python API."""

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570

# pylint: disable=useless-import-alias
from jax.experimental.colocated_python.api import (
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,7 +2108,7 @@ def apply_carry(x, i):
jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash

def testIssue804(self):
# https://github.com/google/jax/issues/804
# https://github.com/jax-ml/jax/issues/804
num_devices = jax.device_count()
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash
Expand Down
Loading