Skip to content

Conversation

@dsikka
Copy link
Collaborator

@dsikka dsikka commented Nov 10, 2025

Summary

  • scale_dtype is no longer being overwritten with the observed_dtype if not already set. By doing this, we do not touch the quantization_args once initialized and assume None in our scheme corresponds to the default weight / observed dtype. With this, scale_dtype shows up as None if not set by the scheme in the config.json
  • If not running asym quantization, set zp_dtype to null during serialization

Testing:

Asym:

 "weights":  {
    "actorder": null,
    "block_structure": null,
    "dynamic": false,
    "group_size": 128,
    "num_bits": 4,
    "observer": "minmax",
    "observer_kwargs": {},
    "scale_dtype": null,
    "strategy": "group",
    "symmetric": false,
    "type": "int",
    "zp_dtype": "torch.int8"
 }

NVFp4:

 "input_activations": {
    "actorder": null,
    "block_structure": null,
    "dynamic": "local",
    "group_size": 16,
    "num_bits": 4,
    "observer": "static_minmax",
    "observer_kwargs": {},
    "scale_dtype": "torch.float8_e4m3fn",
    "strategy": "tensor_group",
    "symmetric": true,
    "type": "float",
    "zp_dtype": null
  },
  "output_activations": null,
  "targets": [
    "Linear"
  ],
  "weights": {
    "actorder": null,
    "block_structure": null,
    "dynamic": false,
    "group_size": 16,
    "num_bits": 4,
    "observer": "static_minmax",
    "observer_kwargs": {},
    "scale_dtype": "torch.float8_e4m3fn",
    "strategy": "tensor_group",
    "symmetric": true,
    "type": "float",
    "zp_dtype": null
 }

KV Cache:

"kv_cache_scheme": {
    "actorder": null,
    "block_structure": null,
    "dynamic": false,
    "group_size": null,
    "num_bits": 8,
    "observer": "minmax",
    "observer_kwargs": {},
    "scale_dtype": null,
    "strategy": "tensor",
    "symmetric": true,
    "type": "float",
    "zp_dtype": null
 },

@dsikka dsikka marked this pull request as ready for review November 10, 2025 17:40
@brian-dellabetta
Copy link
Collaborator

We want the torch. prefix in these fields?

"scale_dtype": "torch.float8_e4m3fn",

@dsikka
Copy link
Collaborator Author

dsikka commented Nov 10, 2025

We want the torch. prefix in these fields?

"scale_dtype": "torch.float8_e4m3fn",

We can go either way. Original config.json includes torch in its dtype: https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/blob/main/config.json#L22

@brian-dellabetta
Copy link
Collaborator

brian-dellabetta commented Nov 10, 2025

We can go either way. Original config.json includes torch in its dtype: https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/blob/main/config.json#L22

This looks like it does not include torch. prefix in dtype?

@dsikka
Copy link
Collaborator Author

dsikka commented Nov 10, 2025

We can go either way. Original config.json includes torch in its dtype: https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/blob/main/config.json#L22

This looks like it does not include torch. prefix in dtype?

I mean they include torch_type in the name to make it obvious the dtype is from torch

Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still many places in the code base where scale_dtype is read but not checked for nullability, right?

https://github.com/search?q=repo%3Avllm-project%2Fcompressed-tensors%20scale_dtype&type=code

@kylesayrs
Copy link
Collaborator

@dsikka
Copy link
Collaborator Author

dsikka commented Nov 10, 2025

https://github.com/search?q=repo%3Avllm-project%2Fcompressed-tensors%20scale_dtype&type=code

No. Did you look through the link you sent?
Most of those instances are test cases where the scale_dtype is being set and in other instances, we are checking if not None.

I can take a look if there's anything specific but sending a generic look-up isn't helpful here.

@dsikka dsikka requested a review from kylesayrs November 10, 2025 18:56
@brian-dellabetta
Copy link
Collaborator

I mean they include torch_type in the name to make it obvious the dtype is from torch

oh, i thought that's the deprecated field name from transfomers that causes all the warnings like

torch_dtype is deprecated! Use dtype instead!

either way, i think it's fine. they'll be loaded up the same anyway

@kylesayrs
Copy link
Collaborator

@dsikka I misread one of the checks. Here we assert that quantization_args.scale_dtype is not None, but it's fine to skip.

The other issue I'd like to call out is that maintaining the if scale_dtype not in [torch.bfloat16, ...] check means that in the case that the observed weight is not full precision, there will be a mismatch in the dtype initialized (float16) and the one generated by the observers (the dtype of the observed weight). Fine to skip for now.

@dsikka
Copy link
Collaborator Author

dsikka commented Nov 10, 2025

@dsikka I misread one of the checks. Here we assert that quantization_args.scale_dtype is not None, but it's fine to skip.

The other issue I'd like to call out is that maintaining the if scale_dtype not in [torch.bfloat16, ...] check means that in the case that the observed weight is not full precision, there will be a mismatch in the dtype initialized (float16) and the one generated by the observers (the dtype of the observed weight). Fine to skip for now.

If the weight isn't one of the dense dtypes, we likely have a bug as all weights remain dense until compression.

@kylesayrs
Copy link
Collaborator

kylesayrs commented Nov 10, 2025

@dsikka Yeah, it's base models with non-full-precision weights is definitely an edge case we can ignore for now. But syntactically, we should remove this case all together if we're not going to support it. Something for the future.

@dsikka dsikka merged commit ba35114 into main Nov 10, 2025
3 checks passed
@dsikka dsikka deleted the fix_serialization branch November 10, 2025 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants