Skip to content

Conversation

@rahul-tuli
Copy link
Contributor

@rahul-tuli rahul-tuli commented Mar 28, 2025

In the latest release, the removal of unexpected_keys filtering for compressed tensors models reintroduced warnings that were previously resolved in #36152. This PR addresses that regression, enhances the user experience for the run_compressed flag, and updates the test folder naming to avoid conflicts and align with conventions.

Changes and Objectives

This pull request accomplishes three key improvements:

  1. Restores Filtering of Unexpected Keys for Compressed Tensors Models

    • The removal of unexpected_keys filtering caused warnings to reappear when loading compressed tensors models. This PR reintroduces the necessary logic by adding unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys, prefix) in modeling_utils.py. This ensures unexpected keys are properly managed during model loading, eliminating warnings and restoring the behavior from #36152.
  2. Enhances User Experience for run_compressed Misconfiguration

    • Previously, setting run_compressed=True in unsupported cases (e.g., sparsified models or non-compressed quantized models) triggered a ValueError and halted execution. This PR improves this by:
      • Adding checks in quantizer_compressed_tensors.py to identify unsupported scenarios (is_sparsification_compressed or is_quantized without is_quantization_compressed).
      • Issuing a logger.warn message instead of raising an error, notifying users that run_compressed is unsupported for the given model type.
      • Automatically setting run_compressed=False in these cases, allowing the process to proceed gracefully.
    • This change enhances usability by replacing hard failures with warnings and safe fallbacks.
  3. Renames Test Folder to Avoid Name Collisions

    • The test folder tests/quantization/compressed_tensors has been renamed to tests/quantization/compressed_tensors_integration. This prevents potential name collisions when running pytest, ensuring smoother test execution. The new name also aligns with the naming conventions of other integration tests in the repository, improving consistency.

Impact

  • Fixes Regression: Eliminates unexpected keys warnings for compressed tensors models, ensuring a seamless loading experience.
  • Better UX: Replaces abrupt failures with warnings and automatic corrections for run_compressed, making the library more robust and user-friendly.
  • Improved Testing: Avoids test execution issues.

