Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.cond is much slower than calling true_fn itself when cond is always true. #24259

Open
hanzhi713 opened this issue Oct 11, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@hanzhi713
Copy link

Reproducing example:

import os
import timeit
import jax
import jax.sharding
from jax.experimental.compute_on import compute_on
import jax.numpy as jnp

sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
p_sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host")


@compute_on("device_host")
@jax.jit
def host_fn(is_valid_step, gradient, opt_state):
    def true_fn(gradient, opt_state):
        opt_state = opt_state + jnp.sin(gradient)
        delta = opt_state * gradient
        return delta, opt_state
    
    def false_fn(gradient, opt_state):
        return gradient, opt_state
    
    # 4.7s per it
    return true_fn(gradient, opt_state)
    # 11.22s per it. Much slower!
    return jax.lax.cond(is_valid_step, true_fn, false_fn, gradient, opt_state)

def test_fn(gradient, opt_state):
    gradient = jnp.cos(gradient)
    return host_fn(jnp.abs(jnp.sum(gradient)) > 0, gradient, opt_state)


x = jnp.arange(0, 1024*1024*100, dtype=jnp.float32)
y = jnp.arange(0, 1024*1024*100, dtype=jnp.float32)
y = jax.device_put(y, p_sharding)

jit_fn = jax.jit(test_fn, in_shardings=(sharding, p_sharding), out_shardings=(sharding, p_sharding), donate_argnums=(0,1))
jit_fn = jit_fn.lower(x, y).compile()


def fn():
    global x, y
    x, y = jit_fn(x, y)
    jax.block_until_ready((x, y))


t = timeit.Timer(fn)
print(t.timeit(10))
@hanzhi713 hanzhi713 added the enhancement New feature or request label Oct 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant