Skip to content

Check and only check the output varibles specified by self.outputs #4175

Closed
@Xreki

Description

@Xreki

很多Operator中都会用到一些中间输出变量,它们被声明为AsIntermediate()类型。根据输入的情况不同,有些中间输出变量实际不会被用到,比如在FCOp中:

  • 当输入变量XW包含多组数据时,计算的过程为
    • MulOut[i] = X[i] * W[i]
    • SumOut = MulOut[0] + ... + MulOut[n-1]
      此时需要用到MulOutSumOut两个中间输出变量。
  • 当输入变量XW只包含一组数据时,计算的过程为
    • MulOut[0] = X[0] * W[0]
      此时,我们不再需要使用SumOut

因此在单测中,这些不被用到的中间输出变量应该允许不指定,并且不进行检查

单测中,Op的输入输出都通过self.inputsself.outputs指定。在创建Op输入输出变量时,会遍历当前OpProto中所指定的输入输出变量,如果这些变量被self.inputsself.outputs指定了,则创建对应的Variable

    for out_name, out_dup in Operator.get_op_outputs(op_type):
        if out_name in outputs:
            kwargs[out_name] = []
            if out_dup:
                sub_out = outputs[out_name]
                for sub_out_name, _ in sub_out:
                    var = scope.new_var(sub_out_name)
                    kwargs[out_name].append(sub_out_name)
            else:
                var = scope.new_var(out_name)
                kwargs[out_name].append(out_name)

因此,在取输出变量校对结果时,也需要检查输出变量是否被self.outputs,因为没有被self.outputs指定的变量,很有可能是不需要的中间输出变量。至于Op计算实际需要,但是没有在self.outputs中指定的变量,则应由Op实现的C++代码实现检查和报错。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions