|
23 | 23 | from fla.modules import RMSNorm |
24 | 24 | from fla.modules.l2warp import l2_warp |
25 | 25 |
|
26 | | -try: |
27 | | - from torch.distributed.tensor import DTensor |
28 | | -except (ImportError, AttributeError): |
29 | | - DTensor = None |
30 | | - |
31 | 26 | if TYPE_CHECKING: |
32 | 27 | from transformers.processing_utils import Unpack |
33 | 28 |
|
@@ -134,38 +129,17 @@ def _init_weights( |
134 | 129 | prenorm_residual_strategy: Optional[str] = None, |
135 | 130 | num_residuals_per_layer: int = 2, |
136 | 131 | ): |
137 | | - if isinstance(module, GatedDeltaNet): |
138 | | - |
139 | | - # --- A_log --- |
140 | | - A = torch.empty(module.num_v_heads, dtype=torch.float32).uniform_(0, 16) |
| 132 | + if isinstance(module, GatedDeltaNet) and next(module.parameters()).device.type != 'meta': |
141 | 133 | with torch.no_grad(): |
142 | | - if not isinstance(module.A_log, DTensor): |
143 | | - module.A_log.copy_(torch.log(A)) |
144 | | - else: |
145 | | - logger.warning_once("`A_log` is a DTensor, skipping initialization") |
146 | | - module.A_log._no_weight_decay = True |
147 | | - |
148 | | - # --- dt_bias --- |
149 | | - # hard coded for now |
150 | | - dt_min = 0.001 |
151 | | - dt_max = 0.1 |
152 | | - dt_init_floor = 1e-4 |
153 | | - dt = torch.exp( |
154 | | - torch.rand(module.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) |
155 | | - + math.log(dt_min) |
156 | | - ) |
157 | | - dt = torch.clamp(dt, min=dt_init_floor) |
158 | | - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 |
159 | | - inv_dt = dt + torch.log(-torch.expm1(-dt)) |
160 | | - with torch.no_grad(): |
161 | | - if not isinstance(module.dt_bias, DTensor): |
162 | | - module.dt_bias.copy_(inv_dt) |
163 | | - else: |
164 | | - logger.warning_once("`dt_bias` is a DTensor, skipping initialization") |
165 | | - # Just to be explicit. Without this we already don't put wd on dt_bias because of the check |
166 | | - # name.endswith("bias") in param_grouping.py |
167 | | - module.dt_bias._no_weight_decay = True |
168 | | - module.dt_bias._no_reinit = True |
| 134 | + module.A_log.copy_(nn.init.uniform_(module.A_log, a=0, b=16).log()) |
| 135 | + module.A_log._no_weight_decay = True |
| 136 | + dt = torch.exp( |
| 137 | + nn.init.uniform_(module.f_proj[1].bias) * (math.log(0.1) - math.log(0.001)) + math.log(0.001) |
| 138 | + ).clamp(min=1e-4) |
| 139 | + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 |
| 140 | + inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| 141 | + module.dt_bias.copy_(inv_dt) |
| 142 | + module.dt_bias._no_weight_decay = True |
169 | 143 |
|
170 | 144 | elif isinstance(module, (nn.Linear, nn.Conv1d)): |
171 | 145 | # Slightly different from the TF version which uses truncated_normal for initialization |
|
0 commit comments