|
| 1 | +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. |
| 2 | + |
| 3 | +# pyre-strict |
| 4 | + |
| 5 | +import argparse |
| 6 | +import json |
| 7 | + |
| 8 | +import sys |
| 9 | + |
| 10 | +import coremltools as ct |
| 11 | +import torch |
| 12 | +from executorch.backends.apple.coreml.compiler import CoreMLBackend # pyre-ignore |
| 13 | +from executorch.backends.apple.coreml.partition import CoreMLPartitioner # pyre-ignore |
| 14 | +from executorch.examples.models.llama.source_transformation.quantize import ( |
| 15 | + EmbeddingQuantHandler, |
| 16 | +) |
| 17 | + |
| 18 | +from executorch.exir.backend.utils import format_delegated_graph |
| 19 | +from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig |
| 20 | +from executorch.exir.passes import MemoryPlanningPass |
| 21 | +from executorch.exir.passes.quant_fusion_pass import QuantFusionPass |
| 22 | +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass |
| 23 | +from executorch.extension.export_util.utils import export_to_edge, save_pte_program |
| 24 | + |
| 25 | +sys.path.insert(0, ".") |
| 26 | +from llama_transformer import InputManager, ModelArgs, Transformer |
| 27 | + |
| 28 | + |
| 29 | +class SplitLinearModule(torch.nn.Module): |
| 30 | + def __init__(self, in_features, out_features, target_split_size, max_splits): |
| 31 | + super(SplitLinearModule, self).__init__() |
| 32 | + num_splits = max(out_features // target_split_size, 1) |
| 33 | + if num_splits > max_splits: |
| 34 | + num_splits = max_splits |
| 35 | + |
| 36 | + self.split_size = out_features // num_splits |
| 37 | + self.split_remainder = out_features % num_splits |
| 38 | + self.splits = torch.nn.ModuleList( |
| 39 | + [torch.nn.Linear(in_features, self.split_size) for _ in range(num_splits)] |
| 40 | + ) |
| 41 | + print( |
| 42 | + f"Splitting out_features={out_features} into {num_splits} of size {self.split_size}" |
| 43 | + ) |
| 44 | + if self.split_remainder > 0: |
| 45 | + print( |
| 46 | + f"Warning: remainder {self.split_remainder} after splitting out_features={out_features} into {num_splits} of size {self.split_size}" |
| 47 | + ) |
| 48 | + self.splits.append(torch.nn.Linear(in_features, self.split_remainder)) |
| 49 | + |
| 50 | + def split_sizes(self): |
| 51 | + return [split.out_features for split in self.splits] |
| 52 | + |
| 53 | + def forward(self, x): |
| 54 | + return torch.cat([split(x) for split in self.splits], dim=-1) |
| 55 | + |
| 56 | + |
| 57 | +def replace_linear_with_split_linear(model, target_split_size, max_splits): |
| 58 | + for name, module in model.named_children(): |
| 59 | + if isinstance(module, torch.nn.Linear): |
| 60 | + new_module = SplitLinearModule( |
| 61 | + module.in_features, module.out_features, target_split_size, max_splits |
| 62 | + ) |
| 63 | + split_sizes = new_module.split_sizes() |
| 64 | + if module.bias is not None: |
| 65 | + split_bias = module.bias.split(split_sizes) |
| 66 | + split_weights = module.weight.split(split_sizes, dim=0) |
| 67 | + for i, split in enumerate(new_module.splits): |
| 68 | + split.weight = torch.nn.Parameter(split_weights[i]) |
| 69 | + if module.bias is not None: |
| 70 | + split.bias = torch.nn.Parameter(split_bias[i]) |
| 71 | + else: |
| 72 | + split.bias = None |
| 73 | + setattr(model, name, new_module) |
| 74 | + else: |
| 75 | + replace_linear_with_split_linear(module, target_split_size, max_splits) |
| 76 | + |
| 77 | + |
| 78 | +def main() -> None: |
| 79 | + parser = argparse.ArgumentParser() |
| 80 | + parser.add_argument( |
| 81 | + "-n", |
| 82 | + "--output_name", |
| 83 | + default="model.pte", |
| 84 | + help="Override the output filename of the saved pte model file.", |
| 85 | + ) |
| 86 | + parser.add_argument( |
| 87 | + "-p", |
| 88 | + "--params", |
| 89 | + help="config.json", |
| 90 | + ) |
| 91 | + parser.add_argument( |
| 92 | + "-c", |
| 93 | + "--checkpoint", |
| 94 | + help="checkpoint path", |
| 95 | + ) |
| 96 | + parser.add_argument( |
| 97 | + "--seq_length", |
| 98 | + type=int, |
| 99 | + default=1, |
| 100 | + help="length sequence to evaluate", |
| 101 | + ) |
| 102 | + parser.add_argument( |
| 103 | + "--max_seq_length", |
| 104 | + type=int, |
| 105 | + default=128, |
| 106 | + help="maximum length sequence to evaluate", |
| 107 | + ) |
| 108 | + parser.add_argument( |
| 109 | + "--cache_size", |
| 110 | + type=int, |
| 111 | + default=None, |
| 112 | + help="Cache size. Old items are evicted from cache", |
| 113 | + ) |
| 114 | + parser.add_argument( |
| 115 | + "-E", |
| 116 | + "--embedding-quantize", |
| 117 | + default=None, |
| 118 | + type=str, |
| 119 | + help="type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '8,1024'.", |
| 120 | + ) |
| 121 | + parser.add_argument( |
| 122 | + "--coreml-quantize", |
| 123 | + default=None, |
| 124 | + choices=["b4w", "c4w"], |
| 125 | + help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)", |
| 126 | + ) |
| 127 | + parser.add_argument( |
| 128 | + "--use_cache_list", |
| 129 | + action="store_true", |
| 130 | + help="Use cache list to speed up model computation (does not work in pybindings)", |
| 131 | + ) |
| 132 | + parser.add_argument( |
| 133 | + "--target_split_size", |
| 134 | + type=int, |
| 135 | + default=None, |
| 136 | + help="Split linear layers into smaller chunks of target_split_size.", |
| 137 | + ) |
| 138 | + parser.add_argument( |
| 139 | + "--max_splits", |
| 140 | + type=int, |
| 141 | + default=8, |
| 142 | + help="Maximum number of splits to divide linear layers", |
| 143 | + ) |
| 144 | + |
| 145 | + export_args = parser.parse_args() |
| 146 | + params_path = export_args.params |
| 147 | + checkpoint_path = export_args.checkpoint |
| 148 | + |
| 149 | + # Load model args |
| 150 | + with open(params_path, "r") as f: |
| 151 | + params = json.loads(f.read()) |
| 152 | + |
| 153 | + args = ModelArgs( |
| 154 | + max_seq_len=export_args.max_seq_length, |
| 155 | + generate_full_logits=False, |
| 156 | + use_cache_list=export_args.use_cache_list, |
| 157 | + **params, |
| 158 | + ) |
| 159 | + |
| 160 | + with torch.device("meta"): |
| 161 | + model = Transformer(args) |
| 162 | + |
| 163 | + checkpoint = torch.load( |
| 164 | + checkpoint_path, map_location="cpu", mmap=True, weights_only=True |
| 165 | + ) |
| 166 | + if "model" in checkpoint: |
| 167 | + checkpoint = checkpoint["model"] |
| 168 | + |
| 169 | + missing, unexpected = model.load_state_dict( |
| 170 | + checkpoint, |
| 171 | + strict=False, |
| 172 | + assign=True, |
| 173 | + ) |
| 174 | + print("Missing keys: ", missing) |
| 175 | + print("Unexpected keys: ", unexpected) |
| 176 | + |
| 177 | + float_dtype = torch.float16 # dtype for model/inputs |
| 178 | + model.eval() |
| 179 | + model.to(float_dtype) |
| 180 | + |
| 181 | + if export_args.embedding_quantize: |
| 182 | + bitwidth, group_size = export_args.embedding_quantize.split(",") |
| 183 | + if group_size == "none" or group_size == "None" or group_size == "0": |
| 184 | + group_size = None |
| 185 | + else: |
| 186 | + group_size = int(group_size) |
| 187 | + bitwidth = int(bitwidth) |
| 188 | + model = EmbeddingQuantHandler( |
| 189 | + model, |
| 190 | + bitwidth=bitwidth, |
| 191 | + group_size=group_size, |
| 192 | + packed=(bitwidth in [2, 4]), |
| 193 | + ).quantized_model() |
| 194 | + |
| 195 | + if export_args.target_split_size is not None: |
| 196 | + replace_linear_with_split_linear( |
| 197 | + model, export_args.target_split_size, export_args.max_splits |
| 198 | + ) |
| 199 | + |
| 200 | + model = model.to(float_dtype) |
| 201 | + |
| 202 | + op_linear_quantizer_config = None |
| 203 | + if export_args.coreml_quantize == "b4w": |
| 204 | + op_linear_quantizer_config = { |
| 205 | + "mode": "linear_symmetric", |
| 206 | + "dtype": "int4", |
| 207 | + "granularity": "per_block", |
| 208 | + "block_size": 32, |
| 209 | + "weight_threshold": 512, |
| 210 | + } |
| 211 | + elif export_args.coreml_quantize == "c4w": |
| 212 | + op_linear_quantizer_config = { |
| 213 | + "mode": "linear_symmetric", |
| 214 | + "dtype": "int4", |
| 215 | + "granularity": "per_channel", |
| 216 | + } |
| 217 | + |
| 218 | + compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] |
| 219 | + minimum_deployment_target=ct.target.iOS18, |
| 220 | + compute_precision=ct.precision(ct.precision.FLOAT16.value), |
| 221 | + compute_unit=ct.ComputeUnit.CPU_AND_NE, |
| 222 | + model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16] |
| 223 | + op_linear_quantizer_config=op_linear_quantizer_config, |
| 224 | + ) |
| 225 | + partitioner = CoreMLPartitioner( # pyre-fixme[16] |
| 226 | + compile_specs=compile_specs, |
| 227 | + take_over_mutable_buffer=False, |
| 228 | + skip_ops_for_coreml_delegation=[ |
| 229 | + "quantized_decomposed.embedding_4bit.dtype", |
| 230 | + "aten.embedding.default", |
| 231 | + ], |
| 232 | + ) |
| 233 | + |
| 234 | + input_manager = InputManager( |
| 235 | + n_layers=args.n_layers, |
| 236 | + max_batch_size=args.max_batch_size, |
| 237 | + n_kv_heads=args.n_kv_heads, |
| 238 | + max_seq_length=args.max_seq_len, |
| 239 | + head_dim=args.head_dim, |
| 240 | + use_cache_list=export_args.use_cache_list, |
| 241 | + seq_length=export_args.seq_length, |
| 242 | + dtype=float_dtype, |
| 243 | + minus_infinity=-30000, |
| 244 | + cache_size=export_args.cache_size, |
| 245 | + ) |
| 246 | + example_inputs = input_manager.get_inputs(tokens=[0]) |
| 247 | + |
| 248 | + edge_manager = export_to_edge( |
| 249 | + model, |
| 250 | + example_inputs, |
| 251 | + edge_compile_config=EdgeCompileConfig( |
| 252 | + _check_ir_validity=False, |
| 253 | + _skip_type_promotion=(float_dtype == torch.float16), |
| 254 | + _skip_dim_order=True, |
| 255 | + ), |
| 256 | + ) |
| 257 | + print("Edge program") |
| 258 | + print(edge_manager.exported_program()) |
| 259 | + |
| 260 | + for node in edge_manager.exported_program().graph_module.graph.nodes: |
| 261 | + print(node.name, node.target, node.args, node.kwargs) |
| 262 | + |
| 263 | + edge_manager = edge_manager.to_backend(partitioner) |
| 264 | + |
| 265 | + print("Delegated program") |
| 266 | + |
| 267 | + print(format_delegated_graph(edge_manager.exported_program().graph_module)) |
| 268 | + |
| 269 | + executorch_program = edge_manager.to_executorch( |
| 270 | + ExecutorchBackendConfig( |
| 271 | + extract_delegate_segments=True, |
| 272 | + passes=[ |
| 273 | + QuantFusionPass(), |
| 274 | + ], |
| 275 | + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), |
| 276 | + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), |
| 277 | + ) |
| 278 | + ) |
| 279 | + |
| 280 | + filename = save_pte_program(executorch_program, export_args.output_name) |
| 281 | + print(f"Saved Executorch program to local {filename}") |
| 282 | + |
| 283 | + |
| 284 | +if __name__ == "__main__": |
| 285 | + main() # pragma: no cover |
0 commit comments