Skip to content

Commit 7882427

Browse files
committed
added logging to examples
1 parent a13ffbd commit 7882427

File tree

4 files changed

+17
-3
lines changed

4 files changed

+17
-3
lines changed

scripts/examples/accelerate_train.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import annotations
1414

1515
import functools
16+
import logging
1617
import os
1718
from dataclasses import dataclass
1819
from pathlib import Path
@@ -27,6 +28,8 @@
2728

2829
import torchrunx
2930

31+
logging.basicConfig(level=logging.INFO)
32+
3033

3134
@dataclass
3235
class ModelConfig:
@@ -114,14 +117,18 @@ def main(
114117
output_dir: Path,
115118
):
116119
model = AutoModelForCausalLM.from_pretrained(model_config.name)
117-
train_dataset = load_training_data(tokenizer_name=model_config.name, dataset_config=dataset_config)
120+
train_dataset = load_training_data(
121+
tokenizer_name=model_config.name, dataset_config=dataset_config
122+
)
118123

119124
# Launch training
120125
results = launcher.run(train, model, train_dataset, batch_size, output_dir)
121126

122127
# Loading trained model from checkpoint
123128
checkpoint_path = results.rank(0)
124-
trained_model = AutoModelForCausalLM.from_pretrained(model_config.name, state_dict=torch.load(checkpoint_path))
129+
trained_model = AutoModelForCausalLM.from_pretrained(
130+
model_config.name, state_dict=torch.load(checkpoint_path)
131+
)
125132

126133

127134
if __name__ == "__main__":

scripts/examples/deepspeed_train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import functools
18+
import logging
1819
import os
1920
from dataclasses import dataclass
2021
from pathlib import Path
@@ -30,6 +31,8 @@
3031

3132
import torchrunx
3233

34+
logging.basicConfig(level=logging.INFO)
35+
3336

3437
@dataclass
3538
class DatasetConfig:

scripts/examples/lightning_train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
from __future__ import annotations
1515

1616
import functools
17+
import logging
1718
import os
1819
from dataclasses import dataclass
1920
from typing import Annotated
2021

2122
import lightning as L
2223
import torch
23-
2424
import tyro
2525
from datasets import load_dataset
2626
from torch.utils.data import Dataset
@@ -29,6 +29,8 @@
2929
import torchrunx
3030
from torchrunx.integrations.lightning import TorchrunxClusterEnvironment
3131

32+
logging.basicConfig(level=logging.INFO)
33+
3234

3335
@dataclass
3436
class ModelConfig:

scripts/examples/transformers_train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import annotations
1414

1515
import functools
16+
import logging
1617
import os
1718
from dataclasses import dataclass
1819
from typing import Annotated
@@ -30,6 +31,7 @@
3031

3132
import torchrunx
3233

34+
logging.basicConfig(level=logging.INFO)
3335

3436
@dataclass
3537
class ModelConfig:

0 commit comments

Comments
 (0)