Skip to content

Commit f55652e

Browse files
authored
Fix deepspeed for single GPU (#187)
1 parent 6f67efb commit f55652e

File tree

4 files changed

+33
-21
lines changed

4 files changed

+33
-21
lines changed

poetry.lock

Lines changed: 21 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sparse_autoencoder/train/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def train_autoencoder(
241241
] = torch.zeros(
242242
(self.n_components, self.n_learned_features),
243243
dtype=torch.int64,
244-
device=autoencoder_device,
244+
device=torch.device("cpu"),
245245
)
246246

247247
for store_batch in activations_dataloader:
@@ -274,7 +274,7 @@ def train_autoencoder(
274274
# Store count of how many neurons have fired
275275
with torch.no_grad():
276276
fired = learned_activations > 0
277-
learned_activations_fired_count.add_(fired.sum(dim=0))
277+
learned_activations_fired_count.add_(fired.sum(dim=0).cpu())
278278

279279
# Backwards pass
280280
total_loss.backward()

sparse_autoencoder/train/sweep.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def setup_autoencoder_optimizer_scheduler(
128128
model=model,
129129
optimizer=optim,
130130
lr_scheduler=lr_scheduler, # type: ignore
131+
config={
132+
"train_batch_size": hyperparameters["pipeline"]["train_batch_size"],
133+
},
131134
)
132135

133136
return (model_engine, optimizer_engine, scheduler) # type: ignore

sparse_autoencoder/train/utils/get_model_device.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Get the device that the model is on."""
2+
from deepspeed import DeepSpeedEngine
23
import torch
34
from torch.nn import Module
5+
from torch.nn.parallel import DataParallel
46

57

6-
def get_model_device(model: Module) -> torch.device:
8+
def get_model_device(model: Module | DataParallel | DeepSpeedEngine) -> torch.device:
79
"""Get the device on which a PyTorch model is on.
810
911
Args:
@@ -15,6 +17,10 @@ def get_model_device(model: Module) -> torch.device:
1517
Raises:
1618
ValueError: If the model has no parameters.
1719
"""
20+
# Deepspeed models already have a device property, so just return that
21+
if hasattr(model, "device"):
22+
return model.device
23+
1824
# Check if the model has parameters
1925
if len(list(model.parameters())) == 0:
2026
exception_message = "The model has no parameters."

0 commit comments

Comments
 (0)