Skip to content

Commit

Permalink
Merge pull request #44 from ayulockin/wandb
Browse files Browse the repository at this point in the history
Add Weights and Baises Integration
  • Loading branch information
Janspiry authored Jan 13, 2022
2 parents 5fb791f + 18e12f2 commit b247799
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 4 deletions.
3 changes: 3 additions & 0 deletions config/sample_ddpm_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "generation_ffhq_ddpm"
}
}
3 changes: 3 additions & 0 deletions config/sample_sr3_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "generation_ffhq_sr3"
}
}
3 changes: 3 additions & 0 deletions config/sr_ddpm_16_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "sr_ffhq"
}
}
3 changes: 3 additions & 0 deletions config/sr_sr3_16_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "sr_ffhq"
}
}
3 changes: 3 additions & 0 deletions config/sr_sr3_64_512.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "distributed_high_sr_ffhq"
}
}
19 changes: 19 additions & 0 deletions core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def parse(args):
phase = args.phase
opt_path = args.config
gpu_ids = args.gpu_ids
enable_wandb = args.enable_wandb
# remove comments starting with '//'
json_str = ''
with open(opt_path, 'r') as f:
Expand Down Expand Up @@ -72,6 +73,24 @@ def parse(args):
if phase == 'train':
opt['datasets']['val']['data_len'] = 3

# W&B Logging
try:
log_wandb_ckpt = args.log_wandb_ckpt
opt['log_wandb_ckpt'] = log_wandb_ckpt
except:
pass
try:
log_eval = args.log_eval
opt['log_eval'] = log_eval
except:
pass
try:
log_infer = args.log_infer
opt['log_infer'] = log_infer
except:
pass
opt['enable_wandb'] = enable_wandb

return opt


Expand Down
116 changes: 116 additions & 0 deletions core/wandb_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os

class WandbLogger:
"""
Log using `Weights and Biases`.
"""
def __init__(self, opt):
try:
import wandb
except ImportError:
raise ImportError(
"To use the Weights and Biases Logger please install wandb."
"Run `pip install wandb` to install it."
)

self._wandb = wandb

# Initialize a W&B run
if self._wandb.run is None:
self._wandb.init(
project=opt['wandb']['project'],
config=opt,
dir='./experiments'
)

self.config = self._wandb.config

if self.config.get('log_eval', None):
self.eval_table = self._wandb.Table(columns=['fake_image',
'sr_image',
'hr_image',
'psnr',
'ssim'])
else:
self.eval_table = None

if self.config.get('log_infer', None):
self.infer_table = self._wandb.Table(columns=['fake_image',
'sr_image',
'hr_image'])
else:
self.infer_table = None

def log_metrics(self, metrics, commit=True):
"""
Log train/validation metrics onto W&B.
metrics: dictionary of metrics to be logged
"""
self._wandb.log(metrics, commit=commit)

def log_image(self, key_name, image_array):
"""
Log image array onto W&B.
key_name: name of the key
image_array: numpy array of image.
"""
self._wandb.log({key_name: self._wandb.Image(image_array)})

def log_images(self, key_name, list_images):
"""
Log list of image array onto W&B
key_name: name of the key
list_images: list of numpy image arrays
"""
self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]})

def log_checkpoint(self, current_epoch, current_step):
"""
Log the model checkpoint as W&B artifacts
current_epoch: the current epoch
current_step: the current batch step
"""
model_artifact = self._wandb.Artifact(
self._wandb.run.id + "_model", type="model"
)

gen_path = os.path.join(
self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch))
opt_path = os.path.join(
self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch))

model_artifact.add_file(gen_path)
model_artifact.add_file(opt_path)
self._wandb.log_artifact(model_artifact, aliases=["latest"])

def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None):
"""
Add data row-wise to the initialized table.
"""
if psnr is not None and ssim is not None:
self.eval_table.add_data(
self._wandb.Image(fake_img),
self._wandb.Image(sr_img),
self._wandb.Image(hr_img),
psnr,
ssim
)
else:
self.infer_table.add_data(
self._wandb.Image(fake_img),
self._wandb.Image(sr_img),
self._wandb.Image(hr_img)
)

