Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【hydra No.17】tempoGAN #592

Merged
merged 9 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix
  • Loading branch information
co63oc committed Oct 24, 2023
commit ed7aa7f694a9f3e0491e974f472e8fe86163ca21
34 changes: 32 additions & 2 deletions examples/tempoGAN/conf/tempogan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,42 @@ TILE_RATIO: 1

# model settings
MODEL:
input_keys: ["input_gen"]
output_keys: ["output_gen"]
gen_net:
input_keys: ["input_gen"] # 'NCHW'
output_keys: ["output_gen"]
in_channel: 1
out_channels_tuple: [[2, 8, 8], [128, 128, 128], [32, 8, 8], [2, 1, 1]]
kernel_sizes_tuple: [[[5, 5], [5, 5], [1, 1]], [[5, 5], [5, 5], [1, 1]], [[5, 5], [5, 5], [1, 1]], [[5, 5], [5, 5], [1, 1]]]
strides_tuple: [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]
use_bns_tuple: [[true, true, true], [true, true, true], [true, true, true], [False, False, False]]
acts_tuple: [['relu', null, null], ['relu', null, null], ['relu', null, null], ['relu', null, null]]
disc_net:
input_keys: ['input_disc_from_target', 'input_disc_from_gen'] # 'NCHW'
output_keys: ['out0_layer0', 'out0_layer1', 'out0_layer2', 'out0_layer3', 'out_disc_from_target', 'out1_layer0', 'out1_layer1', 'out1_layer2', 'out1_layer3', 'out_disc_from_gen']
in_channel: 2
out_channels: [32, 64, 128, 256]
fc_channel: 1048576
kernel_sizes: [[4, 4], [4, 4], [4, 4], [4, 4]]
strides: [2, 2, 2, 1]
use_bns: [false, true, true, true]
acts: ['leaky_relu', 'leaky_relu', 'leaky_relu', 'leaky_relu', null]
tempo_net:
input_keys: ['input_tempo_disc_from_target', 'input_tempo_disc_from_gen'] # 'NCHW'
output_keys: ['out0_tempo_layer0', 'out0_tempo_layer1', 'out0_tempo_layer2', 'out0_tempo_layer3', 'out_disc_tempo_from_target', 'out1_tempo_layer0', 'out1_tempo_layer1', 'out1_tempo_layer2', 'out1_tempo_layer3', 'out_disc_tempo_from_gen']
in_channel: 3
out_channels: [32, 64, 128, 256]
fc_channel: 1048576
kernel_sizes: [[4, 4], [4, 4], [4, 4], [4, 4]]
strides: [2, 2, 2, 1]
use_bns: [false, true, true, true]
acts: ['leaky_relu', 'leaky_relu', 'leaky_relu', 'leaky_relu', null]

# training settings
TRAIN:
epochs: 40000
epochs_gen: 1
epochs_disc: 1
epochs_disc_tempo: 1
iters_per_epoch: 2
batch_size:
sup_constraint: 8
Expand Down
146 changes: 33 additions & 113 deletions examples/tempoGAN/tempoGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,9 @@ def train(cfg: DictConfig):
# initialize logger
logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")

# initialize parameters and import classes
USE_AMP = cfg.USE_AMP
USE_SPATIALDISC = cfg.USE_SPATIALDISC
USE_TEMPODISC = cfg.USE_TEMPODISC

weight_gen = cfg.WEIGHT_GEN # lambda_l1, lambda_l2, lambda_t
weight_gen_layer = (
cfg.WEIGHT_GEN_LAYER if USE_SPATIALDISC else None
cfg.WEIGHT_GEN_LAYER if cfg.USE_SPATIALDISC else None
) # lambda_layer, lambda_layer1, lambda_layer2, lambda_layer3, lambda_layer4
weight_disc = cfg.WEIGHT_DISC
tile_ratio = cfg.TILE_RATIO
Expand All @@ -58,116 +53,41 @@ def train(cfg: DictConfig):
dataset_train = hdf5storage.loadmat(cfg.DATASET_PATH)
dataset_valid = hdf5storage.loadmat(cfg.DATASET_PATH_VALID)

# init Generator params
in_channel = 1
rb_channel0 = (2, 8, 8)
rb_channel1 = (128, 128, 128)
rb_channel2 = (32, 8, 8)
rb_channel3 = (2, 1, 1)
out_channels_tuple = (rb_channel0, rb_channel1, rb_channel2, rb_channel3)
kernel_sizes_tuple = (((5, 5),) * 2 + ((1, 1),),) * 4
strides_tuple = ((1, 1, 1),) * 4
use_bns_tuple = ((True, True, True),) * 3 + ((False, False, False),)
acts_tuple = (("relu", None, None),) * 4

