Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add DataClass Arguments to Activate Padding-Free and MultiPack Plugin and FastKernels #280

Merged
Prev Previous commit
Next Next commit
modifications to dataclasses to support fast kernels on full finetuning
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Sep 20, 2024
commit 6d10f72357be4ae8d726b0b8e85b9b9e21f87696
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AccelerationFrameworkConfig:
fast_kernels: Annotated[
FastKernelsConfig,
ConfigAnnotation(
path="peft.quantization",
path="training",
key="fused_ops_and_kernels",
experimental=True,
required_packages=["foak"],
Expand Down Expand Up @@ -127,6 +127,15 @@ def _verify_configured_dataclasses(self):
self.padding_free is not None
), "`--multipack` is currently only supported with `--padding_free`"

# Check that fused lora must be activated with either auto_gptq or bitsandbytes
if self.fused_lora is not None:
assert (
self.bitsandbytes is not None or self.auto_gptq is not None
), "`--fused_lora` must be accompanied by a quantized base layer"\
" `--auto_gptq` or `--bitsandbytes`."



@staticmethod
def from_dataclasses(*dataclasses: Type):
"Convert one or many FMS config dataclasses to a monolithic AccelerationConfig"
Expand Down
17 changes: 1 addition & 16 deletions tuning/config/acceleration_configs/fused_ops_and_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,11 @@ class FastKernelsConfig(List):
fast_loss: bool = False

# fast rms norm triton kernels
fast_rsm_layernorm: bool = False
fast_rms_layernorm: bool = False

# fast RoPE embedding triton kernels
fast_rope_embeddings: bool = False

def __post_init__(self):

if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings:
raise ValueError(
"fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled "
"together. This restriction may be relaxed in the future."
)


@dataclass
Expand All @@ -77,14 +70,6 @@ class FusedOpsAndKernelsConfig:
# fast kernels
fast_kernels: FastKernelsConfig = None

def __post_init__(self):
if (self.fused_lora is not None and self.fast_kernels is None) or (
self.fused_lora is None and self.fast_kernels is not None
):
raise ValueError(
"fused lora and fast_kernels must be used together. "
"This restriction may be relaxed in the future."
)

# ensure nested dataclasses initialized
ensure_nested_dataclasses_initialized(self)