-
Notifications
You must be signed in to change notification settings - Fork 470
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #44 from ayulockin/wandb
Add Weights and Baises Integration
- Loading branch information
Showing
10 changed files
with
249 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,5 +90,8 @@ | |
"update_ema_every": 1, | ||
"ema_decay": 0.9999 | ||
} | ||
}, | ||
"wandb": { | ||
"project": "generation_ffhq_ddpm" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,5 +89,8 @@ | |
"update_ema_every": 1, | ||
"ema_decay": 0.9999 | ||
} | ||
}, | ||
"wandb": { | ||
"project": "generation_ffhq_sr3" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,5 +90,8 @@ | |
"update_ema_every": 1, | ||
"ema_decay": 0.9999 | ||
} | ||
}, | ||
"wandb": { | ||
"project": "sr_ffhq" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,5 +89,8 @@ | |
"update_ema_every": 1, | ||
"ema_decay": 0.9999 | ||
} | ||
}, | ||
"wandb": { | ||
"project": "sr_ffhq" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,5 +92,8 @@ | |
"update_ema_every": 1, | ||
"ema_decay": 0.9999 | ||
} | ||
}, | ||
"wandb": { | ||
"project": "distributed_high_sr_ffhq" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
|
||
class WandbLogger: | ||
""" | ||
Log using `Weights and Biases`. | ||
""" | ||
def __init__(self, opt): | ||
try: | ||
import wandb | ||
except ImportError: | ||
raise ImportError( | ||
"To use the Weights and Biases Logger please install wandb." | ||
"Run `pip install wandb` to install it." | ||
) | ||
|
||
self._wandb = wandb | ||
|
||
# Initialize a W&B run | ||
if self._wandb.run is None: | ||
self._wandb.init( | ||
project=opt['wandb']['project'], | ||
config=opt, | ||
dir='./experiments' | ||
) | ||
|
||
self.config = self._wandb.config | ||
|
||
if self.config.get('log_eval', None): | ||
self.eval_table = self._wandb.Table(columns=['fake_image', | ||
'sr_image', | ||
'hr_image', | ||
'psnr', | ||
'ssim']) | ||
else: | ||
self.eval_table = None | ||
|
||
if self.config.get('log_infer', None): | ||
self.infer_table = self._wandb.Table(columns=['fake_image', | ||
'sr_image', | ||
'hr_image']) | ||
else: | ||
self.infer_table = None | ||
|
||
def log_metrics(self, metrics, commit=True): | ||
""" | ||
Log train/validation metrics onto W&B. | ||
metrics: dictionary of metrics to be logged | ||
""" | ||
self._wandb.log(metrics, commit=commit) | ||
|
||
def log_image(self, key_name, image_array): | ||
""" | ||
Log image array onto W&B. | ||
key_name: name of the key | ||
image_array: numpy array of image. | ||
""" | ||
self._wandb.log({key_name: self._wandb.Image(image_array)}) | ||
|
||
def log_images(self, key_name, list_images): | ||
""" | ||
Log list of image array onto W&B | ||
key_name: name of the key | ||
list_images: list of numpy image arrays | ||
""" | ||
self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]}) | ||
|
||
def log_checkpoint(self, current_epoch, current_step): | ||
""" | ||
Log the model checkpoint as W&B artifacts | ||
current_epoch: the current epoch | ||
current_step: the current batch step | ||
""" | ||
model_artifact = self._wandb.Artifact( | ||
self._wandb.run.id + "_model", type="model" | ||
) | ||
|
||
gen_path = os.path.join( | ||
self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch)) | ||
opt_path = os.path.join( | ||
self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch)) | ||
|
||
model_artifact.add_file(gen_path) | ||
model_artifact.add_file(opt_path) | ||
self._wandb.log_artifact(model_artifact, aliases=["latest"]) | ||
|
||
def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None): | ||
""" | ||
Add data row-wise to the initialized table. | ||
""" | ||
if psnr is not None and ssim is not None: | ||
self.eval_table.add_data( | ||
self._wandb.Image(fake_img), | ||
self._wandb.Image(sr_img), | ||
self._wandb.Image(hr_img), | ||
psnr, | ||
ssim | ||
) | ||
else: | ||
self.infer_table.add_data( | ||
self._wandb.Image(fake_img), | ||
self._wandb.Image(sr_img), | ||
self._wandb.Image(hr_img) | ||
) | ||
|
||
def log_eval_table(self, commit=False): | ||
""" | ||
Log the table | ||
""" | ||
if self.eval_table: | ||
self._wandb.log({'eval_data': self.eval_table}, commit=commit) | ||
elif self.infer_table: | ||
self._wandb.log({'infer_data': self.infer_table}, commit=commit) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
b247799
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Traceback (most recent call last):
File "infer.py", line 12, in
import wandb
ModuleNotFoundError: No module named 'wandb'