Skip to content

Commit dd49999

Browse files
author
SAAS R7 User1
committed
add new visualization
1 parent 66a0de1 commit dd49999

File tree

2 files changed

+147
-74
lines changed

2 files changed

+147
-74
lines changed

examples/Untitled.ipynb

Lines changed: 38 additions & 48 deletions
Large diffs are not rendered by default.

exnn/base.py

Lines changed: 109 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,96 @@ def train_step_init(self, inputs, labels):
160160
@tf.function
161161
def train_step_finetune(self, inputs, labels):
162162
pass
163-
163+
164+
@property
165+
def projection_indices_(self):
166+
"""Return the projection indices.
167+
Returns
168+
-------
169+
projection_indices_ : ndarray of shape (d, )
170+
"""
171+
projection_indices = np.array([])
172+
if self.nfeature_num_ > 0:
173+
active_sim_subnets = [item["indice"] for key, item in self.active_subnets_.items()]
174+
projection_indices = self.proj_layer.proj_weights.numpy()[:, active_sim_subnets]
175+
return projection_indices
176+
177+
@property
178+
def orthogonality_measure_(self):
179+
"""Return the orthogonality measure (the lower, the better).
180+
Returns
181+
-------
182+
orthogonality_measure_ : float scalar
183+
"""
184+
ortho_measure = np.nan
185+
if self.nfeature_num_ > 0:
186+
ortho_measure = np.linalg.norm(np.dot(self.projection_indices_.T,
187+
self.projection_indices_) - np.eye(self.projection_indices_.shape[1]))
188+
if self.projection_indices_.shape[1] > 1:
189+
ortho_measure /= self.projection_indices_.shape[1]
190+
return ortho_measure
191+
192+
@property
193+
def importance_ratios_(self):
194+
"""Return the estimator importance ratios (the higher, the more important the feature).
195+
Returns
196+
-------
197+
importance_ratios_ : ndarray of shape (n_estimators,)
198+
The estimator importances.
199+
"""
200+
importance_ratios_ = {**self.active_subnets_, **self.active_dummy_subnets_}
201+
return importance_ratios_
202+
203+
@property
204+
def active_subnets_(self):
205+
"""
206+
Return the information of sim subnetworks
207+
"""
208+
if self.bn_flag:
209+
beta = self.output_layer.output_weights.numpy()
210+
else:
211+
subnet_norm = [self.subnet_blocks.subnets[i].moving_norm.numpy()[0] for i in range(self.subnet_num)]
212+
categ_norm = [self.categ_blocks.categnets[i].moving_norm.numpy()[0]for i in range(self.cfeature_num_)]
213+
beta = self.output_layer.output_weights.numpy() * np.hstack([subnet_norm, categ_norm]).reshape([-1, 1])
214+
215+
beta = beta * self.output_layer.output_switcher.numpy()
216+
importance_ratio = (np.abs(beta) / np.sum(np.abs(beta))).reshape([-1])
217+
sorted_index = np.argsort(importance_ratio)
218+
active_index = sorted_index[importance_ratio[sorted_index].cumsum() > 0][::-1]
219+
active_subnets = {"Subnet " + str(indice + 1):{"type":"sim_net",
220+
"indice":indice,
221+
"rank":idx,
222+
"beta":self.output_layer.output_weights.numpy()[indice],
223+
"ir":importance_ratio[indice]}
224+
for idx, indice in enumerate(active_index) if indice in range(self.subnet_num)}
225+
226+
return active_subnets
227+
228+
@property
229+
def active_dummy_subnets_(self):
230+
"""
231+
Return the information of active categorical features
232+
"""
233+
if self.bn_flag:
234+
beta = self.output_layer.output_weights.numpy()
235+
else:
236+
subnet_norm = [self.subnet_blocks.subnets[i].moving_norm.numpy()[0] for i in range(self.subnet_num)]
237+
categ_norm = [self.categ_blocks.categnets[i].moving_norm.numpy()[0]for i in range(self.cfeature_num_)]
238+
beta = self.output_layer.output_weights.numpy() * np.hstack([subnet_norm, categ_norm]).reshape([-1, 1])
239+
240+
beta = beta * self.output_layer.output_switcher.numpy()
241+
importance_ratio = (np.abs(beta) / np.sum(np.abs(beta))).reshape([-1])
242+
sorted_index = np.argsort(importance_ratio)
243+
active_index = sorted_index[importance_ratio[sorted_index].cumsum() > 0][::-1]
244+
245+
active_dummy_subnets = {self.cfeature_list_[indice - self.subnet_num]:{"type":"dummy_net",
246+
"indice":indice,
247+
"rank":idx,
248+
"beta":self.output_layer.output_weights.numpy()[indice],
249+
"ir":importance_ratio[indice]}
250+
for idx, indice in enumerate(active_index) if indice in range(self.subnet_num, self.subnet_num + self.cfeature_num_)}
251+
return active_dummy_subnets
252+
164253
def estimate_density(self, x):
165254

