|
1 | 1 | import time |
2 | 2 | from itertools import count |
3 | | -from typing import Dict, List, Optional, Tuple, Union |
| 3 | +from typing import Dict, List, Optional, Tuple, Type, Union |
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | import torch |
@@ -64,7 +64,7 @@ def __init__( |
64 | 64 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], |
65 | 65 | inference_config: InferenceConfig, |
66 | 66 | verbose: bool = False, |
67 | | - model_policy: Policy = None, |
| 67 | + model_policy: Union[Policy, Type[Policy]] = None, |
68 | 68 | ) -> None: |
69 | 69 | self.inference_config = inference_config |
70 | 70 | self.dtype = inference_config.dtype |
@@ -105,7 +105,7 @@ def __init__( |
105 | 105 |
|
106 | 106 | self._verify_args() |
107 | 107 |
|
108 | | - def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): |
| 108 | + def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): |
109 | 109 | """ |
110 | 110 | Shard model or/and Load weight |
111 | 111 |
|
@@ -150,11 +150,17 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy |
150 | 150 | ) |
151 | 151 |
|
152 | 152 | if model_policy is None: |
153 | | - if self.inference_config.pad_input: |
154 | | - model_type = "padding_" + self.model_config.model_type |
155 | | - else: |
156 | | - model_type = "nopadding_" + self.model_config.model_type |
157 | | - model_policy = model_policy_map[model_type]() |
| 153 | + prefix = "nopadding" if not self.inference_config.pad_input else "padding" |
| 154 | + model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" |
| 155 | + model_policy = model_policy_map.get(model_policy_key) |
| 156 | + |
| 157 | + if not isinstance(model_policy, Policy): |
| 158 | + try: |
| 159 | + model_policy = model_policy() |
| 160 | + except Exception as e: |
| 161 | + raise ValueError(f"Unable to instantiate model policy: {e}") |
| 162 | + |
| 163 | + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" |
158 | 164 | pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) |
159 | 165 | tp_group = pg_mesh.get_group_along_axis(TP_AXIS) |
160 | 166 |
|
|
0 commit comments