Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchft.local_sgd import DiLoCo, LocalSGD
from torchft.manager import Manager
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
from torchft.process_group import ProcessGroupGloo, ProcessGroupNCCL
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -41,7 +41,10 @@ def state_dict() -> Dict[str, Dict[str, object]]:

print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")

pg = ProcessGroupGloo()
if device.type == "cuda":
pg = ProcessGroupBabyNCCL()
else:
pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
Expand Down Expand Up @@ -110,7 +113,12 @@ def diloco_train_loop(
# pyre-ignore[53]
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
m.load_state_dict(state_dict["model"])
m.to(device)
diloco.original_parameters = state_dict["original_params"]
for name in diloco.original_parameters.keys():
diloco.original_parameters[name] = diloco.original_parameters[name].to(
device
)
inner_optimizer.load_state_dict(state_dict["inner_optim"])
outer_optimizer.load_state_dict(state_dict["outer_optim"])

Expand All @@ -124,7 +132,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]

print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")

pg = ProcessGroupGloo()
if device.type == "cuda":
pg = ProcessGroupBabyNCCL()
else:
pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
Expand All @@ -138,6 +149,8 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
world_size=runner.world_size,
lighthouse_addr=runner.lighthouse_address,
port=19530 + runner.replica_id,
connect_timeout=timedelta(seconds=10),
quorum_timeout=timedelta(seconds=10),
timeout=timedelta(seconds=10),
# pyre-fixme[6]: Incompatible parameter type
**runner.manager_args,
Expand All @@ -155,6 +168,12 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
sync_every=2,
) as diloco:
while True:
manager_curr_step = manager.current_step()
if manager_curr_step not in all_state_dicts:
print(
f"{manager_curr_step=} {diloco._local_step=} {runner.replica_id=} {state_dict()=}"
)
all_state_dicts[manager_curr_step] = copy.deepcopy(state_dict())
batch_size = 1
inputs = m.get_rand_inputs(batch_size).to(device)
labels = m.get_rand_labels(batch_size).to(device)
Expand All @@ -164,7 +183,6 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]

inner_optimizer.zero_grad()
loss.backward()
all_state_dicts[str(manager.current_step())] = state_dict()
inner_optimizer.step()

# after 4 model updates then break
Expand All @@ -181,10 +199,15 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
class LocalSGDIntegTest(TestCase):
@parameterized.expand(
[
(True,),
(False,),
]
)
def test_local_sgd_recovery(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")

lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
Expand Down Expand Up @@ -236,10 +259,15 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:

@parameterized.expand(
[
(True,),
(False,),
]
)
def test_diloco_healthy(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")

lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
num_replicas = 2
futures = []
Expand Down Expand Up @@ -289,7 +317,17 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
check_device=False,
)

def test_diloco_recovery(self) -> None:
@parameterized.expand(
[
(True,),
(False,),
]
)
def test_diloco_recovery(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")

lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
Expand Down