-
Notifications
You must be signed in to change notification settings - Fork 213
/
test_mixed_precision.py
136 lines (110 loc) · 4.5 KB
/
test_mixed_precision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from pathlib import Path
import pytest
import torch
from finetune.args import LoraArgs
from finetune.loss import compute_loss_with_mask
from finetune.mixed_precision import (
downcast_mixed_precision,
prepare_mixed_precision,
upcast_mixed_precision,
)
from finetune.wrapped_model import load_model
from tests.test_utils import MODEL_PATH, get_dataloader, setup_mp_test_dist
from .test_utils import spawn_for_all_world_sizes
@pytest.mark.parametrize(
("world_size", "enable_lora"), [(1, False), (1, True), (2, False), (2, True)]
)
def test_mixed_precision(world_size, enable_lora):
spawn_for_all_world_sizes(
_check_mixed_precision,
world_sizes=[world_size],
args=[enable_lora],
deterministic=True,
)
def _check_mixed_precision(
rank: int, world_size: int, filename: str, filename_rpc: str, enable_lora: bool
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 100
folder = Path(MODEL_PATH)
# mixed precision
param_dtype = torch.bfloat16
optim_dtype = torch.float32
model = load_model(
folder=folder,
lora=LoraArgs(enable=enable_lora),
checkpoint=True,
param_dtype=param_dtype,
)
optimizer = torch.optim.AdamW(model.parameters())
# initialize mixed precision training for TP
prepare_mixed_precision(
model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype
)
data_loader = get_dataloader(seq_len=seq_len)
# ensure every parameter that requires a grad has a _mp_param of optim_dtype precision
for param in model.parameters():
assert param.dtype == param_dtype
if param.requires_grad:
assert param._mp_param.dtype == optim_dtype
assert (
param._mp_param.tolist() == param.data.to(optim_dtype).tolist()
), "mp param has to match param in optim dtype precision"
else:
assert not hasattr(param, "_mp_param")
# test three train steps
for _ in range(3):
optimizer.zero_grad()
# micro-batching
for _ in range(2):
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = (
torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
if batch.y_mask is not None
else None
)
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)
# ensure all params are upcasted correctly and mp param equals param
param_sum = 0
for param in model.parameters():
if param.requires_grad:
assert param.dtype == optim_dtype, param.dtype
assert (
param._mp_param.tolist() == param.data.tolist()
), "mp param and param should point to the same data"
assert param.grad.dtype == optim_dtype
assert param._temp.dtype == param_dtype
param_sum += param.data.float().abs().sum()
else:
assert param.dtype == param_dtype
optimizer.step()
# ensure that after optimizer step params are still in optim dtype precision
new_param_sum = 0
for param in model.parameters():
if param.requires_grad:
assert param.dtype == optim_dtype
assert param._mp_param.dtype == optim_dtype
assert param.grad.dtype == optim_dtype
new_param_sum += param.data.float().abs().sum()
else:
assert param.dtype == param_dtype
assert new_param_sum != param_sum, "Make sure parameters are updated"
downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)
# ensure that before new forward pass params are downcast to param dtype
for param in model.parameters():
assert param.dtype == param_dtype
if param.requires_grad:
assert param._mp_param.dtype == optim_dtype
assert param.grad.dtype == param_dtype
assert (
param._mp_param.to(param_dtype).tolist() == param.data.tolist()
), "mp param has to match param in optim dtype precision"