forked from ServiceNow/Fast-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_simple.py
81 lines (66 loc) · 2.22 KB
/
test_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import pytest
from tests.common import CONFIG_COMMON, CONFIG_FAST_LLM, TEST_MODEL, run_test_script
@pytest.mark.depends(
depends=[f"tests/test_match_megatron.py::test_{TEST_MODEL}_match_meg"],
)
def test_model_safe():
# The safest possible config, identical to the one in test_match_megatron except for the initialization.
run_test_script(
f"test_{TEST_MODEL}_safe",
CONFIG_FAST_LLM
+ [
"run.torch_dynamo_enable=False",
"schedule.data_overlap=False",
"model.base_model.transformer.dropless_moe=False",
],
)
@pytest.mark.depends(on=["test_model_safe"])
def test_model():
# A baseline config (single-gpu, bf16, flash-attn).
# Also tests for multiple data loaders.
run_test_script(
f"test_{TEST_MODEL}", CONFIG_COMMON + ["training.num_workers=2"], compare=f"test_{TEST_MODEL}_safe"
)
@pytest.mark.slow
@pytest.mark.depends(on=["test_model"])
def test_model_dp2():
# Simple data-parallel.
run_test_script(f"test_{TEST_MODEL}_dp2", CONFIG_COMMON, num_gpus=2, compare=f"test_{TEST_MODEL}")
@pytest.mark.slow
@pytest.mark.depends(on=["test_model"])
def test_model_tp2():
# Simple tensor-parallel.
run_test_script(
f"test_{TEST_MODEL}_tp2",
CONFIG_COMMON + ["model.distributed.tensor_parallel=2"],
num_gpus=2,
compare=f"test_{TEST_MODEL}",
)
@pytest.mark.depends(on=["test_model"])
def test_model_ce4():
# Cross-entropy splits.
run_test_script(
f"test_{TEST_MODEL}_ce4",
CONFIG_COMMON + ["model.base_model.cross_entropy_splits=4"],
compare=f"test_{TEST_MODEL}",
)
@pytest.mark.slow
@pytest.mark.depends(on=["test_model"])
def test_model_dp2_z2():
# Data-parallel with zero stage 2.
run_test_script(
f"test_{TEST_MODEL}_dp2_z2",
CONFIG_COMMON + ["model.multi_stage.zero_stage=2"],
num_gpus=2,
compare=f"test_{TEST_MODEL}",
)
@pytest.mark.slow
@pytest.mark.depends(on=["test_model"])
def test_model_dp2_z3():
# Data-parallel with zero stage 3.
run_test_script(
f"test_{TEST_MODEL}_dp2_z3",
CONFIG_COMMON + ["model.multi_stage.zero_stage=3"],
num_gpus=2,
compare=f"test_{TEST_MODEL}",
)