Skip to content

Add Torch support -- demonstration for Felafax #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 15, 2024
Merged

Conversation

maxvonhippel
Copy link
Collaborator

@maxvonhippel maxvonhippel commented Aug 15, 2024

Customer: Felafax

Demonstrates new feature: Torch support

Copy link

benchify bot commented Aug 15, 2024

🧪 Benchify Analysis of PR 20

AI Generated Summary:
The property-based testing for the apply_rotary_emb_torch and apply_rotary_emb_jax functions in the PyTorch2Jax/script.py file yielded mixed results.

  1. Some tests (denoted by ❌) indicated that the functions produced different outputs for the same inputs, suggesting discrepancies between the two implementations.
  2. Another batch of tests (denoted by 💥) encountered an unexpected exception when comparing the output of these functions with specific inputs, indicating potential robustness or edge-case handling issues in the code.

No timeouts (⏰) were encountered during the tests.

File Inputs Description
PyTorch2Jax/script.py xq=array([[[[0.0000000e+00, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.5000000e+00, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12]],, , [[1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12]],, , [[1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12]],, , [[1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12]],, , [[1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12]]],, , , [[[1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12]],, , [[1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12]],, , [[1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12]],, , [[1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12]],, , [[1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12],, [1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12,, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12, 1.0992432e+12]]]],, dtype=float32), xk=array([[[[-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01]],, , [[-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01]],, , [[-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01]],, , [[-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01]],, , [[-5.0000000e-01, 1.5000000e+00],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01]]],, , , [[[-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01]],, , [[-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01]],, , [[-5.0000000e-01, -5.0000000e-01],, [-3.8717056e+16, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01]],, , [[-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01]],, , [[-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, 0.0000000e+00],, [-5.0000000e-01, inf],, [-5.0000000e-01, -5.0000000e-01],, [-5.0000000e-01, -5.0000000e-01]]]], dtype=float32), freqs_cis=array([[ nan, nan, nan, nan,, nan, nan],, [ nan, nan, nan, inf,, inf, nan],, [ nan, inf, nan, nan,, nan, nan],, [ nan, nan, nan, -5.9604645e-08,, 1.1920929e-07, nan],, [ nan, 0.0000000e+00, nan, -5.0000000e-01,, nan, nan]], dtype=float32) ... and 3 more. The apply_rotary_emb_torch and apply_rotary_emb_jax functions produced different outputs for the same inputs.
💥 PyTorch2Jax/script.py xq=array([[[0., 0.],, [0., 0.]],, , [[0., 0.],, [0., 0.]]], dtype=float32), xk=array([[[0., 0.],, [0., 0.]],, , [[0., 0.],, [0., 0.]]], dtype=float32), freqs_cis=array([[inf, inf, inf, inf, inf, inf],, [inf, inf, inf, inf, inf, inf],, [inf, inf, inf, inf, inf, inf],, [inf, inf, inf, inf, inf, inf],, [inf, inf, inf, inf, inf, inf]], dtype=float32) ... and 571 more. Testing whether the apply_rotary_emb_torch and apply_rotary_emb_jax functions produce equivalent outputs for the same inputs caused an unexpected exception.
Reproducible Unit Tests
import unittest
import time

def benchify_test_apply_rotary_emb_equivalence(xq, xk, freqs_cis):
    xq_jax, xk_jax, freqs_jax = map(jnp.array, [xq, xk, freqs_cis])
    xq_torch, xk_torch, freqs_torch = map(torch.from_numpy, [xq, xk, freqs_cis])
    
    output_jax = apply_rotary_emb_jax(xq_jax, xk_jax, freqs_jax)
    output_torch = apply_rotary_emb_torch(xq_torch, xk_torch, freqs_torch)
    
    np.testing.assert_allclose(torch_tensor_to_jnp(output_torch[0]), output_jax[0], atol=1e-5)
    np.testing.assert_allclose(torch_tensor_to_jnp(output_torch[1]), output_jax[1], atol=1e-5)


# Property: The apply_rotary_emb_torch and apply_rotary_emb_jax functions produce equivalent outputs for the same inputs.


