Skip to content

Conversation

inho9606
Copy link
Collaborator

@inho9606 inho9606 commented Sep 25, 2025

Description

This PR is a prototype to load weight_scale values from the Qwen3 checkpoint.

  • first commit
    • Reads .safetensors file as a PT framework first since the "flax" framework Numpy does not support float8. And then it converts the tensor to jnp. As this commit modifies utility functions, it may cause unexpected errors by other models using the same utility functions.
  • Second function
    • Reads the weight_scale from weight files and save them in Qwen3ForCausalLM instance with a new attribute named 'quant_scales'.
    • I think it may better have them in nnx.Module with other layers, but it is quite complex.. As it is a prototyping, I implemented it with the easier way first.

FIXES: b/446023123

Tests

the following command runs the model on JAX path loading weight_scales:

python3 examples/offline_inference.py --model=RedHatAI/Qwen3-32B-FP8-dynamic --tensor_parallel_size=8 --task=generate --max_model_len=1024 --download_dir=/mnt/disks/persist

@inho9606 inho9606 closed this Sep 25, 2025
@inho9606 inho9606 reopened this Sep 25, 2025
@inho9606 inho9606 requested review from BirdsOfAFthr and kyuyeunk and removed request for kyuyeunk September 25, 2025 05:19
Signed-off-by: inho9606 <inhoseo@google.com>
Signed-off-by: inho9606 <inhoseo@google.com>
@inho9606 inho9606 force-pushed the prototyping_load_weight_scale_for_qwen3 branch from a4f5732 to ae345c3 Compare September 26, 2025 01:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant