Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jax compatible api #1207

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Conversation

sky-2002
Copy link

This PR adds a JAX compatible API, refer issue #1027

Looking forward for review and open to feedback(especially in writing tests).

Copy link
Contributor

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work so far!

Could you please add jax to benchmarks/bench_processors.py as well?

Also could you fill me in on how you intend to use a jax based model, just so I have context on how Outlines is being used here?

Thanks!

def is_jax_array_type(array_type):
return array_type == jaxlib.xla_extension.ArrayImpl or isinstance(
array_type, jaxlib.xla_extension.ArrayImpl
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the pattern of is_mlx_array_type, where we first check if it's importable, since it's possible jax isn't installed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

assert isinstance(torch_tensor, torch.Tensor)
assert torch.allclose(
torch_tensor, torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
)
Copy link
Contributor

@lapp0 lapp0 Oct 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad to see you have good test coverage.

This could be made much simpler with

arrays = {
    "list": get_list_array(),
    "np": get_np_array(),
    ...
]

@pytest.mark.parameterize("array_type", arrays.keys())
def test_to_torch(array_type, processor):
    array = arrays[array_type]
    ...

This way we can have full coverage for all array libraries, but only requiring three tests functions (test_from_torch(), test_to_torch(), test_call())

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, done

@sky-2002
Copy link
Author

Hey @lapp0 , thanks for the feedback.
Have added to benchmarks(though I haven't run those yet, will possibly do)

Regarding intent to use with JAX, @borisdayma(author of issue mentioned above) can help us understand jax based model usage with outlines.

PS: I currently have no use case with JAX, but I was trying out outlines and found it very interesting, so wanted to contribute.

Copy link
Contributor

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS: I currently have no use case with JAX, but I was trying out outlines and found it very interesting, so wanted to contribute.

Thanks for clarifying and your interest in the project.

Have added to benchmarks(though I haven't run those yet, will possibly do)

I'll add the run-benchmarks label, which will trigger their run once the workflow is approved.

Regarding intent to use with JAX, @borisdayma(author of issue mentioned above) can help us understand jax based model usage with outlines.

@borisdayma could you review and smoke test to ensure this fits your desired use case?

"jax": jnp.array([[1, 2], [3, 4]], dtype=jnp.float32),
"torch": torch.tensor([[1, 2], [3, 4]], dtype=torch.float32),
"mlx": mx.array([[1, 2], [3, 4]], dtype=mx.float32),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try:
    import mlx.core as mx
    arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32)
except ImportError:
    pass

same for jax

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

MLX_AVAILABLE = True
except ImportError:
MLX_AVAILABLE = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for this if you use the below comment

elif is_jax_array_type(target_type):
import jax

return jax.numpy.from_dlpack(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we need to use jax.numpy? Is there a reason to not just use jax.dlpack.from_dlpack?

Copy link
Author

@sky-2002 sky-2002 Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, we can just use jax.dlpack.from_dlpack , idk why autocomplete suggestions didn't show and why its there in line 117 , and I came across links that used jax.numpy.dlpack, though I will check whats difference it creates really

import jax

torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like))
return torch_tensor.cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't send to cuda.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@sky-2002 sky-2002 marked this pull request as ready for review October 15, 2024 05:19
@borisdayma
Copy link

borisdayma commented Oct 16, 2024

This looks great but I’ll still need some time to try it out.

The goal is to use it with repo’s such as:

Since the basic tests seem to work I expect it should be fine and could create a new issue if for some reason it prevents JAX compilation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants