[Weight-adapter/Trainer] Bypass forward mode in Weight adapter system#11958
Conversation
Test Evidence CheckIf this PR modifies behavior that requires testing, a test explanation is required. PRs lacking applicable test explanations may not be reviewed until added. Please add test explanations to ensure code quality and prevent regressions. If this PR changes user-facing behavior, visual proof (screen recording or screenshot) is required. PRs without applicable visual documentation may not be reviewed until provided. You can add it by:
|
|
I had to change the bypass.py to if device is not None:
# For quantized models (e.g., INT8), weights are int8 but computations are in float.
# Don't cast adapter weights to int8 as they need to remain in float for proper computation.
if dtype is not None and dtype in [torch.int8, torch.uint8]:
dtype = None
self._move_adapter_weights_to_device(device, dtype)So it works with my int8 node ( https://github.com/BobJohnson24/ComfyUI-Flux2-INT8 ). I hope this can be added! |
|
@BobJohnson24 I can add some check to ensure dtype is fp16/bf16/fp32 or pass None |
That would be amazing! |
In this pull-requests we proposed a different Weight Adapter system: bypass forward
Idea
In LoRA or similar PEFT stuff, they are usually be written as:
This allow us to have less compute in training (modify weight < do whole linear forward on large hidden state), also allow us to have less latency in inference (same model + same multiplier = no need to modify weight)
But this has a limitation: it cannot apply on quantized weight model
You should either do model merge before quantization, or you need to dequant then quant a gain, which is not 100% suitable considering the weight adapter is not trained QAT, so only PTQ method works.
Therefore, here is another modification which works well for quantized model for both inference/training:
For example: lora/loha/lokr have h(X) where it is "bypass path" for the weight diff forward, and OFT/BOFT/GLoRA have g() for "output modification"
What we did
In this PR we implement the "bypass forward" mode in both training and inference (currently only lora/lokr/oft have bypass mode training)
I also modify the trainer node to only output "lora state dict" and step/lose list, as suggested by @rattus128, as the returned modelpatcher is injected/patched by training module not inference module, use it in inference is not a good idea, so we remove the output of "trainer modelpatcher"