Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Weights and Baises Integration #44

Merged
merged 4 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config/sample_ddpm_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "generation_ffhq_ddpm"
}
}
3 changes: 3 additions & 0 deletions config/sample_sr3_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "generation_ffhq_sr3"
}
}
3 changes: 3 additions & 0 deletions config/sr_ddpm_16_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "sr_ffhq"
}
}
3 changes: 3 additions & 0 deletions config/sr_sr3_16_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "sr_ffhq"
}
}
3 changes: 3 additions & 0 deletions config/sr_sr3_64_512.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "distributed_high_sr_ffhq"
}
}
19 changes: 19 additions & 0 deletions core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def parse(args):
phase = args.phase
opt_path = args.config
gpu_ids = args.gpu_ids
enable_wandb = args.enable_wandb
# remove comments starting with '//'
json_str = ''
with open(opt_path, 'r') as f:
Expand Down Expand Up @@ -72,6 +73,24 @@ def parse(args):
if phase == 'train':
opt['datasets']['val']['data_len'] = 3

# W&B Logging
try:
log_wandb_ckpt = args.log_wandb_ckpt
opt['log_wandb_ckpt'] = log_wandb_ckpt
except:
pass
try:
log_eval = args.log_eval
opt['log_eval'] = log_eval
except:
pass
try:
log_infer = args.log_infer
opt['log_infer'] = log_infer
except:
pass
opt['enable_wandb'] = enable_wandb

return opt


Expand Down
116 changes: 116 additions & 0 deletions core/wandb_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os

class WandbLogger:
"""
Log using `Weights and Biases`.
"""
def __init__(self, opt):
try:
import wandb
except ImportError:
raise ImportError(
"To use the Weights and Biases Logger please install wandb."
"Run `pip install wandb` to install it."
)

self._wandb = wandb

# Initialize a W&B run
if self._wandb.run is None:
self._wandb.init(
project=opt['wandb']['project'],
config=opt,
dir='./experiments'
)

self.config = self._wandb.config

if self.config.get('log_eval', None):
self.eval_table = self._wandb.Table(columns=['fake_image',
'sr_image',
'hr_image',
'psnr',
'ssim'])
else:
self.eval_table = None

if self.config.get('log_infer', None):
self.infer_table = self._wandb.Table(columns=['fake_image',
'sr_image',
'hr_image'])
else:
self.infer_table = None

def log_metrics(self, metrics, commit=True):
"""
Log train/validation metrics onto W&B.

metrics: dictionary of metrics to be logged
"""
self._wandb.log(metrics, commit=commit)

def log_image(self, key_name, image_array):
"""
Log image array onto W&B.

key_name: name of the key
image_array: numpy array of image.
"""
self._wandb.log({key_name: self._wandb.Image(image_array)})

def log_images(self, key_name, list_images):
"""
Log list of image array onto W&B

key_name: name of the key
list_images: list of numpy image arrays
"""
self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]})

def log_checkpoint(self, current_epoch, current_step):
"""
Log the model checkpoint as W&B artifacts

current_epoch: the current epoch
current_step: the current batch step
"""
model_artifact = self._wandb.Artifact(
self._wandb.run.id + "_model", type="model"
)

gen_path = os.path.join(
self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch))
opt_path = os.path.join(
self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch))

model_artifact.add_file(gen_path)
model_artifact.add_file(opt_path)
self._wandb.log_artifact(model_artifact, aliases=["latest"])

def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None):
"""
Add data row-wise to the initialized table.
"""
if psnr is not None and ssim is not None:
self.eval_table.add_data(
self._wandb.Image(fake_img),
self._wandb.Image(sr_img),
self._wandb.Image(hr_img),
psnr,
ssim
)
else:
self.infer_table.add_data(
self._wandb.Image(fake_img),
self._wandb.Image(sr_img),
self._wandb.Image(hr_img)
)

