-
Notifications
You must be signed in to change notification settings - Fork 10
/
helpers.py
33 lines (31 loc) · 1.19 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Polygon
from mpl_toolkits.mplot3d import Axes3D
def draw_heatmaps_3d(heatmaps, region, region_mask=None, fontsize=15):
height, width = heatmaps.shape[1:]
xs, ys = np.meshgrid(np.arange(width), np.arange(height))
if region_mask is not None:
region_mask = np.logical_not(region_mask)
for heatmap in heatmaps:
heatmap = heatmap.copy()
if region_mask is not None:
heatmap[region_mask] = 0
fig = plt.figure(figsize=(30, 10))
ax = fig.add_subplot(121)
ax.axis('off')
ax.imshow(heatmap, cmap='jet')
p = Polygon(region, fill=False)
ax.add_patch(p)
ax = fig.add_subplot(122, projection='3d')
ax.plot_surface(xs, ys, heatmap, cmap='jet')
y1, y2 = ax.get_ylim()
ax.set_ylim(y2, y1)
ax.tick_params('x', labelsize=fontsize)
ax.tick_params('y', labelsize=fontsize)
ax.tick_params('z', labelsize=fontsize)
ax.set_xlabel('X', fontsize=fontsize)
ax.set_ylabel('Y', fontsize=fontsize)
ax.set_zlabel('Value', fontsize=fontsize)
plt.show()
plt.close(fig)