Skip to content

Commit

Permalink
Update csgd_prune.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DingXiaoH authored Jun 27, 2020
1 parent a2998cb commit 53146b4
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions csgd/csgd_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,20 @@ def handle_vecs(key_name):
num_filters = k_value.shape[0]
fc_neurons_per_conv_kernel = follow_kernel_value.shape[1] // num_filters
print('{} filters, {} neurons per kernel'.format(num_filters, fc_neurons_per_conv_kernel))
base = np.arange(0, fc_neurons_per_conv_kernel * num_filters, num_filters)

for clst in clusters:
if len(clst) == 1:
continue
for i in clst[1:]:
fc_idx_to_delete.append(base + i)
fc_idx_to_delete.append(np.arange(i * fc_neurons_per_conv_kernel,
(i+1) * fc_neurons_per_conv_kernel))
to_concat = []
for i in clst:
corresponding_neurons_idx = base + i
corresponding_neurons_idx = np.arange(i * fc_neurons_per_conv_kernel,
(i+1) * fc_neurons_per_conv_kernel)
to_concat.append(np.expand_dims(follow_kernel_value[:, corresponding_neurons_idx], axis=0))
summed = np.sum(np.concatenate(to_concat, axis=0), axis=0)
reserved_idx = base + clst[0]
reserved_idx = np.arange(clst[0] * fc_neurons_per_conv_kernel, (clst[0]+1) * fc_neurons_per_conv_kernel)
follow_kernel_value[:, reserved_idx] = summed
if len(fc_idx_to_delete) > 0:
follow_kernel_value = delete_or_keep(follow_kernel_value, np.concatenate(fc_idx_to_delete, axis=0), axis=1)
Expand All @@ -115,4 +117,4 @@ def handle_vecs(key_name):
result['deps'] = new_deps

print('save {} values to {} after pruning'.format(len(result), save_file))
save_hdf5(result, save_file)
save_hdf5(result, save_file)

0 comments on commit 53146b4

Please sign in to comment.