-
Notifications
You must be signed in to change notification settings - Fork 48
/
acceleration_framework_config.py
318 lines (265 loc) · 11.4 KB
/
acceleration_framework_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Standard
from dataclasses import asdict, dataclass, fields, is_dataclass
from typing import Annotated, Dict, List, Type
import warnings
# Third Party
import yaml
# Local
from .attention_and_distributed_packing import MultiPack, PaddingFree
from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig
from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig
from tuning.utils.import_utils import is_fms_accelerate_available
if is_fms_accelerate_available():
# Third Party
from fms_acceleration import AccelerationFramework # pylint: disable=import-error
from fms_acceleration.framework import KEY_PLUGINS # pylint: disable=import-error
# these are optional annotations that describe different behavior
@dataclass
class ConfigAnnotation:
# AccelerationFramework configuration path
path: str
# if omitted, will take the field name
key: str = None
# only one that has single=True may exist under its path
# - this is used to indicate conflicting configurations
# - we do not allow two configurations that load the model to be
# activated at the same time
standalone: bool = False
# set to true to throw a user warning
experimental: bool = False
# set to indicate what acceeleration packages are needed
required_packages: List[str] = None
def __post_init__(self):
if self.required_packages is None:
self.required_packages = []
@dataclass
class AccelerationFrameworkConfig:
"Dataclass that manages configuration of AccelerationFramework"
PACKAGE_PREFIX = "fms_acceleration_"
# each field will a single-level use case dataclass
auto_gptq: Annotated[
AutoGPTQLoraConfig,
ConfigAnnotation(
path="peft.quantization", standalone=True, required_packages=["peft"]
),
] = None
bitsandbytes: Annotated[
BNBQLoraConfig,
ConfigAnnotation(
path="peft.quantization", standalone=True, required_packages=["peft"]
),
] = None
fused_lora: Annotated[
FusedLoraConfig,
ConfigAnnotation(
path="peft.quantization",
key="fused_ops_and_kernels",
experimental=False,
required_packages=["foak"],
),
] = None
fast_kernels: Annotated[
FastKernelsConfig,
ConfigAnnotation(
path="training",
key="fused_ops_and_kernels",
experimental=False,
required_packages=["foak"],
),
] = None
padding_free: Annotated[
PaddingFree,
ConfigAnnotation(
path="training.attention",
experimental=False,
required_packages=["aadp"],
),
] = None
multipack: Annotated[
MultiPack,
ConfigAnnotation(
path="training.dataloader",
experimental=False,
required_packages=["aadp"],
),
] = None
def _verify_configured_dataclasses(self):
if self.multipack is not None:
# ensure if multipack is set, padding free is also turned on as well
# this also ensures that the attention implementation for multipack
# will be flash attention as sfttrainer will enforce flash attn to be
# set for padding free
if self.padding_free is None:
raise ValueError(
"`--multipack` is currently only supported with `--padding_free`"
)
# Check that fused lora must be activated with either auto_gptq or bitsandbytes
if self.fused_lora is not None:
if self.bitsandbytes is None and self.auto_gptq is None:
raise ValueError(
"`--fused_lora` must be accompanied by a quantized base layer"
" `--auto_gptq` or `--bitsandbytes`."
)
@staticmethod
def from_dataclasses(*dataclasses: Type):
"Convert one or many FMS config dataclasses to a monolithic AccelerationConfig"
# Assumption: AccelerationFrameworkConfig only has fields that are
# single level dataclasses
# Assumption: dataclasses is a list of nested dataclasses
# - each dc in dataclasses is a nested dataclass.
# - each dc.field in dc is a non-nested dataclass.
if len(dataclasses) == 0:
raise ValueError(
"AccelerationFrameworkConfig construction requires at least one dataclass."
)
# first unroll all the dataclases into a single level
nested_dataclasses = []
for dc in dataclasses:
if dc is None:
continue
# make sure that it every field is a dataclass
for fi in fields(dc):
attr = getattr(dc, fi.name)
if attr is None:
continue # skip the None attributes
if not is_dataclass(attr):
raise ValueError(
f"field '{fi.name}' is specified but not a dataclass"
)
# NOTE: should we also check that these are non-nested
# dataclasses?
nested_dataclasses.append(attr)
config = AccelerationFrameworkConfig()
rem_fields = {fi.name: fi for fi in fields(config)} # these need to be parsed
# process the dataclasses that were nested
# by assumption these are non-nested dataclasses
for dc in nested_dataclasses:
# check the fields that are yet to be populated
found_field = False
for fi in rem_fields.values():
# check if it is an AccelerationFrameworkConfig field
if isinstance(dc, fi.type.__origin__):
found_field = True
break
if not found_field:
raise ValueError(
f"dataclass '{dc}' cannot be placed into AccelerationFrameworkConfig."
)
# assign the dataclass
setattr(config, fi.name, dc)
del rem_fields[fi.name] # remove the field
# perform some checks on dataclasse
config._verify_configured_dataclasses()
return config
def get_framework(self):
if is_fms_accelerate_available():
# to be eventually be made to be passed as a dict to Acceleration
# Framework
# Standard
from tempfile import ( # pylint: disable=import-outside-toplevel
NamedTemporaryFile,
)
try:
with NamedTemporaryFile("w") as f:
self.to_yaml(f.name)
return AccelerationFramework(f.name)
except ValueError as e:
(msg,) = e.args
# AcceleratorFramework raises ValueError if it
# fails to configure any plugin
if self.is_empty() and msg.startswith("No plugins could be configured"):
# in the case when the error was thrown when
# the acceleration framework config was empty
# then this is expected.
return None
raise e
else:
if not self.is_empty():
raise ValueError(
"No acceleration framework package found. To use, first "
"ensure that 'pip install fms-hf-tuning[fms-accel]' is done first to "
"obtain the acceleration framework dependency. Additional "
"acceleration plugins make be required depending on the requsted "
"acceleration. See README.md for instructions."
)
def is_empty(self):
"check if the configuration is empty"
for fi in fields(self):
if getattr(self, fi.name) is not None:
return False
return True
def to_dict(self):
"""convert a valid AccelerationFrameworkConfig dataclass into a schema-less dictionary
as dictated by the header annotations.
"""
# populate a dictionary
configuration_contents = {}
# helper function to populate
def _descend_and_set(path: List[str], d: Dict):
r = configuration_contents
for p in path[:-1]:
if p not in r:
r[p] = {} # new branch
r = r[p]
p = path[-1]
r[p] = {**r.get(p, {}), **d} # merge dict if exists
# parse each field
already_set = set()
for fi in fields(self):
datacls = getattr(self, fi.name)
if datacls is not None:
# this is the documented way to get annotations
# https://docs.python.org/3/library/typing.html#typing.Annotated
annotate: ConfigAnnotation
(annotate,) = fi.type.__metadata__
prefix_path = tuple(annotate.path.split("."))
if annotate.standalone and prefix_path in already_set:
raise ValueError(
f"Configuration path '{'.'.join(prefix_path)}' "
"already has one standalone config."
)
if annotate.experimental:
warnings.warn(
"An experimental acceleration feature is requested by specifying the "
f"'--{fi.name}' argument. Please note this feature may not support certain "
"edge cases at this juncture. When the feature matures this "
"message will be turned off."
)
if not all(
is_fms_accelerate_available(x) for x in annotate.required_packages
):
raise ValueError(
"An acceleration feature is requested by specifying the "
f"'--{fi.name}' argument, but the this requires acceleration packages "
"to be installed. Please do:\n"
+ "\n".join(
[
"- python -m fms_acceleration.cli install "
f"{AccelerationFrameworkConfig.PACKAGE_PREFIX + x}"
for x in annotate.required_packages
]
)
)
key = annotate.key if annotate.key is not None else fi.name
path = prefix_path + (key,)
already_set.add(prefix_path)
_descend_and_set(path, asdict(datacls))
return configuration_contents
def to_yaml(self, filename: str):
"convert a valid AccelerationConfig dataclass into a yaml"
configuration_contents = self.to_dict()
with open(filename, "w", encoding="utf-8") as f:
yaml.dump({KEY_PLUGINS: configuration_contents}, f)