Skip to content

Commit e6a82e9

Browse files
committed
[GDN] Deal with init on meta device
1 parent 9d6eac9 commit e6a82e9

File tree

2 files changed

+20
-72
lines changed

2 files changed

+20
-72
lines changed

fla/models/comba/modeling_comba.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,6 @@
2323
from fla.modules import RMSNorm
2424
from fla.modules.l2warp import l2_warp
2525

26-
try:
27-
from torch.distributed.tensor import DTensor
28-
except (ImportError, AttributeError):
29-
DTensor = None
30-
3126
if TYPE_CHECKING:
3227
from transformers.processing_utils import Unpack
3328

@@ -133,38 +128,17 @@ def _init_weights(
133128
prenorm_residual_strategy: Optional[str] = None,
134129
num_residuals_per_layer: int = 2,
135130
):
136-
if isinstance(module, Comba):
137-
138-
# --- A_log ---
139-
A = torch.empty(module.num_v_heads, dtype=torch.float32).uniform_(0, 16)
131+
if isinstance(module, Comba) and next(module.parameters()).device.type != 'meta':
140132
with torch.no_grad():
141-
if not isinstance(module.A_log, DTensor):
142-
module.A_log.copy_(torch.log(A))
143-
else:
144-
logger.warning_once("`A_log` is a DTensor, skipping initialization")
145-
module.A_log._no_weight_decay = True
146-
147-
# --- dt_bias ---
148-
# hard coded for now
149-
dt_min = 0.001
150-
dt_max = 0.1
151-
dt_init_floor = 1e-4
152-
dt = torch.exp(
153-
torch.rand(module.num_v_heads) * (math.log(dt_max) - math.log(dt_min))
154-
+ math.log(dt_min)
155-
)
156-
dt = torch.clamp(dt, min=dt_init_floor)
157-
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
158-
inv_dt = dt + torch.log(-torch.expm1(-dt))
159-
with torch.no_grad():
160-
if not isinstance(module.dt_bias, DTensor):
161-
module.dt_bias.copy_(inv_dt)
162-
else:
163-
logger.warning_once("`dt_bias` is a DTensor, skipping initialization")
164-
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
165-
# name.endswith("bias") in param_grouping.py
166-
module.dt_bias._no_weight_decay = True
167-
module.dt_bias._no_reinit = True
133+
module.A_log.copy_(nn.init.uniform_(module.A_log, a=0, b=16).log())
134+
module.A_log._no_weight_decay = True
135+
dt = torch.exp(
136+
nn.init.uniform_(module.f_proj[1].bias) * (math.log(0.1) - math.log(0.001)) + math.log(0.001)
137+
).clamp(min=1e-4)
138+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
139+
inv_dt = dt + torch.log(-torch.expm1(-dt))
140+
module.dt_bias.copy_(inv_dt)
141+
module.dt_bias._no_weight_decay = True
168142

169143
elif isinstance(module, (nn.Linear, nn.Conv1d)):
170144
# Slightly different from the TF version which uses truncated_normal for initialization

fla/models/gated_deltanet/modeling_gated_deltanet.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,6 @@
2323
from fla.modules import RMSNorm
2424
from fla.modules.l2warp import l2_warp
2525

26-
try:
27-
from torch.distributed.tensor import DTensor
28-
except (ImportError, AttributeError):
29-
DTensor = None
30-
3126
if TYPE_CHECKING:
3227
from transformers.processing_utils import Unpack
3328

@@ -134,38 +129,17 @@ def _init_weights(
134129
prenorm_residual_strategy: Optional[str] = None,
135130
num_residuals_per_layer: int = 2,
136131
):
137-
if isinstance(module, GatedDeltaNet):
138-
139-
# --- A_log ---
140-
A = torch.empty(module.num_v_heads, dtype=torch.float32).uniform_(0, 16)
132+
if isinstance(module, GatedDeltaNet) and next(module.parameters()).device.type != 'meta':
141133
with torch.no_grad():
142-
if not isinstance(module.A_log, DTensor):
143-
module.A_log.copy_(torch.log(A))
144-
else:
145-
logger.warning_once("`A_log` is a DTensor, skipping initialization")
146-
module.A_log._no_weight_decay = True
147-
148-
# --- dt_bias ---
149-
# hard coded for now
150-
dt_min = 0.001
151-
dt_max = 0.1
152-
dt_init_floor = 1e-4
153-
dt = torch.exp(
154-
torch.rand(module.num_v_heads) * (math.log(dt_max) - math.log(dt_min))
155-
+ math.log(dt_min)
156-
)
157-
dt = torch.clamp(dt, min=dt_init_floor)
158-
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
159-
inv_dt = dt + torch.log(-torch.expm1(-dt))
160-
with torch.no_grad():
161-
if not isinstance(module.dt_bias, DTensor):
162-
module.dt_bias.copy_(inv_dt)
163-
else:
164-
logger.warning_once("`dt_bias` is a DTensor, skipping initialization")
165-
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
166-
# name.endswith("bias") in param_grouping.py
167-
module.dt_bias._no_weight_decay = True
168-
module.dt_bias._no_reinit = True
134+
module.A_log.copy_(nn.init.uniform_(module.A_log, a=0, b=16).log())
135+
module.A_log._no_weight_decay = True
136+
dt = torch.exp(
137+
nn.init.uniform_(module.f_proj[1].bias) * (math.log(0.1) - math.log(0.001)) + math.log(0.001)
138+
).clamp(min=1e-4)
139+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
140+
inv_dt = dt + torch.log(-torch.expm1(-dt))
141+
module.dt_bias.copy_(inv_dt)
142+
module.dt_bias._no_weight_decay = True
169143

170144
elif isinstance(module, (nn.Linear, nn.Conv1d)):
171145
# Slightly different from the TF version which uses truncated_normal for initialization

0 commit comments

Comments
 (0)