def log_eval_table(self, commit=False):
"""
Log the table
"""
if self.eval_table:
self._wandb.log({'eval_data': self.eval_table}, commit=commit)
elif self.infer_table:
self._wandb.log({'infer_data': self.infer_table}, commit=commit)
16 changes: 16 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import logging
import core.logger as Logger
import core.metrics as Metrics
from core.wandb_logger import WandbLogger
from tensorboardX import SummaryWriter
import os
import numpy as np
import wandb

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -16,6 +18,8 @@
parser.add_argument('-p', '--phase', type=str, choices=['val'], help='val(generation)', default='val')
parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
parser.add_argument('-debug', '-d', action='store_true')
parser.add_argument('-enable_wandb', action='store_true')
parser.add_argument('-log_infer', action='store_true')

# parse configs
args = parser.parse_args()
Expand All @@ -34,6 +38,12 @@
logger.info(Logger.dict2str(opt))
tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])

# Initialize WandbLogger
if opt['enable_wandb']:
wandb_logger = WandbLogger(opt)
else:
wandb_logger = None

# dataset
for phase, dataset_opt in opt['datasets'].items():
if phase == 'val':
Expand Down Expand Up @@ -85,3 +95,9 @@
hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx))
Metrics.save_img(
fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx))

if wandb_logger and opt['log_infer']:
wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img)

if wandb_logger and opt['log_infer']:
wandb_logger.log_eval_table(commit=True)
28 changes: 28 additions & 0 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import logging
import core.logger as Logger
import core.metrics as Metrics
from core.wandb_logger import WandbLogger
from tensorboardX import SummaryWriter
import os
import numpy as np
import wandb

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -17,6 +19,8 @@
help='Run either train(training) or val(generation)', default='train')
parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
parser.add_argument('-debug', '-d', action='store_true')
parser.add_argument('-enable_wandb', action='store_true')
parser.add_argument('-log_wandb_ckpt', action='store_true')

# parse configs
args = parser.parse_args()
Expand All @@ -35,6 +39,13 @@
logger.info(Logger.dict2str(opt))
tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])

# Initialize WandbLogger
if opt['enable_wandb']:
wandb_logger = WandbLogger(opt)
val_step = 0
else:
wandb_logger = None

# dataset
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train' and args.phase != 'val':
Expand Down Expand Up @@ -78,6 +89,9 @@
tb_logger.add_scalar(k, v, current_step)
logger.info(message)

if wandb_logger:
wandb_logger.log_metrics(logs)

# validation
if current_step % opt['train']['val_freq'] == 0:
result_path = '{}/{}'.format(opt['path']
Expand All @@ -100,19 +114,28 @@
'Iter_{}'.format(current_step),
np.transpose(sample_img, [2, 0, 1]),
idx)

if wandb_logger:
wandb_logger.log_image(f'validation_{idx}', sample_img)

diffusion.set_new_noise_schedule(
opt['model']['beta_schedule']['train'], schedule_phase='train')

if current_step % opt['train']['save_checkpoint_freq'] == 0:
logger.info('Saving models and training states.')
diffusion.save_network(current_epoch, current_step)

if wandb_logger and opt['log_wandb_ckpt']:
wandb_logger.log_checkpoint(current_epoch, current_step)

# save model
logger.info('End of training.')
else:
logger.info('Begin Model Evaluation.')

result_path = '{}'.format(opt['path']['results'])
os.makedirs(result_path, exist_ok=True)
sample_imgs = []
for idx in range(sample_sum):
idx += 1
diffusion.sample(continous=True)
Expand All @@ -133,3 +156,8 @@
sample_img, '{}/{}_{}_sample_process.png'.format(result_path, current_step, idx))
Metrics.save_img(
Metrics.tensor2img(visuals['SAM'][-1]), '{}/{}_{}_sample.png'.format(result_path, current_step, idx))

sample_imgs.append(Metrics.tensor2img(visuals['SAM'][-1]))

if wandb_logger:
wandb_logger.log_images('eval_images', sample_imgs)
Loading