Files Modified

  • src/transformers/modeling_utils.py: Added update_unexpected_keys call to restore filtering.
  • src/transformers/quantizers/quantizer_compressed_tensors.py: Updated run_compressed logic with warnings and overrides.
  • tests/quantization/compressed_tensors/*: Renamed folder to compressed_tensors_integration (including __init__.py, test_compressed_models.py, and test_compressed_tensors.py).

Absolutely! Here's the full markdown with both the script and the output wrapped in collapsible <details> blocks for a clean and structured PR description:


Local Testing

This test verifies that compressed and uncompressed models can be loaded using AutoModelForCausalLM with various run_compressed settings. It also surfaces any warnings, decompression events, or fallbacks.

Loading Script
from transformers import AutoModelForCausalLM
from transformers.utils.quantization_config import CompressedTensorsConfig
import traceback

# List of model stubs to test
model_stubs = [
    "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed",
    "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
    "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
    "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
    "nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed",
    "nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
]

print("\n=== Starting Model Load Tests ===\n")

for stub in model_stubs:
    print("=" * 40)
    print(f"Testing model stub: {stub}")

    # Infer model style (check 'uncompressed' before 'compressed')
    style = "uncompressed" if "uncompressed" in stub else "compressed"

    for run_compressed in [True, False]:
        print(f"\n→ Attempting load with run_compressed={run_compressed} (model style: {style})")

        try:
            model = AutoModelForCausalLM.from_pretrained(
                stub,
                torch_dtype="auto",
                device_map="auto",
                quantization_config=CompressedTensorsConfig(run_compressed=run_compressed),
            )
            print(f"✓ Successfully loaded ({style}, run_compressed={run_compressed})")
        except Exception:
            print(f"✗ Failed to load ({style}, run_compressed={run_compressed})")
            print("Traceback:")
            traceback.print_exc()

    print(f"\nFinished testing: {stub}")
    print("=" * 40 + "\n")

print("=== ✅ All Tests Completed ===")

Test Output

=== Starting Model Load Tests ===

========================================
Testing model stub: nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed

→ Attempting load with run_compressed=True (model style: compressed)
2025-04-02 02:31:01.841947: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-02 02:31:01.882386: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-02 02:31:01.882424: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-02 02:31:01.883778: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-02 02:31:01.890147: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-04-02 02:31:02.655687: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/rahul/upstream-transformers/src/transformers/quantizers/auto.py:212: UserWarning: You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading already has a `quantization_config` attribute. The `quantization_config` from the model will be used.However, loading attributes (e.g. ['run_compressed']) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored.
  warnings.warn(warning_msg)
✓ Successfully loaded (compressed, run_compressed=True)

→ Attempting load with run_compressed=False (model style: compressed)
Decompressing model: 56it [00:00, 996.06it/s]
✓ Successfully loaded (compressed, run_compressed=False)

Finished testing: nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed
========================================

========================================
Testing model stub: nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed

→ Attempting load with run_compressed=True (model style: uncompressed)
`run_compressed` is only supported for compressed models.Setting `run_compressed=False`
✓ Successfully loaded (uncompressed, run_compressed=True)

→ Attempting load with run_compressed=False (model style: uncompressed)
✓ Successfully loaded (uncompressed, run_compressed=False)

Finished testing: nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed
========================================

========================================
Testing model stub: nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed

→ Attempting load with run_compressed=True (model style: compressed)
`run_compressed` is only supported for quantized_compressed models and not for sparsified models. Setting `run_compressed=False`
Decompressing model: 75it [00:00, 631.24it/s]
✓ Successfully loaded (compressed, run_compressed=True)

→ Attempting load with run_compressed=False (model style: compressed)
Decompressing model: 75it [00:00, 618.48it/s]
✓ Successfully loaded (compressed, run_compressed=False)

Finished testing: nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed
========================================

========================================
Testing model stub: nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed

→ Attempting load with run_compressed=True (model style: uncompressed)
✓ Successfully loaded (uncompressed, run_compressed=True)

→ Attempting load with run_compressed=False (model style: uncompressed)
✓ Successfully loaded (uncompressed, run_compressed=False)

Finished testing: nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed
========================================

========================================
Testing model stub: nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed

→ Attempting load with run_compressed=True (model style: compressed)
`run_compressed` is only supported for quantized_compressed models and not for sparsified models. Setting `run_compressed=False`
Decompressing model: 131it [00:00, 278.01it/s]
Decompressing model: 56it [00:00, 5845.43it/s]
✓ Successfully loaded (compressed, run_compressed=True)

→ Attempting load with run_compressed=False (model style: compressed)
Decompressing model: 131it [00:00, 878.16it/s]
Decompressing model: 56it [00:00, 12393.47it/s]
✓ Successfully loaded (compressed, run_compressed=False)

Finished testing: nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed
========================================

========================================
Testing model stub: nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed

→ Attempting load with run_compressed=True (model style: uncompressed)
`run_compressed` is only supported for compressed models.Setting `run_compressed=False`
✓ Successfully loaded (uncompressed, run_compressed=True)

→ Attempting load with run_compressed=False (model style: uncompressed)
✓ Successfully loaded (uncompressed, run_compressed=False)

Finished testing: nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed
========================================

=== ✅ All Tests Completed ===

Cases when run_compressed=True is not supported and overridden to False

  • Model is uncompressed
  • Model is sparse (both sparse and sparse quantized)
Model ID Run Compressed Overridden Columns
nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed True None
nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed False None
nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed True run_compressed set to False
nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed False None
nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed True run_compressed set to False
nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed False None
nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed True run_compressed set to False
nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed False None
nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed True run_compressed set to False
nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed False None
nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed True run_compressed set to False
nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed False None

Copy link

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the other changes in #36152 are all still on main, so this looks good to me!

@rahul-tuli rahul-tuli changed the title Restore Unexpected Keys Filtering and Improve run_compressed Handling for Compressed Tensors Models Fix: Unexpected Keys, Improve run_compressed, Rename Test Folder Mar 28, 2025
@rahul-tuli rahul-tuli marked this pull request as ready for review March 28, 2025 15:51
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a nit

Comment on lines 121 to 128
logger.warn(
"`run_compressed` is only supported for quantized_compressed models"
" and not for sparsified models. Setting `run_compressed=False`"
)
self.run_compressed = False
elif self.is_quantized and not self.is_quantization_compressed:
logger.warn("`run_compressed` is only supported for compressed models.Setting `run_compressed=False`")
self.run_compressed = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do these check in the config post_init method ? It will be better I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@SunMarc
Copy link
Member

SunMarc commented Apr 1, 2025

friendly ping @rahul-tuli

@rahul-tuli rahul-tuli force-pushed the fix-ct-model-loading branch from 0ee5778 to 8dd325f Compare April 2, 2025 02:35
@rahul-tuli
Copy link
Contributor Author

friendly ping @rahul-tuli

Apologies for the delayed response — I was investigating a weird warning that showed up after moving some logic to post_init.

The issue is that when we instantiate the quantization_config like this:

model = AutoModelForCausalLM.from_pretrained(
    stub,
    torch_dtype="auto",
    device_map="auto",
    quantization_config=CompressedTensorsConfig(run_compressed=run_compressed),
)

…the post_init() checks run too early — before the config has been merged with the values from config.json. This causes problems since other fields haven’t been fully resolved yet.

The fix was to move the post_init() call into the Quantizer's __init__, so it runs after config merging is complete.

The diff should be good to go now!

Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should verify serialization and deserialization with the condition change

@rahul-tuli rahul-tuli requested a review from dsikka April 2, 2025 22:45
Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM given we've tested:

  1. Compressed
  2. Quantized, not saved compressed
  3. Sparse-only
  4. Sparse + Quantized

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Just a nit

@SunMarc SunMarc merged commit ebe47ce into huggingface:main Apr 4, 2025
18 checks passed
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request Apr 5, 2025
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants