Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
06e7d2b
phase 1 and 2
avtc Dec 1, 2025
d0915cc
phase 3
avtc Dec 2, 2025
7e59bb3
fix for the case when :moe module does not have expert modules on cur…
avtc Dec 2, 2025
15c74bf
log activations collected, raise StopForward in named_module forward …
avtc Dec 2, 2025
5eff9af
optimization to prevent shared experts double activation
avtc Dec 2, 2025
6b7f190
fix circular dependency
avtc Dec 2, 2025
791192d
debug hooks not used
avtc Dec 2, 2025
c9e2c96
integrate to _run_forward_batches_parallel
avtc Dec 2, 2025
9b3b783
remove redundant overrides
avtc Dec 2, 2025
b17bd95
debug subset content
avtc Dec 2, 2025
48c00a7
cache subset_modules, refactor and fix confusion with names/modules i…
avtc Dec 2, 2025
a7032dc
debug
avtc Dec 2, 2025
8cbf894
refining forward_to_all_experts
avtc Dec 2, 2025
6d05f70
refactor get_experts_module_name
avtc Dec 2, 2025
77a8bfd
fix import
avtc Dec 2, 2025
00ce12d
optimization for _extract_moe_block_prefix, fix in case subset can ha…
avtc Dec 2, 2025
ce2a4a8
deduplicate moe_forward_wrapper
avtc Dec 2, 2025
10d8861
update _masked_pre_hook_wrapper to recent variant. revisit hooks_paus…
avtc Dec 2, 2025
169fa63
fix signature
avtc Dec 2, 2025
a76af2a
fix param missing
avtc Dec 3, 2025
d4bc837
remove comments and unused variables
avtc Dec 3, 2025
21b0d5e
fix indentation, and move after torch_sync
avtc Dec 3, 2025
1520b87
trying to fix for data-parallel
avtc Dec 3, 2025
c6697c3
fix indentation
avtc Dec 3, 2025
f6f4e1a
config option wait_for_layer_completion
avtc Dec 3, 2025
d175149
clean moe_contexts in finally
avtc Dec 3, 2025
6a1f90d
enable moe_lifecycle_hooks for qwen3_moe
avtc Dec 3, 2025
61cf213
trying to fix VRAM leak
avtc Dec 4, 2025
19e6b32
fix typo
avtc Dec 4, 2025
745944b
update model definitions to support pass_whole_dataset_to_each_expert…
avtc Dec 4, 2025
c1a6707
add warnings
avtc Dec 4, 2025
fcd174f
remove redundant log
avtc Dec 4, 2025
e27abcb
clear cache after forward pass
avtc Dec 4, 2025
38a5f78
handle replicated modules, ability to save vram on device 0, to defra…
avtc Dec 6, 2025
6a5405c
missing save change
avtc Dec 6, 2025
81b409c
skip iterating over all experts during replay forward
avtc Dec 7, 2025
27914f0
turn off auto-gc when vram_opt_memory_cleanup_on_stage_end is set
avtc Dec 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gptqmodel/looper/loop_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(

self.inputs_cache: InputCache = InputCache(None, None, None, None)
self.tasks = {}

self.pb = None
self.fwd_time = None
self.layer_count = None
Expand Down
417 changes: 318 additions & 99 deletions gptqmodel/looper/module_looper.py

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions gptqmodel/looper/named_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde
# persistent work state for named module (used by some LoopProcessors)
# store all `processed()` work state/data/result here
self.state = {}

# Forward hook mechanism (compatible with HookedLinear)
self.forward_hook = None
self.forward_hook_last = False

# print(f"NamedModule init: name: `{name}, full-name: `{full_name}`")

Expand Down Expand Up @@ -128,6 +132,18 @@ def __getattr__(self, name: str):

# setattr is always called by python even if attr exists in `self`
def __setattr__(self, name: str, value: Any) -> None:
# Proxy forward_hook to inner module if it supports it (e.g. HookedLinear)
if name in ["forward_hook", "forward_hook_last"]:
try:
module = object.__getattribute__(self, "module")
if hasattr(module, name):
setattr(module, name, value)
except AttributeError:
pass # module not set yet during __init__
# Also set on self for consistency
object.__setattr__(self, name, value)
return

if name in [
"module",
"module_dtype",
Expand All @@ -137,6 +153,7 @@ def __setattr__(self, name: str, value: Any) -> None:
"state",
"_parent_lock",
"target_device",
"target_device_stream",
"register_buffer",
"unregister_buffer",
"register_parameter",
Expand All @@ -155,6 +172,26 @@ def __setattr__(self, name: str, value: Any) -> None:
else:
with lock:
setattr(module, name, value)

def forward(self, *args, **kwargs):
"""Forward pass with optional hook support (compatible with HookedLinear)."""
output = self.module(*args, **kwargs)

# Call forward_hook if it exists and wasn't proxied to inner module
if self.forward_hook:
# Check if inner module has forward_hook (meaning we proxied it)
if not hasattr(self.module, 'forward_hook') or self.module.forward_hook is None:
# Extract first positional arg as input for hook
input_tensor = args[0] if args else None
self.forward_hook(self, (input_tensor,), output)

# If forward_hook_last is True, this should stop execution (like HookedLinear)
# The hook may raise StopForward, which should propagate
if self.forward_hook_last:
from ..nn_modules.hooked_linear import StopForward # Local import to avoid circular dependency
raise StopForward()

return output

def stream_state_payload_to_cpu(
self,
Expand Down
33 changes: 24 additions & 9 deletions gptqmodel/looper/stage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..utils.logger import log_time_block, setup_logger
from ..utils.model import find_modules, get_module
from ..utils.offload import offload_to_disk
from ..utils.torch import CPU, torch_sync
from ..utils.torch import CPU, torch_empty_cache, torch_sync
from .stage_subset import SubsetForwardContext, run_subset_stage

if TYPE_CHECKING: # pragma: no cover - type hints only
Expand Down Expand Up @@ -243,6 +243,9 @@ def run_layer_stage(
# )

try:
# Reset current subset for MoE lifecycle hooks as we do not need to collect activation at this stage,
# and need to collect only outputs produced by original forward
looper._current_subset = None
layer_outputs = looper._run_forward_batches(
module=module,
processor=processor,
Expand Down Expand Up @@ -501,18 +504,30 @@ def _drain_finalize_futures(
)

if finalize_futures_snapshot:
# Drain finalize futures asynchronously so the main loop can continue scheduling work.
threading.Thread(
target=_drain_finalize_futures,
args=(
if looper.gptq_model.quantize_config.vram_opt_memory_cleanup_on_stage_end:
# Synchronous: wait for all finalization to complete before proceeding to next layer
# This ensures all packing and writing tasks are done
_drain_finalize_futures(
[future for future, *_ in finalize_futures_snapshot],
finalize_pb,
finalize_count,
layer_index,
),
name="SubmoduleFinalizeWatcher",
daemon=True,
).start()
)
torch_empty_cache()
else:
# Asynchronous (current/default behavior): drain in background thread
# This allows next layer to start while current layer finalizes
threading.Thread(
target=_drain_finalize_futures,
args=(
[future for future, *_ in finalize_futures_snapshot],
finalize_pb,
finalize_count,
layer_index,
),
name="SubmoduleFinalizeWatcher",
daemon=True,
).start()
else:
looper._emit_layer_complete(
layer_idx=layer_index,
Expand Down
51 changes: 45 additions & 6 deletions gptqmodel/looper/stage_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..quantization.config import VRAMStrategy
from ..utils.device import get_device
from ..utils.logger import setup_logger
from ..utils.torch import torch_sync
from ..utils.torch import torch_empty_cache, torch_sync

if TYPE_CHECKING: # pragma: no cover - typing only
from .module_looper import ModuleLooper
Expand Down Expand Up @@ -195,7 +195,7 @@ def emit_subset_event(stage: str) -> None:
assignable_group_keys: List[str] = []
for group_key, module_names in expert_groups.items():
suffixes = {name.rsplit(".", 1)[-1] for name in module_names}
if {"gate_proj", "up_proj"}.issubset(suffixes):
if {"gate_proj", "up_proj"}.issubset(suffixes) or {"w1", "w3"}.issubset(suffixes):
assignable_group_keys.append(group_key)

if assignable_group_keys:
Expand All @@ -219,6 +219,8 @@ def emit_subset_event(stage: str) -> None:
for named_module in subset.values():
setattr(named_module, "moe_enabled", False)

subset_forward_serial = subset_forward_serial or looper.gptq_model.quantize_config.force_subset_forward_serial

handle = []

# some processes are simple and not require forward captures
Expand All @@ -238,6 +240,20 @@ def emit_subset_event(stage: str) -> None:
forward_row_counts.extend([1] * (batch_count - len(forward_row_counts)))

subset_size = len(subset)

# Determine MoE block name for hook selection
moe_block_name = None
if looper.gptq_model and hasattr(looper.gptq_model, 'moe_lifecycle_hooks'):
hooks = looper.gptq_model.moe_lifecycle_hooks
if hooks is not None:
moe_block = hooks.get_moe_block(module, looper.gptq_model.__class__)
if moe_block is not None:
# Get the full name/path of the MoE block
for mod_name, mod in module.named_modules():
if mod is moe_block:
moe_block_name = mod_name
break

for idx, (name, m) in enumerate(subset.items()):
# Register the forward hook that captures activations for quantization.
# The final module optionally flips a flag so processors can trigger
Expand All @@ -248,18 +264,31 @@ def emit_subset_event(stage: str) -> None:
hook_source = getattr(m, "name", name)
if hook_source is None:
hook_source = str(name)

# Determine if this module is part of MoE block (needs pre-hook to avoid StopForward)
is_moe_module = moe_block_name and name.startswith(moe_block_name + ".")

if hasattr(subset[name], 'forward_hook'):
original_hook = processor.pre_process_fwd_hook(name)
subset[name].forward_hook = looper._masked_hook_wrapper(processor, original_hook, hook_source)
# Use pre-hook for MoE modules to fire before StopForward
if is_moe_module:
subset[name].forward_hook = looper._masked_pre_hook_wrapper(processor, original_hook, hook_source)
else:
subset[name].forward_hook = looper._masked_hook_wrapper(processor, original_hook, hook_source)
enable_stop = processor.fwd_after_process or getattr(processor, "subset_forward_early_stop", False)
if is_last and enable_stop:
subset[name].forward_hook_last = True
else:
original_hook = processor.pre_process_fwd_hook(name)
handle.append(subset[name].register_forward_hook(
looper._masked_hook_wrapper(processor, original_hook, hook_source)
))
# Use pre-hook registration for MoE modules
if is_moe_module:
handle.append(subset[name].register_forward_hook(
looper._masked_pre_hook_wrapper(processor, original_hook, hook_source)
))
else:
handle.append(subset[name].register_forward_hook(
looper._masked_hook_wrapper(processor, original_hook, hook_source)
))

if DEBUG_ON and logger.isEnabledFor(logging.DEBUG):
if is_awq_processor:
Expand Down Expand Up @@ -311,6 +340,8 @@ def emit_subset_event(stage: str) -> None:
)

try:
# Set the current subset for MoE lifecycle hooks
looper._current_subset = subset
forward_outputs = looper._run_forward_batches(
module=module,
processor=processor,
Expand Down Expand Up @@ -368,6 +399,11 @@ def emit_subset_event(stage: str) -> None:
if hasattr(subset[name], 'forward_hook'):
subset[name].forward_hook = None
subset[name].forward_hook_last = False

if looper.gptq_model.quantize_config.vram_opt_memory_cleanup_on_stage_end:
torch_sync()
torch_empty_cache()

else:
if DEBUG_ON:
logger.debug(
Expand Down Expand Up @@ -508,6 +544,9 @@ def _process_on_worker(
processed_subset[name] = named_module
torch_sync()

if looper.gptq_model.quantize_config.vram_opt_memory_cleanup_on_stage_end:
torch_empty_cache()

emit_subset_event("quant_complete")

context = SubsetForwardContext(
Expand Down
Loading