Skip to content

Add Torch support -- demonstration for Felafax #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.DS_Store
.hypothesis/
3 changes: 3 additions & 0 deletions PyTorch2Jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# PyTorch to Jax Conversion

This example illustrates how Benchify can be used to test the equivalence of two austensibly equivalent functions. In this case, the customer translated some code from PyTorch to Jax, and wanted to make sure that the two implementations were functionally equivalent. Thank you to [Nithin Santi](https://www.linkedin.com/in/nithinsonti/) for contributing this example!
72 changes: 72 additions & 0 deletions PyTorch2Jax/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
import numpy as np
import jax
import jax.numpy as jnp
from typing import Tuple

# No need to test this, assume it is correct
# Just a helper function for the apply_rotary_emb_torch function
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


"""
This file contains two functions: apply_rotary_emb_torch and apply_rotary_emb_jax.
The functions should be functionally equivalent.
"""

# This should be equivalent to the below function (apply_rotary_emb_jax)
def apply_rotary_emb_torch(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)

# This should be equivalent to the above function (apply_rotary_emb_torch)
def apply_rotary_emb_jax(
xq: jnp.ndarray,
xk: jnp.ndarray,
freqs_cis: jnp.ndarray,
dtype: jnp.dtype=jnp.float32, # This is the return type. Generally we will use jnp.float32.
) -> Tuple[jnp.ndarray, jnp.ndarray]:
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)

xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])

# add head dim
freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))

xq_out = xq_ * freqs_cis
xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)

xk_out = xk_ * freqs_cis
xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)

return xq_out.astype(dtype), xk_out.astype(dtype)

"""
---- HELPER FUNCTIONS ----

The following helper functions are meant to help with test writing.
Note that you can use torch.from_numpy(ndarray) → Tensor to get a torch
tensor from a numpy array.
"""
def jnp_ndarray_to_torch(x: jnp.ndarray) -> torch.Tensor:
return torch.from_numpy(x.astype(np.float32))

def torch_tensor_to_jnp(x: torch.Tensor) -> jnp.ndarray:
return x.cpu().numpy().astype(np.float32)


4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# benchify-examples
# Examples of Benchify in Action 🥷🏼

This repository contains interesting examples of how customers use Benchify to analyze their code. If you have an interesting use-case for us, please let us know!
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch
numpy
jax