-
Notifications
You must be signed in to change notification settings - Fork 484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add jax compatible api #1207
base: main
Are you sure you want to change the base?
Add jax compatible api #1207
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work so far!
Could you please add jax to benchmarks/bench_processors.py
as well?
Also could you fill me in on how you intend to use a jax based model, just so I have context on how Outlines is being used here?
Thanks!
def is_jax_array_type(array_type): | ||
return array_type == jaxlib.xla_extension.ArrayImpl or isinstance( | ||
array_type, jaxlib.xla_extension.ArrayImpl | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please follow the pattern of is_mlx_array_type
, where we first check if it's importable, since it's possible jax isn't installed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
assert isinstance(torch_tensor, torch.Tensor) | ||
assert torch.allclose( | ||
torch_tensor, torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Glad to see you have good test coverage.
This could be made much simpler with
arrays = {
"list": get_list_array(),
"np": get_np_array(),
...
]
@pytest.mark.parameterize("array_type", arrays.keys())
def test_to_torch(array_type, processor):
array = arrays[array_type]
...
This way we can have full coverage for all array libraries, but only requiring three tests functions (test_from_torch()
, test_to_torch()
, test_call()
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, done
Hey @lapp0 , thanks for the feedback. Regarding intent to use with JAX, @borisdayma(author of issue mentioned above) can help us understand jax based model usage with PS: I currently have no use case with JAX, but I was trying out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PS: I currently have no use case with JAX, but I was trying out outlines and found it very interesting, so wanted to contribute.
Thanks for clarifying and your interest in the project.
Have added to benchmarks(though I haven't run those yet, will possibly do)
I'll add the run-benchmarks
label, which will trigger their run once the workflow is approved.
Regarding intent to use with JAX, @borisdayma(author of issue mentioned above) can help us understand jax based model usage with outlines.
@borisdayma could you review and smoke test to ensure this fits your desired use case?
"jax": jnp.array([[1, 2], [3, 4]], dtype=jnp.float32), | ||
"torch": torch.tensor([[1, 2], [3, 4]], dtype=torch.float32), | ||
"mlx": mx.array([[1, 2], [3, 4]], dtype=mx.float32), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try:
import mlx.core as mx
arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32)
except ImportError:
pass
same for jax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
MLX_AVAILABLE = True | ||
except ImportError: | ||
MLX_AVAILABLE = False | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for this if you use the below comment
elif is_jax_array_type(target_type): | ||
import jax | ||
|
||
return jax.numpy.from_dlpack(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we need to use jax.numpy
? Is there a reason to not just use jax.dlpack.from_dlpack
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, we can just use jax.dlpack.from_dlpack
, idk why autocomplete suggestions didn't show and why its there in line 117 , and I came across links that used jax.numpy.dlpack
, though I will check whats difference it creates really
import jax | ||
|
||
torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like)) | ||
return torch_tensor.cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't send to cuda.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
This looks great but I’ll still need some time to try it out. The goal is to use it with repo’s such as:
Since the basic tests seem to work I expect it should be fine and could create a new issue if for some reason it prevents JAX compilation. |
This PR adds a JAX compatible API, refer issue #1027
Looking forward for review and open to feedback(especially in writing tests).