Skip to content

Commit

Permalink
Update some outdated syntax in FFI tutorial.
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Nov 12, 2024
1 parent 9bb6366 commit f757054
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 104 deletions.
35 changes: 15 additions & 20 deletions docs/ffi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
"In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n",
"We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n",
"\n",
"This tutorial comes with two supplementary files:\n",
"\n",
"* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and\n",
"* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.\n",
"The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).\n",
"\n",
"## A simple example\n",
"\n",
Expand Down Expand Up @@ -101,7 +98,7 @@
"\n",
"To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n",
"For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).\n",
"The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here:\n",
"The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:\n",
"\n",
"```c++\n",
"#include <functional>\n",
Expand Down Expand Up @@ -129,12 +126,11 @@
"// A wrapper function providing the interface between the XLA FFI call and our\n",
"// library function `ComputeRmsNorm` above. This function handles the batch\n",
"// dimensions by calling `ComputeRmsNorm` within a loop.\n",
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,\n",
" ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {\n",
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,\n",
" ffi::ResultBuffer<ffi::F32> y) {\n",
" auto [totalSize, lastDim] = GetDims(x);\n",
" if (lastDim == 0) {\n",
" return ffi::Error(ffi::ErrorCode::kInvalidArgument,\n",
" \"RmsNorm input must be an array\");\n",
" return ffi::Error::InvalidArgument(\"RmsNorm input must be an array\");\n",
" }\n",
" for (int64_t n = 0; n < totalSize; n += lastDim) {\n",
" ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));\n",
Expand All @@ -149,8 +145,8 @@
" RmsNorm, RmsNormImpl,\n",
" ffi::Ffi::Bind()\n",
" .Attr<float>(\"eps\")\n",
" .Arg<ffi::Buffer<ffi::DataType::F32>>() // x\n",
" .Ret<ffi::Buffer<ffi::DataType::F32>>() // y\n",
" .Arg<ffi::Buffer<ffi::F32>>() // x\n",
" .Ret<ffi::Buffer<ffi::F32>>() // y\n",
");\n",
"```\n",
"\n",
Expand All @@ -173,8 +169,7 @@
"Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n",
"In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.\n",
"\n",
"To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble.\n",
"The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt)."
"To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble."
]
},
{
Expand Down Expand Up @@ -433,7 +428,7 @@
"1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n",
"2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n",
"\n",
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end.\n",
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n",
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
"\n",
"This custom derivative rule can be wired in as follows:"
Expand Down Expand Up @@ -508,16 +503,16 @@
"When defining our FFI wrapper for CPU, the function signature that we used was:\n",
"\n",
"```c++\n",
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,\n",
" ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)\n",
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,\n",
" ffi::ResultBuffer<ffi::F32> y)\n",
"```\n",
"\n",
"To update this to interface with a CUDA kernel, this signature becomes:\n",
"\n",
"```c++\n",
"ffi::Error RmsNormImpl(cudaStream_t stream, float eps,\n",
" ffi::Buffer<ffi::DataType::F32> x,\n",
" ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)\n",
" ffi::Buffer<ffi::F32> x,\n",
" ffi::ResultBuffer<ffi::F32> y)\n",
"```\n",
"\n",
"And the handler definition is updated to include a `Ctx` in its binding:\n",
Expand All @@ -528,8 +523,8 @@
" ffi::Ffi::Bind()\n",
" .Ctx<ffi::PlatformStream<cudaStream_t>>()\n",
" .Attr<float>(\"eps\")\n",
" .Arg<ffi::Buffer<ffi::DataType::F32>>() // x\n",
" .Ret<ffi::Buffer<ffi::DataType::F32>>() // y\n",
" .Arg<ffi::Buffer<ffi::F32>>() // x\n",
" .Ret<ffi::Buffer<ffi::F32>>() // y\n",
");\n",
"```\n",
"\n",
Expand Down
33 changes: 14 additions & 19 deletions docs/ffi.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ JAX's FFI support is provided in two parts:
In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.
We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.

This tutorial comes with two supplementary files:

* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and
* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.
The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).

