Skip to content

Commit

Permalink
Merge pull request #1686 from Akegarasu/sd3
Browse files Browse the repository at this point in the history
fix: fix some distributed training error in windows
  • Loading branch information
kohya-ss authored Oct 12, 2024
2 parents 43bfeea + 9f4dac5 commit d005652
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
# from concurrent.futures import ThreadPoolExecutor, as_completed

from tqdm import tqdm
from packaging.version import Version

import torch
from library.device_utils import init_ipex, clean_memory_on_device
Expand Down Expand Up @@ -5077,17 +5078,18 @@ def prepare_accelerator(args: argparse.Namespace):
if args.torch_compile:
dynamo_backend = args.dynamo_backend

kwargs_handlers = (
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
(
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
),
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
kwargs_handlers = [
InitProcessGroupKwargs(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None,
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
) if torch.cuda.device_count() > 1 else None,
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
static_graph=args.ddp_static_graph
) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None
]
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)

accelerator = Accelerator(
Expand Down

0 comments on commit d005652

Please sign in to comment.