diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py index 16edbd992b022..4651e17e03e1d 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_walker.py @@ -49,14 +49,18 @@ def prune(self, var, pruned_axis, pruned_idx): pruned_axis(int): The axis to be pruned of root variable. pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable. """ + if self._visit(var, pruned_axis): + self._prune(var, pruned_axis, pruned_idx) + + def _visit(self, var, pruned_axis): key = "_".join([str(self.op.idx()), var.name()]) if pruned_axis not in self.visited: self.visited[pruned_axis] = {} if key in self.visited[pruned_axis]: - return + return False else: self.visited[pruned_axis][key] = True - self._prune(var, pruned_axis, pruned_idx) + return True def _prune(self, var, pruned_axis, pruned_idx): raise NotImplementedError('Abstract method.') @@ -83,7 +87,7 @@ def __init__(self, op, pruned_params, visited={}): super(conv2d, self).__init__(op, pruned_params, visited) def _prune(self, var, pruned_axis, pruned_idx): - data_format = sef.op.attr("data_format") + data_format = self.op.attr("data_format") channel_axis = 1 if data_format == "NHWC": channel_axis = 3 @@ -91,8 +95,7 @@ def _prune(self, var, pruned_axis, pruned_idx): assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format( pruned_axis, var.name()) filter_var = self.op.inputs("Filter")[0] - key = "_".join([str(self.op.idx()), filter_var.name()]) - self.visited[1][key] = True + self._visit(filter_var, 1) self.pruned_params.append((filter_var, 1, pruned_idx)) for op in filter_var.outputs(): self._prune_op(op, filter_var, 1, pruned_idx) @@ -110,16 +113,14 @@ def _prune(self, var, pruned_axis, pruned_idx): self.pruned_params.append( (self.op.inputs("Bias"), channel_axis, pruned_idx)) output_var = self.op.outputs("Output")[0] - key = "_".join([str(self.op.idx()), output_var.name()]) - self.visited[channel_axis][key] = True + self._visit(output_var, channel_axis) next_ops = output_var.outputs() for op in next_ops: self._prune_op(op, output_var, channel_axis, pruned_idx) elif pruned_axis == 1: input_var = self.op.inputs("Input")[0] - key = "_".join([str(self.op.idx()), input_var.name()]) - self.visited[channel_axis][key] = True + self._visit(input_var, channel_axis) pre_ops = input_var.inputs() for op in pre_ops: self._prune_op(op, input_var, channel_axis, pruned_idx) @@ -128,8 +129,7 @@ def _prune(self, var, pruned_axis, pruned_idx): pruned_axis, var.name()) filter_var = self.op.inputs("Filter")[0] - key = "_".join([str(self.op.idx()), filter_var.name()]) - self.visited[0][key] = True + self._visit(filter_var, 0) self.pruned_params.append((filter_var, 0, pruned_idx)) @@ -158,8 +158,7 @@ def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.outputs("Y"): in_var = self.op.inputs("X")[0] - key = "_".join([str(self.op.idx()), in_var.name()]) - self.visited[pruned_axis][key] = True + self._visit(in_var, pruned_axis) pre_ops = in_var.inputs() for op in pre_ops: self._prune_op(op, in_var, pruned_axis, pruned_idx) @@ -171,8 +170,7 @@ def _prune(self, var, pruned_axis, pruned_idx): self.pruned_params.append((param_var, 0, pruned_idx)) out_var = self.op.outputs("Y")[0] - key = "_".join([str(self.op.idx()), out_var.name()]) - self.visited[pruned_axis][key] = True + self._visit(out_var, pruned_axis) next_ops = out_var.outputs() for op in next_ops: self._prune_op(op, out_var, pruned_axis, pruned_idx) @@ -214,8 +212,7 @@ def _prune(self, var, pruned_axis, pruned_idx): self._prune_op(op, in_var, pruned_axis, pruned_idx) out_var = self.op.outputs("Out")[0] - key = "_".join([str(self.op.idx()), out_var.name()]) - self.visited[pruned_axis][key] = True + self._visit(out_var, pruned_axis) next_ops = out_var.outputs() for op in next_ops: self._prune_op(op, out_var, pruned_axis, pruned_idx) @@ -253,8 +250,7 @@ def _prune(self, var, pruned_axis, pruned_idx): self._prune_op(op, in_var, pruned_axis, pruned_idx) out_var = self.op.outputs(self.output_name)[0] - key = "_".join([str(self.op.idx()), out_var.name()]) - self.visited[pruned_axis][key] = True + self._visit(out_var, pruned_axis) next_ops = out_var.outputs() for op in next_ops: self._prune_op(op, out_var, pruned_axis, pruned_idx) @@ -317,8 +313,7 @@ def _prune(self, var, pruned_axis, pruned_idx): for op in pre_ops: self._prune_op(op, in_var, pruned_axis, pruned_idx) out_var = self.op.outputs("Out")[0] - key = "_".join([str(self.op.idx()), out_var.name()]) - self.visited[pruned_axis][key] = True + self._visit(out_var, pruned_axis) next_ops = out_var.outputs() for op in next_ops: self._prune_op(op, out_var, pruned_axis, pruned_idx) @@ -363,8 +358,7 @@ def _prune(self, var, pruned_axis, pruned_idx): start += v.shape()[pruned_axis] out_var = self.op.outputs("Out")[0] - key = "_".join([str(self.op.idx()), out_var.name()]) - self.visited[pruned_axis][key] = True + self._visit(out_var, pruned_axis) next_ops = out_var.outputs() for op in next_ops: self._prune_op(op, out_var, pruned_axis, idx, visited={}) @@ -373,8 +367,7 @@ def _prune(self, var, pruned_axis, pruned_idx): for op in v.inputs(): self._prune_op(op, v, pruned_axis, pruned_idx) out_var = self.op.outputs("Out")[0] - key = "_".join([str(self.op.idx()), out_var.name()]) - self.visited[pruned_axis][key] = True + self._visit(out_var, pruned_axis) next_ops = out_var.outputs() for op in next_ops: self._prune_op(op, out_var, pruned_axis, pruned_idx) @@ -386,7 +379,7 @@ def __init__(self, op, pruned_params, visited={}): super(depthwise_conv2d, self).__init__(op, pruned_params, visited) def _prune(self, var, pruned_axis, pruned_idx): - data_format = sef.op.attr("data_format") + data_format = self.op.attr("data_format") channel_axis = 1 if data_format == "NHWC": channel_axis = 3 @@ -396,8 +389,7 @@ def _prune(self, var, pruned_axis, pruned_idx): filter_var = self.op.inputs("Filter")[0] self.pruned_params.append((filter_var, 0, pruned_idx)) - key = "_".join([str(self.op.idx()), filter_var.name()]) - self.visited[0][key] = True + self._visit(filter_var, 0) new_groups = filter_var.shape()[0] - len(pruned_idx) self.op.set_attr("groups", new_groups) @@ -425,8 +417,7 @@ def _prune(self, var, pruned_axis, pruned_idx): self._prune_op(op, var, 0, pruned_idx) output_var = self.op.outputs("Output")[0] - key = "_".join([str(self.op.idx()), output_var.name()]) - self.visited[channel_axis][key] = True + self._visit(output_var, channel_axis) next_ops = output_var.outputs() for op in next_ops: self._prune_op(op, output_var, channel_axis, pruned_idx) @@ -436,8 +427,7 @@ def _prune(self, var, pruned_axis, pruned_idx): assert pruned_axis == channel_axis filter_var = self.op.inputs("Filter")[0] self.pruned_params.append((filter_var, 0, pruned_idx)) - key = "_".join([str(self.op.idx()), filter_var.name()]) - self.visited[0][key] = True + self._visit(filter_var, 0) new_groups = filter_var.shape()[0] - len(pruned_idx) op.set_attr("groups", new_groups) @@ -450,8 +440,7 @@ def _prune(self, var, pruned_axis, pruned_idx): (self.op.inputs("Bias")[0], channel_axis, pruned_idx)) in_var = self.op.inputs("Input")[0] - key = "_".join([str(self.op.idx()), in_var.name()]) - self.visited[channel_axis][key] = True + self._visit(in_var, channel_axis) pre_ops = in_var.inputs() for op in pre_ops: self._prune_op(op, in_var, channel_axis, pruned_idx)