# The first batch of tests are failing.
def test_apply_rotary_emb_equivalence_failure_0():
    xq=array([[[0., 0.],
            [0., 0.]],
    
           [[0., 0.],
            [0., 0.]]], dtype=float32)
    xk=array([[[0., 0.],
            [0., 0.]],
    
           [[0., 0.],
            [0., 0.]]], dtype=float32)
    freqs_cis=array([[0., 0.],
           [0., 0.]], dtype=float32)
    benchify_test_apply_rotary_emb_equivalence(xq, xk, freqs_cis)

def test_apply_rotary_emb_equivalence_failure_1():
    xq=array([[[0., 0.],
            [0., 0.]],
    
           [[0., 0.],
            [0., 0.]]], dtype=float32)
    xk=array([[[0., 0.],
            [0., 0.]],
    
           [[0., 0.],
            [0., 0.]]], dtype=float32)
    freqs_cis=array([[inf, inf],
           [inf, inf]], dtype=float32)
    benchify_test_apply_rotary_emb_equivalence(xq, xk, freqs_cis)

def test_apply_rotary_emb_equivalence_failure_2():
    xq=array([[[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]],
    
            [[ 0., nan,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]],
    
    
           [[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]],
    
            [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]],
    
    
           [[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]],
    
            [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]],
    
    
           [[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]],
    
            [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]],
    
    
           [[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]],
    
            [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]],
    
    
           [[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]],
    
            [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
             [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]]], dtype=float32)
    xk=array([[[5.2039762e+11, 5.2039762e+11],
            [5.2039762e+11, 5.2039762e+11],
            [5.2039762e+11, 5.2039762e+11]],
    
           [[5.2039762e+11, 5.2039762e+11],
            [5.2039762e+11, 5.2039762e+11],
            [5.2039762e+11, 5.2039762e+11]]], dtype=float32)
    freqs_cis=array([[-2.7482141e+16,  1.4766473e+16,           -inf],
           [ 0.0000000e+00,  1.1754944e-38,            nan]], dtype=float32)
    benchify_test_apply_rotary_emb_equivalence(xq, xk, freqs_cis)

# Found 1 more failing tests.



# The next batch of tests all throw unexpected exceptions.
def test_apply_rotary_emb_equivalence_exception_0():
    xq=array([[[0., 0.],
            [0., 0.]],
    
           [[0., 0.],
            [0., 0.]]], dtype=float32)
    xk=array([[[0., 0.],
            [0., 0.]],
    
           [[0., 0.],
            [0., 0.]]], dtype=float32)
    freqs_cis=array([[inf, inf, inf],
           [inf, inf, inf]], dtype=float32)
    benchify_test_apply_rotary_emb_equivalence(xq, xk, freqs_cis)

def test_apply_rotary_emb_equivalence_exception_1():
    xq=array([[[0., 0.],
            [0., 0.],
            [0., 0.],
            [0., 0.]],
    
           [[0., 0.],
            [0., 0.],
            [0., 0.],
            [0., 0.]]], dtype=float32)
    xk=array([[[0., 0.],
            [0., 0.]],
    
           [[0., 0.],
            [0., 0.]]], dtype=float32)
    freqs_cis=array([[inf, inf, inf],
           [inf, inf, inf]], dtype=float32)
    benchify_test_apply_rotary_emb_equivalence(xq, xk, freqs_cis)

def test_apply_rotary_emb_equivalence_exception_2():
    xq=array([[[0., 0.],
            [0., 0.]],
    
           [[0., 0.],
            [0., 0.]]], dtype=float32)
    xk=array([[[0., 0.],
            [0., 0.]],
    
           [[0., 0.],
            [0., 0.]]], dtype=float32)
    freqs_cis=array([[inf, inf, inf],
           [inf, inf, inf],
           [inf, inf, inf],
           [inf, inf, inf],
           [inf, inf, inf],
           [inf, inf, inf],
           [inf, inf, inf],
           [inf, inf, -0.]], dtype=float32)
    benchify_test_apply_rotary_emb_equivalence(xq, xk, freqs_cis)

# Found 569 more tests that throw unexpected exceptions.

@maxvonhippel maxvonhippel merged commit 03607ea into main Aug 15, 2024
@maxvonhippel maxvonhippel deleted the first-example-v21 branch August 15, 2024 17:39
@maxvonhippel maxvonhippel changed the title First example v21 Add Torch support -- demonstration for Felafax Aug 21, 2024
maxvonhippel added a commit that referenced this pull request Mar 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant