Skip to content

Tensorflow InferShap分析

Qiao Longfei edited this page Sep 27, 2017 · 5 revisions

比较直接的文档在:https://www.tensorflow.org/extend/adding_an_op

op_def_builder中定义了REGISTER_OP所调用的各种方法。 例如

REGISTER_OP("ResourceApplyAdagrad")
    .Input("var: resource")
    .Input("accum: resource")
    .Input("lr: T")
    .Input("grad: T")
    .Attr("T: numbertype")
    .Attr("use_locking: bool = false")
    .SetShapeFn([](InferenceContext* c) {
      return ApplyAdagradShapeFn(c, false /* sparse */);
    })
    .Doc(R"doc(
Update '*var' according to the adagrad scheme.

accum += grad * grad
var -= lr * grad * (1 / sqrt(accum))

var: Should be from a Variable().
accum: Should be from a Variable().
lr: Scaling factor. Must be a scalar.
grad: The gradient.
use_locking: If `True`, updating of the var and accum tensors will be protected
  by a lock; otherwise the behavior is undefined, but may exhibit less
  contention.
)doc");

构造了一个OpRegistrationData数据结构:

struct OpRegistrationData {
 public:
  OpRegistrationData() {}
  OpRegistrationData(const OpDef& def) : op_def(def) {}

  OpDef op_def;
  OpShapeInferenceFn shape_inference_fn;
};

主要是填写了op_def和shape_inference_fn这两个成员变量

调用过程

python call

call_cpp_shape_fn

_call_cpp_shape_fn_impl

--> RunCppShapeInference(tensorflow/tensorflow/python/framework/cpp_shape_inference.cc)

--> RunCppShapeInferenceImpl(tensorflow/tensorflow/python/framework/cpp_shape_inference.cc)

---->

  const OpRegistrationData* op_reg_data;
  TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(node.op(), &op_reg_data));

  if (op_reg_data->shape_inference_fn == nullptr) {
    return errors::InvalidArgument(
        "No shape inference function exists for op '", node.op(),
        "', did you forget to define it?");
  }

---> OpRegistry::LookUp(tensorflow/tensorflow/core/framework/op.h)

--->

class OpListOpRegistry : public OpRegistryInterface {
 public:
  // Does not take ownership of op_list, *op_list must outlive *this.
  OpListOpRegistry(const OpList* op_list);
  ~OpListOpRegistry() override;
  Status LookUp(const string& op_type_name,
                const OpRegistrationData** op_reg_data) const override;

 private:
  // Values are owned.
  std::unordered_map<string, const OpRegistrationData*> index_;
};

结论

OpRegistry里边存了一个全局Map,里边是

std::unordered_map<string, const OpRegistrationData*> index_;

而 OpRegistrationData 主要保存了

struct OpRegistrationData {
  OpDef op_def;
  OpShapeInferenceFn shape_inference_fn;
}
Clone this wiki locally