## A simple example

Expand Down Expand Up @@ -96,7 +93,7 @@ and, for our example, this is the function that we want to expose to JAX via the
To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).
For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).
The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here:
The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:
```c++
#include <functional>
Expand Down Expand Up @@ -124,12 +121,11 @@ std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNorm input must be an array");
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
Expand All @@ -144,8 +140,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
```

Expand All @@ -166,7 +162,6 @@ Now that we have our minimal FFI wrapper implemented, we need to expose this fun
In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.

To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble.
The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt).

```{code-cell} ipython3
:tags: [hide-output]
Expand Down Expand Up @@ -357,7 +352,7 @@ In this case, we actually define two new FFI calls:
1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass.
2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.

We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end.
We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.
The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.

This custom derivative rule can be wired in as follows:
Expand Down Expand Up @@ -422,16 +417,16 @@ Since this documentation page is automatically generated on a machine without ac
When defining our FFI wrapper for CPU, the function signature that we used was:

```c++
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y)
```
To update this to interface with a CUDA kernel, this signature becomes:
```c++
ffi::Error RmsNormImpl(cudaStream_t stream, float eps,
ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)
ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y)
```

And the handler definition is updated to include a `Ctx` in its binding:
Expand All @@ -442,8 +437,8 @@ XLA_FFI_DEFINE_HANDLER(
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
```

Expand Down
56 changes: 25 additions & 31 deletions docs/ffi/rms_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,11 @@ std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNorm input must be an array");
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
Expand All @@ -75,17 +74,16 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);

ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> res) {
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y,
ffi::ResultBuffer<ffi::F32> res) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormFwd input must be an array");
return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
Expand All @@ -94,13 +92,12 @@ ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNormFwd, RmsNormFwdImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
.Ret<ffi::Buffer<ffi::DataType::F32>>() // res
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
.Ret<ffi::Buffer<ffi::F32>>() // res
);

void ComputeRmsNormBwd(int64_t size, float res, const float *x,
Expand All @@ -115,14 +112,12 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x,
}
}

ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::DataType::F32> res,
ffi::Buffer<ffi::DataType::F32> x,
ffi::Buffer<ffi::DataType::F32> ct_y,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> ct_x) {
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x,
ffi::Buffer<ffi::F32> ct_y,
ffi::ResultBuffer<ffi::F32> ct_x) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormBwd inputs must be arrays");
return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),
Expand All @@ -131,11 +126,10 @@ ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::DataType::F32> res,
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNormBwd, RmsNormBwdImpl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::DataType::F32>>() // res
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Arg<ffi::Buffer<ffi::DataType::F32>>() // ct_y
.Ret<ffi::Buffer<ffi::DataType::F32>>() // ct_x
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::F32>>() // res
.Arg<ffi::Buffer<ffi::F32>>() // x
.Arg<ffi::Buffer<ffi::F32>>() // ct_y
.Ret<ffi::Buffer<ffi::F32>>() // ct_x
);
6 changes: 3 additions & 3 deletions examples/ffi/src/jax_ffi_example/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace nb = nanobind;
namespace ffi = xla::ffi;

ffi::Error ArrayAttrImpl(ffi::Span<const int32_t> array,
ffi::Result<ffi::BufferR0<ffi::S32>> res) {
ffi::ResultBufferR0<ffi::S32> res) {
int64_t total = 0;
for (int32_t x : array) {
total += x;
Expand All @@ -37,8 +37,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl,
.Ret<ffi::BufferR0<ffi::S32>>());

ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs,
ffi::Result<ffi::BufferR0<ffi::S32>> secret,
ffi::Result<ffi::BufferR0<ffi::S32>> count) {
ffi::ResultBufferR0<ffi::S32> secret,
ffi::ResultBufferR0<ffi::S32> count) {
auto maybe_secret = attrs.get<int64_t>("secret");
if (maybe_secret.has_error()) {
return maybe_secret.error();
Expand Down
Loading

0 comments on commit f757054

Please sign in to comment.