# define Generator model
model_gen = ppsci.arch.Generator(
cfg.MODEL.input_keys, # 'NCHW'
cfg.MODEL.output_keys,
in_channel,
out_channels_tuple,
kernel_sizes_tuple,
strides_tuple,
use_bns_tuple,
acts_tuple,
)
model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)

model_gen.register_input_transform(gen_funcs.transform_in)
dics_funcs.model_gen = model_gen

model_tuple = (model_gen,)

# init Discriminators params
in_channel = 2
in_channel_tempo = 3
out_channels = (32, 64, 128, 256)
in_shape = np.shape(dataset_train["density_high"][0])
h, w = in_shape[1] // tile_ratio, in_shape[2] // tile_ratio
down_sample_ratio = 2 ** (len(out_channels) - 1)
fc_channel = int(
out_channels[-1] * (h / down_sample_ratio) * (w / down_sample_ratio)
)
kernel_sizes = ((4, 4),) * 4
strides = (2,) * 3 + (1,)
use_bns = (False,) + (True,) * 3
acts = ("leaky_relu",) * 4 + (None,)

# define Discriminators
if USE_SPATIALDISC:
output_keys_disc = (
tuple(f"out0_layer{i}" for i in range(4))
+ ("out_disc_from_target",)
+ tuple(f"out1_layer{i}" for i in range(4))
+ ("out_disc_from_gen",)
)
model_disc = ppsci.arch.Discriminator(
("input_disc_from_target", "input_disc_from_gen"), # 'NCHW'
output_keys_disc,
in_channel,
out_channels,
fc_channel,
kernel_sizes,
strides,
use_bns,
acts,
)
if cfg.USE_SPATIALDISC:
model_disc = ppsci.arch.Discriminator(**cfg.MODEL.disc_net)
model_disc.register_input_transform(dics_funcs.transform_in)
model_tuple += (model_disc,)

# define temporal Discriminators
if USE_TEMPODISC:
output_keys_disc_tempo = (
tuple(f"out0_tempo_layer{i}" for i in range(4))
+ ("out_disc_tempo_from_target",)
+ tuple(f"out1_tempo_layer{i}" for i in range(4))
+ ("out_disc_tempo_from_gen",)
)
model_disc_tempo = ppsci.arch.Discriminator(
("input_tempo_disc_from_target", "input_tempo_disc_from_gen"), # 'NCHW'
output_keys_disc_tempo,
in_channel_tempo,
out_channels,
fc_channel,
kernel_sizes,
strides,
use_bns,
acts,
)
if cfg.USE_TEMPODISC:
model_disc_tempo = ppsci.arch.Discriminator(**cfg.MODEL.tempo_net)
model_disc_tempo.register_input_transform(dics_funcs.transform_in_tempo)
model_tuple += (model_disc_tempo,)

# define model_list
model_list = ppsci.arch.ModelList(model_tuple)

# set training hyper-parameters
ITERS_PER_EPOCH = cfg.TRAIN.iters_per_epoch
EPOCHS = cfg.TRAIN.epochs
EPOCHS_GEN = EPOCHS_DISC = EPOCHS_DISC_TEMPO = 1
BATCH_SIZE = cfg.TRAIN.batch_size.sup_constraint

