Skip to content

Commit eb482e5

Browse files
Merge branch 'main' into HTTP/2_Multi_Room_Lighthouse
2 parents fedd473 + 93c230b commit eb482e5

19 files changed

+891
-169
lines changed

.github/workflows/lint.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ jobs:
2323
2424
sudo apt-get install -y protobuf-compiler
2525
26-
# use RC build
27-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
28-
2926
pip install lintrunner lintrunner-adapters
3027
lintrunner init
3128

.github/workflows/unittest.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ jobs:
1515
- runs-on: "linux.2xlarge"
1616
gpu-arch-type: "cpu"
1717
gpu-arch-version: ""
18-
torch-version: "test"
18+
torch-version: "stable"
1919
- runs-on: "linux.g5.12xlarge.nvidia.gpu"
2020
gpu-arch-type: "cuda"
2121
gpu-arch-version: "12.4"
22-
torch-version: "test"
22+
torch-version: "stable"
2323
- runs-on: "linux.g5.12xlarge.nvidia.gpu"
2424
gpu-arch-type: "cuda"
2525
gpu-arch-version: "12.4"

README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,44 @@ for i in range(1000):
208208
optimizer.step()
209209
```
210210

211+
### Running DDP
212+
213+
After starting the lighthouse server by running:
214+
215+
```sh
216+
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000
217+
```
218+
219+
A test DDP script can be launched with torchX with:
220+
221+
```sh
222+
torchx run
223+
```
224+
225+
See [.torchxconfig](.torchxconfig), [torchx.py](./torchft/torchx.py) and the [torchX documentation](https://pytorch.org/torchx/latest/) to understand how DDP is being ran.
226+
227+
`torchx.py` could also launch HSDP jobs when `workers_per_replica` is set > 1, if the training script supports it. For an example HSDP training implementation with torchFT enabled, see [torchtitan](https://github.com/pytorch/torchtitan).
228+
229+
Alternatively, to test on a node with two GPUs, you can launch two replica groups running [train_ddp.py](./train_ddp.py) by:
230+
231+
On shell 1 (one replica groups starts initial training):
232+
```sh
233+
export REPLICA_GROUP_ID=0
234+
export NUM_REPLICA_GROUPS=2
235+
236+
CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp.py
237+
```
238+
239+
On shell 2 (a second replica group joins):
240+
```sh
241+
export REPLICA_GROUP_ID=1
242+
export NUM_REPLICA_GROUPS=2
243+
244+
CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp.py
245+
```
246+
247+
By observing the outputs from both shells, you should observe process group reconfiguration and live checkpoint recovery.
248+
211249
### Example Parameter Server
212250

213251
torchft has a fault tolerant parameter server implementation built on it's

proto/torchft.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ message QuorumMember {
4242
int64 step = 4;
4343
uint64 world_size = 5;
4444
bool shrink_only = 6;
45+
int64 commit_failures = 8;
4546
// User passing in data stored as JSON string.
4647
string data = 7;
4748
}
@@ -77,6 +78,7 @@ message ManagerQuorumRequest {
7778
string checkpoint_metadata = 3;
7879
bool shrink_only = 4;
7980
bool init_sync = 5;
81+
int64 commit_failures = 6;
8082
}
8183

8284
message ManagerQuorumResponse {
@@ -93,6 +95,7 @@ message ManagerQuorumResponse {
9395
int64 replica_rank = 9;
9496
int64 replica_world_size = 10;
9597
bool heal = 11;
98+
int64 commit_failures = 12;
9699
}
97100

98101
message CheckpointMetadataRequest {

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ classifiers = [
1212
]
1313
dynamic = ["version"]
1414
dependencies = [
15-
"torch"
15+
"torch>=2.7"
1616
]
1717

1818
[project.urls]

src/lib.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ use crate::torchftpb::lighthouse_service_server::LighthouseServiceServer;
4141
use crate::torchftpb::manager_service_client::ManagerServiceClient;
4242
use crate::torchftpb::LighthouseHeartbeatRequest;
4343
use crate::torchftpb::{
44-
CheckpointMetadataRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest,
44+
CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest,
45+
ManagerQuorumRequest, ShouldCommitRequest,
4546
};
4647
use pyo3::prelude::*;
4748
use pyo3::types::{PyDict, PyString};
@@ -182,6 +183,7 @@ impl ManagerClient {
182183
checkpoint_metadata: String,
183184
shrink_only: bool,
184185
init_sync: bool,
186+
commit_failures: i64,
185187
timeout: Duration,
186188
) -> Result<QuorumResult, StatusError> {
187189
py.allow_threads(move || {
@@ -191,6 +193,7 @@ impl ManagerClient {
191193
checkpoint_metadata: checkpoint_metadata,
192194
shrink_only: shrink_only,
193195
init_sync: init_sync,
196+
commit_failures: commit_failures,
194197
});
195198

196199
// This timeout is processed on the server side so we also enable
@@ -562,6 +565,7 @@ impl LighthouseClient {
562565
world_size: world_size,
563566
shrink_only: shrink_only,
564567
data: data_string,
568+
commit_failures: 0,
565569
}),
566570
});
567571

@@ -615,6 +619,7 @@ impl LighthouseClient {
615619
}
616620
req
617621
}
622+
618623
}
619624

620625
/// LighthouseServer is a GRPC server for the lighthouse service.
@@ -741,11 +746,17 @@ fn setup_logging() -> Result<(), Box<dyn std::error::Error>> {
741746
.debug(Color::Blue)
742747
.trace(Color::Magenta);
743748
let level_filter = match env::var("RUST_LOG").as_deref() {
744-
Ok("error") => LevelFilter::Error,
745-
Ok("warn") => LevelFilter::Warn,
746-
Ok("info") => LevelFilter::Info,
747-
Ok("debug") => LevelFilter::Debug,
748-
Ok("trace") => LevelFilter::Trace,
749+
Ok(value) => {
750+
let value_lower = value.to_lowercase();
751+
match value_lower.as_str() {
752+
"error" => LevelFilter::Error,
753+
"warn" => LevelFilter::Warn,
754+
"info" => LevelFilter::Info,
755+
"debug" => LevelFilter::Debug,
756+
"trace" => LevelFilter::Trace,
757+
_ => LevelFilter::Info,
758+
}
759+
}
749760
_ => LevelFilter::Info,
750761
};
751762
fern::Dispatch::new()

0 commit comments

Comments
 (0)