Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Llama 3.2 RoPE to CI #391

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 116 additions & 35 deletions ch05/07_gpt_to_llama/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,45 @@

@pytest.fixture(scope="module")
def notebook():
def import_definitions_from_notebook(fullname, names):
# Get the directory of the current test file
current_dir = os.path.dirname(__file__)
path = os.path.join(current_dir, "..", fullname + ".ipynb")
path = os.path.normpath(path)
def import_definitions_from_notebook(notebooks):
imported_modules = {}

# Load the notebook
if not os.path.exists(path):
raise FileNotFoundError(f"Notebook file not found at: {path}")
for fullname, names in notebooks.items():
# Get the directory of the current test file
current_dir = os.path.dirname(__file__)
path = os.path.join(current_dir, "..", fullname + ".ipynb")
path = os.path.normpath(path)

with io.open(path, "r", encoding="utf-8") as f:
nb = nbformat.read(f, as_version=4)
# Load the notebook
if not os.path.exists(path):
raise FileNotFoundError(f"Notebook file not found at: {path}")

# Create a module to store the imported functions and classes
mod = types.ModuleType(fullname)
sys.modules[fullname] = mod
with io.open(path, "r", encoding="utf-8") as f:
nb = nbformat.read(f, as_version=4)

# Go through the notebook cells and only execute function or class definitions
for cell in nb.cells:
if cell.cell_type == "code":
cell_code = cell.source
for name in names:
# Check for function or class definitions
if f"def {name}" in cell_code or f"class {name}" in cell_code:
exec(cell_code, mod.__dict__)
return mod
# Create a module to store the imported functions and classes
mod = types.ModuleType(fullname)
sys.modules[fullname] = mod

# Specify the notebook name and functions/classes to import
fullname = "converting-gpt-to-llama2"
names = ["precompute_rope_params", "compute_rope", "SiLU", "RMSNorm"]
# Go through the notebook cells and only execute function or class definitions
for cell in nb.cells:
if cell.cell_type == "code":
cell_code = cell.source
for name in names:
# Check for function or class definitions
if f"def {name}" in cell_code or f"class {name}" in cell_code:
exec(cell_code, mod.__dict__)

# Import the required functions and classes from the notebook
return import_definitions_from_notebook(fullname, names)
imported_modules[fullname] = mod

return imported_modules

notebooks = {
"converting-gpt-to-llama2": ["SiLU", "RMSNorm", "precompute_rope_params", "compute_rope"],
"converting-llama2-to-llama3": ["precompute_rope_params"]
}

return import_definitions_from_notebook(notebooks)


@pytest.fixture(autouse=True)
Expand All @@ -59,22 +65,25 @@ def set_seed():


def test_rope_llama2(notebook):

this_nb = notebook["converting-gpt-to-llama2"]

# Settings
batch_size = 1
context_len = 4096
num_heads = 4
head_dim = 16

# Instantiate RoPE parameters
cos, sin = notebook.precompute_rope_params(head_dim=head_dim, context_length=context_len)
cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len)

# Dummy query and key tensors
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)

# Apply rotary position embeddings
queries_rot = notebook.compute_rope(queries, cos, sin)
keys_rot = notebook.compute_rope(keys, cos, sin)
queries_rot = this_nb.compute_rope(queries, cos, sin)
keys_rot = this_nb.compute_rope(keys, cos, sin)

rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
Expand All @@ -93,6 +102,10 @@ def test_rope_llama2(notebook):


def test_rope_llama3(notebook):

nb1 = notebook["converting-gpt-to-llama2"]
nb2 = notebook["converting-llama2-to-llama3"]

# Settings
batch_size = 1
context_len = 8192
Expand All @@ -101,19 +114,20 @@ def test_rope_llama3(notebook):
theta_base = 50_000

# Instantiate RoPE parameters
cos, sin = notebook.precompute_rope_params(
cos, sin = nb2.precompute_rope_params(
head_dim=head_dim,
context_length=context_len,
theta_base=theta_base
)

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)

# Apply rotary position embeddings
queries_rot = notebook.compute_rope(queries, cos, sin)
keys_rot = notebook.compute_rope(keys, cos, sin)
queries_rot = nb1.compute_rope(queries, cos, sin)
keys_rot = nb1.compute_rope(keys, cos, sin)

rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
Expand All @@ -131,16 +145,83 @@ def test_rope_llama3(notebook):
torch.testing.assert_close(queries_rot, ref_queries_rot)


def test_rope_llama3_12(notebook):

nb1 = notebook["converting-gpt-to-llama2"]
nb2 = notebook["converting-llama2-to-llama3"]

# Settings
batch_size = 1
context_len = 8192
num_heads = 4
head_dim = 16
rope_theta = 50_000

rope_config = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
}

# Instantiate RoPE parameters
cos, sin = nb2.precompute_rope_params(
head_dim=head_dim,
theta_base=rope_theta,
context_length=context_len,
freq_config=rope_config,
)

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)

# Apply rotary position embeddings
queries_rot = nb1.compute_rope(queries, cos, sin)
keys_rot = nb1.compute_rope(keys, cos, sin)

hf_rope_params = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
}

class RoPEConfig:
rope_type = "llama3"
rope_scaling = hf_rope_params
factor = 1.0
dim: int = head_dim
rope_theta = 50_000
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads

config = RoPEConfig()

rot_emb = LlamaRotaryEmbedding(config=config)
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)

torch.testing.assert_close(sin, ref_sin.squeeze(0))
torch.testing.assert_close(cos, ref_cos.squeeze(0))
torch.testing.assert_close(keys_rot, ref_keys_rot)
torch.testing.assert_close(queries_rot, ref_queries_rot)


def test_silu(notebook):
example_batch = torch.randn(2, 3, 4)
silu = notebook.SiLU()
silu = notebook["converting-gpt-to-llama2"].SiLU()
assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))


@pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
def test_rmsnorm(notebook):
example_batch = torch.randn(2, 3, 4)
rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
rms_norm = notebook["converting-gpt-to-llama2"].RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)

assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))