forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay][AutoTVM] Relay op strategy (apache#4644)
* relay op strategy fix lint bitpack strategy bitserial_dense (#6) * update strategy * address comments fix a few topi test Dense strategy (#5) * dense * add biforst; remove comments * address comment Refactor x86 conv2d_NCHWc (#4) * Refactor x86 conv2d * Add x86 depthwise_conv2d_NCHWc * Add back topi x86 conv2d_nchw * Merge x86 conv2d_nchw and conv2d_NCHWc * Minor fix for x86 conv2d fix more strategy Add x86 conv2d_NCHWc_int8 strategy (#8) * Add x86 conv2d_NCHWc_int8 strategy * Remove contrib_conv2d_nchwc_int8 * Fix generic conv2d_NCHWc for int8 * Fix topi arm_cpu conv2d_NCHWc_int8 update x86 conv2d enable specify relay ops to be tuned for autotvm add cuda conv2d strategy add conv2d strategy for rocm add conv2d strategy for hls add conv2d strategy for arm cpu add conv2d strategy for mali add conv2d strategy for bifrost add conv2d strategy for intel graphics clean up and fix lint remove template keys from autotvm remove 2 in the func name address comments fix * fix bugs * lint * address comments * add name to op implement * Modify topi tests (#9) * Add pooling, reorg, softmax and vision * Add lrn * fix topi test * fix more topi test * lint * address comments * x * fix more tests & bugs * Modify more tests (#10) * Modify tests for bitserial_conv2d, bitserial_dense, bitserial_conv2d_rasp and bnn * Minor fix * More minor fix * fix more test * try to update vta using strategy * fix cpptest * x * fix rebase err * Fix two tests (#11) * change autotvm log format * lint * minor fix * try fix vta test * fix rebase err * tweak * tmp hack for vta pass * fix tutorial * fix * fix more tutorials * fix vta tutorial * minor * address comments * fix * address comments * fix cpptest * fix docs * change data structure name and api * address comments * lint * fix rebase err * updates * fix winograd test * fix doc * rebase * upgrade tophub version number * fix bug * re-enable vta tsim test after tophub is upgraded * fix vta test to use the correct args so the config can be found in tophub Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
- Loading branch information
Showing
270 changed files
with
8,466 additions
and
7,067 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/relay/op_strategy.h | ||
* \brief The Relay operator Strategy and related data structure. | ||
*/ | ||
|
||
#ifndef TVM_RELAY_OP_STRATEGY_H_ | ||
#define TVM_RELAY_OP_STRATEGY_H_ | ||
|
||
#include <tvm/te/tensor.h> | ||
#include <tvm/te/schedule.h> | ||
#include <tvm/relay/expr.h> | ||
#include <tvm/relay/op_attr_types.h> | ||
#include <tvm/target/target.h> | ||
#include <string> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
/*! | ||
* \brief Operator implementation that includes compute and schedule function. | ||
*/ | ||
class OpImplementationNode : public Object { | ||
public: | ||
/*! \brief Compute function */ | ||
FTVMCompute fcompute; | ||
/*! \brief Schedule function */ | ||
FTVMSchedule fschedule; | ||
/*! \brief Name of the implementation */ | ||
std::string name; | ||
/*! \brief Priority level */ | ||
int plevel; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
v->Visit("name", &name); | ||
v->Visit("plevel", &plevel); | ||
} | ||
|
||
static constexpr const char* _type_key = "relay.OpImplementation"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object); | ||
}; | ||
|
||
/*! | ||
* \brief Operator implementation class. | ||
*/ | ||
class OpImplementation : public ObjectRef { | ||
public: | ||
/*! | ||
* \brief Invoke the operator compute function. | ||
* \param attrs The attribute of the primitive | ||
* \param inputs The input tensors. | ||
* \param out_type The output type information. | ||
* \return The output compute description of the operator. | ||
*/ | ||
TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs, | ||
const Array<te::Tensor>& inputs, | ||
const Type& out_type); | ||
/*! | ||
* \brief Build the computation schedule. | ||
* \param attrs The attribute of the node. | ||
* \param outs The output tensors. | ||
* \param target The build target. | ||
* \return The computation schedule. | ||
*/ | ||
TVM_DLL te::Schedule Schedule(const Attrs& attrs, | ||
const Array<te::Tensor>& outs, | ||
const Target& target); | ||
|
||
TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode); | ||
}; | ||
|
||
/*! | ||
* \brief Specialized implementations for operators under certain conditions. | ||
*/ | ||
class OpSpecializationNode : public Object { | ||
public: | ||
/*! \brief List of implementations. */ | ||
Array<OpImplementation> implementations; | ||
/*! \brief Condition to enable the specialization. | ||
* Could be undefined to represent generic case. */ | ||
te::SpecializedCondition condition; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
v->Visit("condition", &condition); | ||
v->Visit("implementations", &implementations); | ||
} | ||
|
||
static constexpr const char* _type_key = "relay.OpSpecialization"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode); | ||
}; | ||
|
||
/*! | ||
* \brief Operator specialization class. | ||
*/ | ||
class OpSpecialization : public ObjectRef { | ||
public: | ||
/*! | ||
* \brief Add an implementation. | ||
* \param fcompute Compute function | ||
* \param fschedule Schedule function | ||
* \param name Name of the implementation | ||
* \param plevel Priority level of the implementation | ||
*/ | ||
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, | ||
std::string name, int plevel); | ||
|
||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode); | ||
}; | ||
|
||
/*! | ||
* \brief Operator strategy to choose implementation. | ||
*/ | ||
class OpStrategyNode : public Object { | ||
public: | ||
/*! \brief List of operator specializations. */ | ||
Array<OpSpecialization> specializations; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
v->Visit("specializations", &specializations); | ||
} | ||
|
||
static constexpr const char* _type_key = "relay.OpStrategy"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode); | ||
}; | ||
|
||
/*! | ||
* \brief Operator strategy class. | ||
*/ | ||
class OpStrategy : public ObjectRef { | ||
public: | ||
/*! | ||
* \brief Add an implementation. | ||
* \param fcompute Compute function | ||
* \param fschedule Schedule function | ||
* \param name Name of the implementation | ||
* \param plevel Priority level of the implementation | ||
*/ | ||
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, | ||
std::string name, int plevel); | ||
|
||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode); | ||
}; | ||
|
||
} // namespace relay | ||
} // namespace tvm | ||
#endif // TVM_RELAY_OP_STRATEGY_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.