|
53 | 53 | get_quant_embedding_transform,
|
54 | 54 | get_quant_weight_transform,
|
55 | 55 | )
|
56 |
| -from .source_transformation.quantized_kv_cache import ( |
57 |
| - replace_kv_cache_with_quantized_kv_cache, |
58 |
| -) |
| 56 | + |
| 57 | +# from .source_transformation.quantized_kv_cache import ( |
| 58 | +# replace_kv_cache_with_quantized_kv_cache, |
| 59 | +# ) |
59 | 60 | from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
|
60 | 61 |
|
61 | 62 | from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
|
62 |
| -from .source_transformation.sdpa import ( |
63 |
| - replace_causal_mask, |
64 |
| - replace_kv_cache_with_coreml_kv_cache, |
65 |
| - replace_kv_cache_with_simple_kv_cache, |
66 |
| - replace_sdpa_with_coreml_sdpa, |
67 |
| - replace_sdpa_with_custom_op, |
68 |
| - replace_sdpa_with_flex_sdpa, |
69 |
| - replace_sdpa_with_simple_sdpa, |
70 |
| -) |
| 63 | + |
| 64 | +# from .source_transformation.sdpa import ( |
| 65 | +# replace_causal_mask, |
| 66 | +# replace_kv_cache_with_coreml_kv_cache, |
| 67 | +# replace_kv_cache_with_simple_kv_cache, |
| 68 | +# replace_sdpa_with_coreml_sdpa, |
| 69 | +# replace_sdpa_with_custom_op, |
| 70 | +# replace_sdpa_with_flex_sdpa, |
| 71 | +# replace_sdpa_with_simple_sdpa, |
| 72 | +# ) |
71 | 73 |
|
72 | 74 | IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
|
73 | 75 | FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
|
@@ -910,23 +912,20 @@ def _get_source_transforms( # noqa
|
910 | 912 | assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
|
911 | 913 | transforms.append(replace_kv_cache_with_quantized_kv_cache)
|
912 | 914 |
|
| 915 | + if args.qnn: |
| 916 | + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` |
| 917 | + from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d |
| 918 | + |
| 919 | + # transforms.append(replace_kv_cache_with_simple_kv_cache) |
| 920 | + # transforms.append(replace_sdpa_with_flex_sdpa) |
| 921 | + # transforms.append(replace_causal_mask) |
| 922 | + transforms.append(replace_rms_norm_with_native_rms_norm) |
| 923 | + if args.optimized_rotation_path: |
| 924 | + transforms.append(fuse_layer_norms) |
| 925 | + transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) |
| 926 | + transforms.append(convert_linear_to_conv2d) |
913 | 927 | if args.use_kv_cache:
|
914 |
| - if args.qnn: |
915 |
| - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` |
916 |
| - from executorch.backends.qualcomm.utils.utils import ( |
917 |
| - convert_linear_to_conv2d, |
918 |
| - ) |
919 |
| - |
920 |
| - transforms.append(replace_kv_cache_with_simple_kv_cache) |
921 |
| - transforms.append(replace_sdpa_with_flex_sdpa) |
922 |
| - transforms.append(replace_causal_mask) |
923 |
| - transforms.append(replace_rms_norm_with_native_rms_norm) |
924 |
| - if args.optimized_rotation_path: |
925 |
| - transforms.append(fuse_layer_norms) |
926 |
| - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) |
927 |
| - transforms.append(convert_linear_to_conv2d) |
928 |
| - |
929 |
| - elif args.mps: |
| 928 | + if args.mps: |
930 | 929 | # Currently mps doesn't support sdpa op, use the simpler decomposition
|
931 | 930 | # to get free perf gain.
|
932 | 931 | transforms.append(replace_sdpa_with_simple_sdpa)
|
|
0 commit comments