166255
density, bins = np.histogram(x, bins=10, density=True)
@@ -383,41 +472,36 @@ def visualize(self, folder="./results/", name="demo", save_png=False, save_eps=F
383472

384473
def visualize_new(self, cols_per_row=3, folder="./results/", name="demo", save_png=False, save_eps=False):
385474

386-
input_size = self.nfeature_num_
387-
coef_index = self.proj_layer.proj_weights.numpy()
388-
active_index, active_categ_index, beta, subnets_scale = self.get_active_subnets()
389-
max_ids = len(active_index) + len(active_categ_index)
390-
391475
input_size = self.nfeature_num_
392476
coef_index = self.proj_layer.proj_weights.numpy()
393477

478+
max_ids = len(self.active_subnets_) + len(self.active_dummy_subnets_)
394479
fig = plt.figure(figsize=(8 * cols_per_row, 4.6 * int(np.ceil(max_ids / cols_per_row))))
395480
outer = gridspec.GridSpec(int(np.ceil(max_ids / cols_per_row)), cols_per_row, wspace=0.15, hspace=0.25)
396481

397-
if coef_index.shape[1] > 0:
398-
xlim_min = - max(np.abs(coef_index.min() - 0.1), np.abs(coef_index.max() + 0.1))
399-
xlim_max = max(np.abs(coef_index.min() - 0.1), np.abs(coef_index.max() + 0.1))
400-
401-
idx = 0
402-
for i, indice in enumerate(active_index):
482+
if self.projection_indices_.shape[1] > 0:
483+
xlim_min = - max(np.abs(self.projection_indices_.min() - 0.1), np.abs(self.projection_indices_.max() + 0.1))
484+
xlim_max = max(np.abs(self.projection_indices_.min() - 0.1), np.abs(self.projection_indices_.max() + 0.1))
485+
for idx, (key, item) in enumerate(self.active_subnets_.items()):
403486

487+
indice = item["indice"]
404488
inner = outer[idx].subgridspec(2, 2, wspace=0.15, height_ratios=[6, 1], width_ratios=[3, 1])
405489
ax1_main = fig.add_subplot(inner[0, 0])
406490
subnet = self.subnet_blocks.subnets[indice]
407491
min_ = self.subnet_input_min[indice]
408492
max_ = self.subnet_input_max[indice]
409493
density, bins = self.subnet_input_density[indice]
410494
xgrid = np.linspace(min_, max_, 1000).reshape([-1, 1])
411-
ygrid = beta[indice] * subnet.__call__(tf.cast(tf.constant(xgrid), tf.float32)).numpy()
495+
ygrid = np.sign(item["beta"]) * subnet.__call__(tf.cast(tf.constant(xgrid), tf.float32)).numpy()
412496

413497
if coef_index[np.argmax(np.abs(coef_index[:, indice])), indice] < 0:
414498
coef_index[:, indice] = - coef_index[:, indice]
415499
xgrid = - xgrid
416500

417501
ax1_main.plot(xgrid, ygrid, color="red")
418502
ax1_main.set_xticklabels([])
419-
ax1_main.set_title("SIM " + str(idx + 1) +
420-
" (IR: " + str(np.round(100 * subnets_scale[indice], 2)) + "%)", fontsize=16)
503+
ax1_main.set_title("SIM " + str(idx + 1) +
504+
" (IR: " + str(np.round(100 * item["ir"], 2)) + "%)", fontsize=16)
421505
fig.add_subplot(ax1_main)
422506

423507
ax1_density = fig.add_subplot(inner[1, 0])
@@ -447,19 +531,19 @@ def visualize_new(self, cols_per_row=3, folder="./results/", name="demo", save_p
447531
ax2.set_xlim(xlim_min, xlim_max)
448532
ax2.set_ylim(-1, len(coef_index.T[indice, :input_size].ravel()))
449533
ax2.axvline(0, linestyle="dotted", color="black")
450-
idx = idx + 1
451534
fig.add_subplot(ax2)
452-
453-
for i, indice in enumerate(active_categ_index):
535+
536+
for idx, (key, item) in enumerate(self.active_dummy_subnets_.items()):
454537

538+
indice = item["indice"]
455539
feature_name = self.cfeature_list_[indice - self.subnet_num]
456540
norm = self.categ_blocks.categnets[indice - self.subnet_num].moving_norm.numpy()
457541
dummy_values = self.dummy_density_[feature_name]["density"]["values"]
458542
dummy_scores = self.dummy_density_[feature_name]["density"]["scores"]
459543
dummy_coef = self.categ_blocks.categnets[indice - self.subnet_num].categ_bias.numpy()
460-
dummy_coef = beta[indice] * dummy_coef[:, 0] / norm
544+
dummy_coef = np.sign(item["beta"]) * dummy_coef[:, 0] / norm
461545

462-
ax_main = fig.add_subplot(outer[len(active_index) + idx])
546+
ax_main = fig.add_subplot(outer[len(self.active_subnets_) + idx])
463547
ax_density = ax_main.twinx()
464548
ax_density.bar(np.arange(len(dummy_values)), dummy_scores, width=0.6)
465549
ax_density.set_ylim(0, dummy_scores.max() * 1.2)
@@ -480,16 +564,15 @@ def visualize_new(self, cols_per_row=3, folder="./results/", name="demo", save_p
480564
" (IR: " + str(np.round(100 * self.importance_ratios_[feature_name]["ir"], 2)) + "%)", fontsize=16)
481565
ax_main.set_zorder(ax_density.get_zorder() + 1)
482566
ax_main.patch.set_visible(False)
483-
idx = idx + 1
484-
567+
485568
plt.show()
486569
if max_ids > 0:
487570
save_path = folder + name
488-
if save_eps:
571+
if save_png:
489572
if not os.path.exists(folder):
490573
os.makedirs(folder)
491-
fig.savefig("%s.eps" % save_path, bbox_inches="tight", dpi=100)
492-
if save_png:
574+
f.savefig("%s.png" % save_path, bbox_inches='tight', dpi=100)
575+
if save_eps:
493576
if not os.path.exists(folder):
494577
os.makedirs(folder)
495-
fig.savefig("%s.png" % save_path, bbox_inches="tight", dpi=100)
578+
f.savefig("%s.eps" % save_path, bbox_inches='tight', dpi=100)

0 commit comments

Comments
 (0)