Description
When the LLaMA 3.2-1B model is converted to fp16.pte using the -d fp16 parameter, why does the weight data type of the Linear layer become FP32 during the runtime?
convert command:
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint "/model_convert/Llama-3.2-1B/original/consolidated_00.pth" --params "/Llama-3.2-1B/original/params.json" --use_sdpa_with_kv_cache -X --xnnpack-extended-ops --output_name "llama3_2_fp16_direct_convert_runtime.pte" -kv -d fp16 --max_seq_length 256
runtime Linear weight dtype log:
@@@@@ kernel_value->datatype FP32, input_value->datatype FP16, output_value->datatype FP16
We print linear weight dtype in executorch/backends/xnnpack/third-party/XNNPACK/src/subgraph/fully-connected.c:1039
enum xnn_status xnn_define_fully_connected(xnn_subgraph_t subgraph,
float output_min, float output_max,
uint32_t input_id,
uint32_t filter_id, uint32_t bias_id,
uint32_t output_id, uint32_t flags)
.......
printf("@@@@@ kernel_value->datatype %s, input_value->datatype %s, output_value->datatype %s\n",
xnn_datatype_to_string(kernel_value->datatype),xnn_datatype_to_string(input_value->datatype),xnn_datatype_to_string(input_value->datatype));
cc @digantdesai @mcr229 @cbilgin @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
Status