def log_eval_table(self, commit=False):
"""
Log the table
"""
if self.eval_table:
self._wandb.log({'eval_data': self.eval_table}, commit=commit)
elif self.infer_table:
self._wandb.log({'infer_data': self.infer_table}, commit=commit)
16 changes: 16 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import logging
import core.logger as Logger
import core.metrics as Metrics
from core.wandb_logger import WandbLogger
from tensorboardX import SummaryWriter
import os
import numpy as np
import wandb

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -16,6 +18,8 @@
parser.add_argument('-p', '--phase', type=str, choices=['val'], help='val(generation)', default='val')
parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
parser.add_argument('-debug', '-d', action='store_true')
parser.add_argument('-enable_wandb', action='store_true')
parser.add_argument('-log_infer', action='store_true')

# parse configs
args = parser.parse_args()
Expand All @@ -34,6 +38,12 @@
logger.info(Logger.dict2str(opt))
tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])

# Initialize WandbLogger
if opt['enable_wandb']:
wandb_logger = WandbLogger(opt)
else:
wandb_logger = None

# dataset
for phase, dataset_opt in opt['datasets'].items():
if phase == 'val':
Expand Down Expand Up @@ -85,3 +95,9 @@
hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx))
Metrics.save_img(
fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx))

if wandb_logger and opt['log_infer']:
wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img)

if wandb_logger and opt['log_infer']:
wandb_logger.log_eval_table(commit=True)
28 changes: 28 additions & 0 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import logging
import core.logger as Logger
import core.metrics as Metrics
from core.wandb_logger import WandbLogger
from tensorboardX import SummaryWriter
import os
import numpy as np
import wandb

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -17,6 +19,8 @@
help='Run either train(training) or val(generation)', default='train')
parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
parser.add_argument('-debug', '-d', action='store_true')
parser.add_argument('-enable_wandb', action='store_true')
parser.add_argument('-log_wandb_ckpt', action='store_true')

# parse configs
args = parser.parse_args()
Expand All @@ -35,6 +39,13 @@
logger.info(Logger.dict2str(opt))
tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])

# Initialize WandbLogger
if opt['enable_wandb']:
wandb_logger = WandbLogger(opt)
val_step = 0
else:
wandb_logger = None

# dataset
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train' and args.phase != 'val':
Expand Down Expand Up @@ -78,6 +89,9 @@
tb_logger.add_scalar(k, v, current_step)
logger.info(message)

if wandb_logger:
wandb_logger.log_metrics(logs)

# validation
if current_step % opt['train']['val_freq'] == 0:
result_path = '{}/{}'.format(opt['path']
Expand All @@ -100,19 +114,28 @@
'Iter_{}'.format(current_step),
np.transpose(sample_img, [2, 0, 1]),
idx)

if wandb_logger:
wandb_logger.log_image(f'validation_{idx}', sample_img)

diffusion.set_new_noise_schedule(
opt['model']['beta_schedule']['train'], schedule_phase='train')

if current_step % opt['train']['save_checkpoint_freq'] == 0:
logger.info('Saving models and training states.')
diffusion.save_network(current_epoch, current_step)

if wandb_logger and opt['log_wandb_ckpt']:
wandb_logger.log_checkpoint(current_epoch, current_step)

# save model
logger.info('End of training.')
else:
logger.info('Begin Model Evaluation.')

result_path = '{}'.format(opt['path']['results'])
os.makedirs(result_path, exist_ok=True)
sample_imgs = []
for idx in range(sample_sum):
idx += 1
diffusion.sample(continous=True)
Expand All @@ -133,3 +156,8 @@
sample_img, '{}/{}_{}_sample_process.png'.format(result_path, current_step, idx))
Metrics.save_img(
Metrics.tensor2img(visuals['SAM'][-1]), '{}/{}_{}_sample.png'.format(result_path, current_step, idx))

sample_imgs.append(Metrics.tensor2img(visuals['SAM'][-1]))

if wandb_logger:
wandb_logger.log_images('eval_images', sample_imgs)
Loading

1 comment on commit b247799

@far-rainbow
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Traceback (most recent call last):
File "infer.py", line 12, in
import wandb
ModuleNotFoundError: No module named 'wandb'

Please sign in to comment.