forked from PaddlePaddle/Paddle
-
Couldn't load subscription status.
- Fork 0
Tensorflow InferShap分析
Qiao Longfei edited this page Sep 27, 2017
·
5 revisions
比较直接的文档在:https://www.tensorflow.org/extend/adding_an_op
op_def_builder中定义了REGISTER_OP所调用的各种方法。 例如
REGISTER_OP("ResourceApplyAdagrad")
.Input("var: resource")
.Input("accum: resource")
.Input("lr: T")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, false /* sparse */);
})
.Doc(R"doc(
Update '*var' according to the adagrad scheme.
accum += grad * grad
var -= lr * grad * (1 / sqrt(accum))
var: Should be from a Variable().
accum: Should be from a Variable().
lr: Scaling factor. Must be a scalar.
grad: The gradient.
use_locking: If `True`, updating of the var and accum tensors will be protected
by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
)doc");构造了一个OpRegistrationData数据结构:
struct OpRegistrationData {
public:
OpRegistrationData() {}
OpRegistrationData(const OpDef& def) : op_def(def) {}
OpDef op_def;
OpShapeInferenceFn shape_inference_fn;
};主要是填写了op_def和shape_inference_fn这两个成员变量
python call
--> RunCppShapeInference(tensorflow/tensorflow/python/framework/cpp_shape_inference.cc)
--> RunCppShapeInferenceImpl(tensorflow/tensorflow/python/framework/cpp_shape_inference.cc)
---->
const OpRegistrationData* op_reg_data;
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(node.op(), &op_reg_data));
if (op_reg_data->shape_inference_fn == nullptr) {
return errors::InvalidArgument(
"No shape inference function exists for op '", node.op(),
"', did you forget to define it?");
}---> OpRegistry::LookUp(tensorflow/tensorflow/core/framework/op.h)
--->
class OpListOpRegistry : public OpRegistryInterface {
public:
// Does not take ownership of op_list, *op_list must outlive *this.
OpListOpRegistry(const OpList* op_list);
~OpListOpRegistry() override;
Status LookUp(const string& op_type_name,
const OpRegistrationData** op_reg_data) const override;
private:
// Values are owned.
std::unordered_map<string, const OpRegistrationData*> index_;
};OpRegistry里边存了一个全局Map,里边是
std::unordered_map<string, const OpRegistrationData*> index_;而 OpRegistrationData 主要保存了
struct OpRegistrationData {
OpDef op_def;
OpShapeInferenceFn shape_inference_fn;
}