Skip to content

Commit

Permalink
Fix prune worker. (PaddlePaddle#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghaoshuang committed Jan 6, 2020
1 parent 790a9ff commit 5312f46
Showing 1 changed file with 23 additions and 34 deletions.
57 changes: 23 additions & 34 deletions paddleslim/prune/prune_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ def prune(self, var, pruned_axis, pruned_idx):
pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable.
"""
if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx)

def _visit(self, var, pruned_axis):
key = "_".join([str(self.op.idx()), var.name()])
if pruned_axis not in self.visited:
self.visited[pruned_axis] = {}
if key in self.visited[pruned_axis]:
return
return False
else:
self.visited[pruned_axis][key] = True
self._prune(var, pruned_axis, pruned_idx)
return True

def _prune(self, var, pruned_axis, pruned_idx):
raise NotImplementedError('Abstract method.')
Expand All @@ -83,16 +87,15 @@ def __init__(self, op, pruned_params, visited={}):
super(conv2d, self).__init__(op, pruned_params, visited)

def _prune(self, var, pruned_axis, pruned_idx):
data_format = sef.op.attr("data_format")
data_format = self.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
if var in self.op.inputs("Input"):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[1][key] = True
self._visit(filter_var, 1)
self.pruned_params.append((filter_var, 1, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx)
Expand All @@ -110,16 +113,14 @@ def _prune(self, var, pruned_axis, pruned_idx):
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
key = "_".join([str(self.op.idx()), output_var.name()])
self.visited[channel_axis][key] = True
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)

elif pruned_axis == 1:
input_var = self.op.inputs("Input")[0]
key = "_".join([str(self.op.idx()), input_var.name()])
self.visited[channel_axis][key] = True
self._visit(input_var, channel_axis)
pre_ops = input_var.inputs()
for op in pre_ops:
self._prune_op(op, input_var, channel_axis, pruned_idx)
Expand All @@ -128,8 +129,7 @@ def _prune(self, var, pruned_axis, pruned_idx):
pruned_axis, var.name())

filter_var = self.op.inputs("Filter")[0]
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[0][key] = True
self._visit(filter_var, 0)

self.pruned_params.append((filter_var, 0, pruned_idx))

Expand Down Expand Up @@ -158,8 +158,7 @@ def _prune(self, var, pruned_axis, pruned_idx):

if var in self.op.outputs("Y"):
in_var = self.op.inputs("X")[0]
key = "_".join([str(self.op.idx()), in_var.name()])
self.visited[pruned_axis][key] = True
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
Expand All @@ -171,8 +170,7 @@ def _prune(self, var, pruned_axis, pruned_idx):
self.pruned_params.append((param_var, 0, pruned_idx))

out_var = self.op.outputs("Y")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
Expand Down Expand Up @@ -214,8 +212,7 @@ def _prune(self, var, pruned_axis, pruned_idx):
self._prune_op(op, in_var, pruned_axis, pruned_idx)

out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
Expand Down Expand Up @@ -253,8 +250,7 @@ def _prune(self, var, pruned_axis, pruned_idx):
self._prune_op(op, in_var, pruned_axis, pruned_idx)

out_var = self.op.outputs(self.output_name)[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
Expand Down Expand Up @@ -317,8 +313,7 @@ def _prune(self, var, pruned_axis, pruned_idx):
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
Expand Down Expand Up @@ -363,8 +358,7 @@ def _prune(self, var, pruned_axis, pruned_idx):
start += v.shape()[pruned_axis]

out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, idx, visited={})
Expand All @@ -373,8 +367,7 @@ def _prune(self, var, pruned_axis, pruned_idx):
for op in v.inputs():
self._prune_op(op, v, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
Expand All @@ -386,7 +379,7 @@ def __init__(self, op, pruned_params, visited={}):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited)

def _prune(self, var, pruned_axis, pruned_idx):
data_format = sef.op.attr("data_format")
data_format = self.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
Expand All @@ -396,8 +389,7 @@ def _prune(self, var, pruned_axis, pruned_idx):

filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[0][key] = True
self._visit(filter_var, 0)

new_groups = filter_var.shape()[0] - len(pruned_idx)
self.op.set_attr("groups", new_groups)
Expand Down Expand Up @@ -425,8 +417,7 @@ def _prune(self, var, pruned_axis, pruned_idx):
self._prune_op(op, var, 0, pruned_idx)

output_var = self.op.outputs("Output")[0]
key = "_".join([str(self.op.idx()), output_var.name()])
self.visited[channel_axis][key] = True
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
Expand All @@ -436,8 +427,7 @@ def _prune(self, var, pruned_axis, pruned_idx):
assert pruned_axis == channel_axis
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[0][key] = True
self._visit(filter_var, 0)

new_groups = filter_var.shape()[0] - len(pruned_idx)
op.set_attr("groups", new_groups)
Expand All @@ -450,8 +440,7 @@ def _prune(self, var, pruned_axis, pruned_idx):
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))

in_var = self.op.inputs("Input")[0]
key = "_".join([str(self.op.idx()), in_var.name()])
self.visited[channel_axis][key] = True
self._visit(in_var, channel_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, channel_axis, pruned_idx)
Expand Down

0 comments on commit 5312f46

Please sign in to comment.