Skip to content

Commit

Permalink
fix to work with cpu_count() == 1 closes bmaltais#1134
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Feb 24, 2024
1 parent 488d187 commit f413201
Show file tree
Hide file tree
Showing 11 changed files with 22 additions and 22 deletions.
4 changes: 2 additions & 2 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers

train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers

train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
Expand Down
4 changes: 2 additions & 2 deletions tools/cache_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")

# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers

train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
Expand Down
4 changes: 2 additions & 2 deletions tools/cache_text_encoder_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")

# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers

train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
Expand Down
4 changes: 2 additions & 2 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers

train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
Expand Down
4 changes: 2 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
Expand Down
4 changes: 2 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ def train(self, args):
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers

train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
Expand Down
4 changes: 2 additions & 2 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def train(self, args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
Expand Down
4 changes: 2 additions & 2 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
Expand Down

0 comments on commit f413201

Please sign in to comment.