Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Add Gradient Based Task Scheduler. #29

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ namespace meta_schedule {
*/
class TaskSchedulerNode : public runtime::Object {
public:
/*! \brief The function type of the objective function. */
using FObjectiveFunc = TypedPackedFunc<FloatImm(Array<FloatImm>)>;
/*! \brief The function type of the tag genration function. */
using FTagGenerationFunc = TypedPackedFunc<String(const IRModule&)>;

/*! \brief The tasks to be tuned */
Array<TuneContext> tasks;
/*! \brief The builder of the scheduler. */
Expand Down Expand Up @@ -259,6 +264,36 @@ class TaskScheduler : public runtime::ObjectRef {
Database database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks);
/*!
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
* \param tasks The tasks to be tuned.
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
* \param alpha The parameter alpha to control gradient computation.
* \param beta The parameter beta to control gradient computation.
* \param backward_window_size The parameter to control backward window size.
* \param seed The random seed.
* \param task_weights The weights of each task.
* \param objective_fun_namec The name of objective function for gradient optimization.
* \param tag_generation_func_name The name of function to generate similarity tag for workloads.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \return The task scheduler created.
*/
TVM_DLL static TaskScheduler GradientBased(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
double alpha, //
double beta, //
int backward_window_size, //
support::LinearCongruentialEngine::TRandState seed, //
Array<FloatImm> task_weights, //
String objective_func_name, //
String tag_generation_func_name, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks);
/*!
* \brief Create a task scheduler with customized methods on the python-side.
* \param tasks The tasks to be tuned.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/task_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
"""
from .task_scheduler import TaskScheduler, PyTaskScheduler
from .round_robin import RoundRobin
from .gradient_based import GradientBased
132 changes: 132 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/gradient_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Gradient Based Task Scheduler"""
import math

from typing import TYPE_CHECKING, List, Optional, Callable
from tvm._ffi import register_object
from tvm._ffi.registry import register_func

from tvm.ir import IRModule
from tvm.tir import Schedule
from tvm.tir.function import PrimFunc
from ..measure_callback import MeasureCallback
from ..builder import Builder
from ..runner import Runner
from ..database import Database
from ..cost_model import CostModel
from .task_scheduler import TaskScheduler

from .. import _ffi_api

if TYPE_CHECKING:
from ..tune_context import TuneContext


@register_func("meta_schedule.task_scheduler.derive_similarity_tag")
def derive_similarity_tag(mod: IRModule, log_base: float = 1.618):
ret = ""
for var in mod.get_global_vars():

if "meta_scheduler_task_scheduler_tag" in mod[var].attrs:
ret += mod[var].attrs.meta_scheduler_task_scheduler_tag + "_"
if ret:
flop_count = _ffi_api.TaskSchedulerFlopCount(mod) # type: ignore # pylint: disable=no-member
ret += "%d" % int(math.log(flop_count + 1, log_base))
return ret


@register_object("meta_schedule.GradientBased")
class GradientBased(TaskScheduler):
"""Gradient Based Task Scheduler"""

def __init__(
self,
tasks: List["TuneContext"],
builder: Builder,
runner: Runner,
database: Database,
*,
alpha: float = 0.2,
beta: float = 2.0,
backward_window_size: int = 3,
seed: int = -1,
task_weights: List[float] = None,
objective_func_name: str = "meta_schedule.task_scheduler.objective_func",
tag_generation_func_name: str = "meta_schedule.task_scheduler.derive_similarity_tag",
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
) -> None:
@register_func("meta_schedule.task_scheduler.objective_func")
def weighted_sum(l: List[float]) -> float:
return sum([l[i] * w for i, w in enumerate(self.task_weights)])

"""Constructor.

Parameters
----------
tasks : List[TuneContext]
List of tasks to schedule.
builder : Builder
The builder.
runner : Runner
The runner.
database : Database
The database.
alpha: float
The parameter alpha to control gradient computation.
beta: float
The parameter beta to control gradient computation.
backward_window_size: int
The parameter to control backward window size.
seed: int
The random seed.
task_weights: Optional[List[float]]
The weights of each task.
objective_func_name:
The name of objective function for gradient optimization.
tag_generation_func_name:
The name of function to generate similarity tag for workloads.
cost_model: CostModel
The cost model of the scheduler.
measure_callbacks: Optional[List[MeasureCallback]]
The list of measure callbacks of the scheduler.
"""
if task_weights is None:
task_weights = [1.0 for _ in tasks]
self.task_weights = task_weights

assert len(task_weights) == len(
tasks
), "The given task weights should be same length as tasks."

self.__init_handle_by_constructor__(
_ffi_api.TaskSchedulerGradientBased, # type: ignore # pylint: disable=no-member
tasks,
builder,
runner,
database,
alpha,
beta,
backward_window_size,
seed,
task_weights,
objective_func_name,
tag_generation_func_name,
cost_model,
measure_callbacks,
)
199 changes: 1 addition & 198 deletions src/meta_schedule/measure_callback/echo_statistics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,204 +20,6 @@

#include "../utils.h"

namespace tvm {
namespace tir {

double CountFlop(const IRModule& mod) {
struct TResult {
using TTable = std::unordered_map<int32_t, double>;

TResult() = default;

explicit TResult(const tvm::DataType& dtype) { Add(dtype); }

void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; }

TResult operator+=(const TResult& rhs) {
for (const auto& kv : rhs.data_) {
data_[kv.first] += kv.second;
}
return *this;
}

TResult operator*=(int64_t rhs) {
for (auto& kv : data_) {
kv.second *= rhs;
}
return *this;
}

TResult MaxWith(const TResult& rhs) {
for (const auto& kv : rhs.data_) {
double& v = data_[kv.first];
if (v < kv.second) {
v = kv.second;
}
}
return *this;
}

struct DType {
uint8_t code : 8;
uint8_t bits : 8;
uint16_t lanes : 16;
};
static_assert(sizeof(DType) == 4, "Incorrect size of DType");

static String Int2Str(int32_t dtype) {
union {
DType dst;
int32_t src;
} converter;
converter.src = dtype;
static std::string type_code_tab[] = {"int", "uint", "float", "handle", "bfloat"};
std::ostringstream os;
os << type_code_tab[converter.dst.code];
os << static_cast<int>(converter.dst.bits);
if (converter.dst.lanes != 1) {
os << "x" << static_cast<int>(converter.dst.lanes);
}
return os.str();
}

static int32_t DataType2Int(const tvm::DataType& dtype) {
union {
DType src;
int32_t dst;
} converter;
converter.src.code = dtype.code();
converter.src.bits = dtype.bits();
converter.src.lanes = dtype.lanes();
return converter.dst;
}

TTable data_;
};

class FlopCounter : public ExprFunctor<TResult(const PrimExpr& n)>,
public StmtFunctor<TResult(const Stmt& n)> {
public:
~FlopCounter() {}

TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); }
TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); }

TResult VisitStmt_(const IfThenElseNode* branch) override {
TResult cond = VisitExpr(branch->condition);
cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case));
return cond;
}

TResult VisitStmt_(const BufferStoreNode* store) override {
TResult result = VisitExpr(store->value);
for (const PrimExpr& e : store->indices) {
result += VisitExpr(e);
}
return result;
}

TResult VisitStmt_(const SeqStmtNode* seq) override {
TResult result;
for (const Stmt& stmt : seq->seq) {
result += VisitStmt(stmt);
}
return result;
}

TResult VisitStmt_(const BlockRealizeNode* block) override {
return VisitStmt(block->block->body);
}

TResult VisitStmt_(const BlockNode* block) override {
TResult result;
if (block->init.defined()) {
result += VisitStmt(block->init.value());
}
result += VisitStmt(block->body);
return result;
}

TResult VisitStmt_(const ForNode* loop) override {
TResult result = VisitStmt(loop->body);
const auto* int_imm = loop->extent.as<IntImmNode>();
ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: "
<< loop->extent->GetTypeKey();
result *= int_imm->value;
return result;
}

#define TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(Node) \
TResult VisitExpr_(const Node* op) final { \
TResult result(op->dtype); \
result += VisitExpr(op->a); \
result += VisitExpr(op->b); \
return result; \
}
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AddNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(SubNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MulNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(DivNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(ModNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorDivNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorModNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MinNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MaxNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(EQNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(NENode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LTNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LENode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GTNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GENode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AndNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(OrNode);
#undef TVM_META_SCHEDULE_FLOP_COUNTER_BINARY
TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); }
TResult VisitExpr_(const VarNode* op) override { return TResult(); }
TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); }
TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); }
TResult VisitExpr_(const IntImmNode* op) override { return TResult(); }
TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); }
TResult VisitExpr_(const NotNode* op) override {
TResult result(op->dtype);
result += VisitExpr(op->a);
return result;
}
TResult VisitExpr_(const SelectNode* op) override {
TResult cond = VisitExpr(op->condition);
cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value));
return cond;
}
TResult VisitExpr_(const CallNode* op) override {
TResult ret;
for (const auto& x : op->args) {
ret += VisitExpr(x);
}
return ret;
}
};
FlopCounter counter;
TResult result;
for (const auto& kv : mod->functions) {
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
result += counter.VisitStmt(prim_func->body);
}
}
double cnt = 0.0;
int i32 = TResult::DataType2Int(tvm::DataType::Int(32));
int i64 = TResult::DataType2Int(tvm::DataType::Int(64));
int u1 = TResult::DataType2Int(tvm::DataType::UInt(1));
for (const auto& kv : result.data_) {
if (kv.first != i32 && kv.first != i64 && kv.first != u1) {
cnt += kv.second;
}
}
return cnt;
}

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

Expand Down Expand Up @@ -312,6 +114,7 @@ class EchoStatisticsNode : public MeasureCallbackNode {
for (const TuneContext& task : tasks) {
task_info.push_back(TaskInfo(GetTaskName(task, task_id)));
TaskInfo& info = task_info.back();
// todo(@zxybazh): Avoid recount task flops
info.flop = tir::CountFlop(task->mod.value());
++task_id;
}
Expand Down
Loading