# initialize Adam optimizer
lr_scheduler_gen = ppsci.optimizer.lr_scheduler.Step(
step_size=EPOCHS // 2, **cfg.TRAIN.lr_scheduler
step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
)()
optimizer_gen = ppsci.optimizer.Adam(lr_scheduler_gen)((model_gen,))
if USE_SPATIALDISC:
if cfg.USE_SPATIALDISC:
lr_scheduler_disc = ppsci.optimizer.lr_scheduler.Step(
step_size=EPOCHS // 2, **cfg.TRAIN.lr_scheduler
step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
)()
optimizer_disc = ppsci.optimizer.Adam(lr_scheduler_disc)((model_disc,))
if USE_TEMPODISC:
if cfg.USE_TEMPODISC:
lr_scheduler_disc_tempo = ppsci.optimizer.lr_scheduler.Step(
step_size=EPOCHS // 2, **cfg.TRAIN.lr_scheduler
step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
)()
optimizer_disc_tempo = ppsci.optimizer.Adam(lr_scheduler_disc_tempo)(
(model_disc_tempo,)
Expand All @@ -192,7 +112,7 @@ def train(cfg: DictConfig):
},
),
},
"batch_size": BATCH_SIZE,
"batch_size": cfg.TRAIN.batch_size.sup_constraint,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
Expand All @@ -203,7 +123,7 @@ def train(cfg: DictConfig):
name="sup_constraint_gen",
)
constraint_gen = {sup_constraint_gen.name: sup_constraint_gen}
if USE_TEMPODISC:
if cfg.USE_TEMPODISC:
sup_constraint_gen_tempo = ppsci.constraint.SupervisedConstraint(
{
"dataset": {
Expand All @@ -221,7 +141,7 @@ def train(cfg: DictConfig):
},
),
},
"batch_size": int(BATCH_SIZE // 3),
"batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
"sampler": {
"name": "BatchSampler",
"drop_last": False,
Expand All @@ -235,7 +155,7 @@ def train(cfg: DictConfig):

# Discriminators
# maunally build constraint(s)
if USE_SPATIALDISC:
if cfg.USE_SPATIALDISC:
sup_constraint_disc = ppsci.constraint.SupervisedConstraint(
{
"dataset": {
Expand All @@ -262,7 +182,7 @@ def train(cfg: DictConfig):
},
),
},
"batch_size": BATCH_SIZE,
"batch_size": cfg.TRAIN.batch_size.sup_constraint,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
Expand All @@ -278,7 +198,7 @@ def train(cfg: DictConfig):

# temporal Discriminators
# maunally build constraint(s)
if USE_TEMPODISC:
if cfg.USE_TEMPODISC:
sup_constraint_disc_tempo = ppsci.constraint.SupervisedConstraint(
{
"dataset": {
Expand All @@ -305,7 +225,7 @@ def train(cfg: DictConfig):
},
),
},
"batch_size": int(BATCH_SIZE // 3),
"batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
"sampler": {
"name": "BatchSampler",
"drop_last": False,
Expand All @@ -326,55 +246,55 @@ def train(cfg: DictConfig):
cfg.output_dir,
optimizer_gen,
lr_scheduler_gen,
EPOCHS_GEN,
ITERS_PER_EPOCH,
cfg.TRAIN.epochs_gen,
cfg.TRAIN.iters_per_epoch,
eval_during_train=cfg.TRAIN.eval_during_train,
use_amp=USE_AMP,
use_amp=cfg.USE_AMP,
amp_level=cfg.TRAIN.amp_level,
)
if USE_SPATIALDISC:
if cfg.USE_SPATIALDISC:
solver_disc = ppsci.solver.Solver(
model_list,
constraint_disc,
cfg.output_dir,
optimizer_disc,
lr_scheduler_disc,
EPOCHS_DISC,
ITERS_PER_EPOCH,
cfg.TRAIN.epochs_disc,
cfg.TRAIN.iters_per_epoch,
eval_during_train=cfg.TRAIN.eval_during_train,
use_amp=USE_AMP,
use_amp=cfg.USE_AMP,
amp_level=cfg.TRAIN.amp_level,
)
if USE_TEMPODISC:
if cfg.USE_TEMPODISC:
solver_disc_tempo = ppsci.solver.Solver(
model_list,
constraint_disc_tempo,
cfg.output_dir,
optimizer_disc_tempo,
lr_scheduler_disc_tempo,
EPOCHS_DISC_TEMPO,
ITERS_PER_EPOCH,
cfg.TRAIN.epochs_disc_tempo,
cfg.TRAIN.iters_per_epoch,
eval_during_train=cfg.TRAIN.eval_during_train,
use_amp=USE_AMP,
use_amp=cfg.USE_AMP,
amp_level=cfg.TRAIN.amp_level,
)

PRED_INTERVAL = 200
for i in range(1, EPOCHS + 1):
for i in range(1, cfg.TRAIN.epochs + 1):
logger.message(f"\nEpoch: {i}\n")
# plotting during training
if i == 1 or i % PRED_INTERVAL == 0 or i == EPOCHS:
if i == 1 or i % PRED_INTERVAL == 0 or i == cfg.TRAIN.epochs:
func_module.predict_and_save_plot(
cfg.output_dir, i, solver_gen, dataset_valid, tile_ratio
)

dics_funcs.model_gen = model_gen
# train disc, input: (x,y,G(x))
if USE_SPATIALDISC:
if cfg.USE_SPATIALDISC:
solver_disc.train()

# train disc tempo, input: (y_3,G(x)_3)
if USE_TEMPODISC:
if cfg.USE_TEMPODISC:
solver_disc_tempo.train()

# train gen, input: (x,)
Expand Down