-
Notifications
You must be signed in to change notification settings - Fork 2
/
tb_callback.py
103 lines (89 loc) · 3.66 KB
/
tb_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import io
import platform
import matplotlib
if platform.system() == 'Darwin':
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import typing as t
class TB_Summary:
""" Helper class to write TensorBoard summaries """
def __init__(self, output_dir: str):
self.dpi = 120
plt.style.use('seaborn-deep')
self.train_summary_writer = tf.summary.create_file_writer(os.path.join(output_dir, 'train'))
self.validate_summary_writer = tf.summary.create_file_writer(os.path.join(output_dir, 'validate'))
def scalar(self, tag, value, epoch, training):
if training:
with self.train_summary_writer.as_default():
tf.summary.scalar(tag, value, step=epoch)
else:
with self.validate_summary_writer.as_default():
tf.summary.scalar(tag, value, step=epoch)
def losses(self, results):
for key, value in results.items():
value = tf.math.reduce_mean(value)
print('%s = %.4f ' % (key, value.numpy()), end='')
print('\n')
def image(self, tag, values, step: int = 0, training: bool = False):
writer = self.get_writer(training)
with writer.as_default():
tf.summary.image(tag, data=values, step=step, max_outputs=len(values))
def figure(self,
tag,
figure,
step: int = 0,
training: bool = False,
close: bool = True):
""" Write matplotlib figure to summary
Args:
tag: data identifier
figure: matplotlib figure or a list of figures
step: global step value to record
training: training summary or test summary
close: flag to close figure
"""
buffer = io.BytesIO()
figure.savefig(buffer, dpi=self.dpi, format='png', bbox_inches='tight')
buffer.seek(0)
image = tf.image.decode_png(buffer.getvalue(), channels=4)
self.image(tag, tf.expand_dims(image, 0), step=step, training=training)
if close:
plt.close(figure)
def image_cycle(self,
tag: str,
images: t.List[np.ndarray],
labels: t.List[str],
step: int = 0,
training: bool = False):
""" Plot image cycle to TensorBoard
Args:
tag: data identifier
images: list of np.ndarray where len(images) == 3 and each array has
shape (N,H,W,C)
labels: list of string where len(labels) == 3
step: global step value to record
training: training summary or test summary
"""
assert len(images) == len(labels) == 3
for sample in range(len(images[0])):
figure, axes = plt.subplots(nrows=1,
ncols=3,
figsize=(9, 3.25),
dpi=self.dpi)
axes[0].imshow(images[0][sample, ...], interpolation='none')
axes[0].set_title(labels[0])
axes[1].imshow(images[1][sample, ...], interpolation='none')
axes[1].set_title(labels[1])
axes[2].imshow(images[2][sample, ...], interpolation='none')
axes[2].set_title(labels[2])
plt.setp(axes, xticks=[], yticks=[])
plt.tight_layout()
figure.subplots_adjust(wspace=0.02, hspace=0.02)
self.figure(tag=f'{tag}/sample_#{sample:03d}',
figure=figure,
step=step,
training=training,
close=True)