|
26 | 26 | from llama.llama_transformer import InputManager, ModelArgs, Transformer
|
27 | 27 |
|
28 | 28 |
|
| 29 | +class SplitLinearModule(torch.nn.Module): |
| 30 | + def __init__(self, in_features, out_features, target_size): |
| 31 | + super(SplitLinearModule, self).__init__() |
| 32 | + self.num_splits = max(out_features // target_size, 1) |
| 33 | + self.common_size = out_features // self.num_splits |
| 34 | + self.remainder = out_features % self.num_splits |
| 35 | + self.splits = torch.nn.ModuleList( |
| 36 | + [ |
| 37 | + torch.nn.Linear(in_features, self.common_size) |
| 38 | + for _ in range(self.num_splits) |
| 39 | + ] |
| 40 | + ) |
| 41 | + if self.remainder > 0: |
| 42 | + self.splits.append(torch.nn.Linear(in_features, self.remainder)) |
| 43 | + |
| 44 | + def split_sizes(self): |
| 45 | + return [split.out_features for split in self.splits] |
| 46 | + |
| 47 | + def forward(self, x): |
| 48 | + return torch.cat([split(x) for split in self.splits], dim=-1) |
| 49 | + |
| 50 | + |
| 51 | +def replace_linear_with_split_linear(model, target_size): |
| 52 | + for name, module in model.named_children(): |
| 53 | + if isinstance(module, torch.nn.Linear): |
| 54 | + new_module = SplitLinearModule( |
| 55 | + module.in_features, module.out_features, target_size |
| 56 | + ) |
| 57 | + split_sizes = new_module.split_sizes() |
| 58 | + if module.bias is not None: |
| 59 | + split_bias = module.bias.split(split_sizes) |
| 60 | + split_weights = module.weight.split(split_sizes, dim=0) |
| 61 | + for i, split in enumerate(new_module.splits): |
| 62 | + split.weight = torch.nn.Parameter(split_weights[i]) |
| 63 | + if module.bias is not None: |
| 64 | + split.bias = torch.nn.Parameter(split_bias[i]) |
| 65 | + else: |
| 66 | + split.bias = None |
| 67 | + setattr(model, name, new_module) |
| 68 | + else: |
| 69 | + replace_linear_with_split_linear(module, target_size) |
| 70 | + |
| 71 | + |
| 72 | + |
29 | 73 | def main() -> None:
|
30 | 74 | parser = argparse.ArgumentParser()
|
31 | 75 | parser.add_argument(
|
@@ -80,6 +124,12 @@ def main() -> None:
|
80 | 124 | action="store_true",
|
81 | 125 | help="Use cache list to speed up model computation (does not work in pybindings)",
|
82 | 126 | )
|
| 127 | + parser.add_argument( |
| 128 | + "--target_size", |
| 129 | + type=int, |
| 130 | + default=None, |
| 131 | + help="Split linear layers into smaller chunks of target_size", |
| 132 | + ) |
83 | 133 |
|
84 | 134 | export_args = parser.parse_args()
|
85 | 135 | params_path = export_args.params
|
@@ -129,6 +179,9 @@ def main() -> None:
|
129 | 179 | packed=(bitwidth in [2, 4]),
|
130 | 180 | ).quantized_model()
|
131 | 181 |
|
| 182 | + if export_args.target_size is not None: |
| 183 | + replace_linear_with_split_linear(model, export_args.target_size) |
| 184 | + |
132 | 185 | model = model.to(float_dtype)
|
133 | 186 |
|
134 | 187 | op_linear_quantizer_config = None
|
@@ -184,6 +237,9 @@ def main() -> None:
|
184 | 237 | print("Edge program")
|
185 | 238 | print(edge_manager.exported_program())
|
186 | 239 |
|
| 240 | + for node in edge_manager.exported_program().graph_module.graph.nodes: |
| 241 | + print(node.name, node.target, node.args, node.kwargs) |
| 242 | + |
187 | 243 | edge_manager = edge_manager.to_backend(partitioner)
|
188 | 244 |
|
189 | 245 | print("Delegated program")
|
|
0 commit comments