-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Meta Schedule][M3a] TuneContext (#9053)
* Add TuneContext class. Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> * Add tune context test. * Add meta_schedule to cmake. * Add type. * Rebase. * Disable MyPy for ethosu. * Add new line. * Remove duplicate line. * Minor fix. * Add comments. Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
- Loading branch information
1 parent
9258b96
commit cd15b79
Showing
7 changed files
with
306 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Meta Schedule tuning context.""" | ||
|
||
from typing import Optional | ||
|
||
from tvm import IRModule | ||
from tvm.runtime import Object | ||
from tvm.target import Target | ||
from tvm.meta_schedule.utils import cpu_count | ||
from tvm._ffi import register_object | ||
|
||
from . import _ffi_api | ||
|
||
|
||
@register_object("meta_schedule.TuneContext") | ||
class TuneContext(Object): | ||
""" | ||
The tune context class is designed to contain all resources for a tuning task. | ||
Different tuning tasks are separated in different TuneContext classes, but different classes in | ||
the same task can interact with each other through tune context. Most classes have a function | ||
to initialize with a tune context. | ||
Parameters | ||
---------- | ||
mod : Optional[IRModule] = None | ||
The workload to be optimized. | ||
target : Optional[Target] = None | ||
The target to be optimized for. | ||
task_name : Optional[str] = None | ||
The name of the tuning task. | ||
rand_state : int = -1 | ||
The random state. | ||
Need to be in integer in [1, 2^31-1], -1 means using random number. | ||
num_threads : int = None | ||
The number of threads to be used, None means using the logical cpu count. | ||
Note | ||
---- | ||
In most cases, mod and target should be available in the tuning context. They are "Optional" | ||
because we allow the user to customize the tuning context, along with other classes, sometimes | ||
without mod and target. E.g., we can have a stand alone search strategy that generates measure | ||
candidates without initializing with the tune context. | ||
""" | ||
|
||
mod: Optional[IRModule] | ||
target: Optional[Target] | ||
task_name: Optional[str] | ||
rand_state: int | ||
num_threads: int | ||
|
||
def __init__( | ||
self, | ||
mod: Optional[IRModule] = None, | ||
target: Optional[Target] = None, | ||
task_name: Optional[str] = None, | ||
rand_state: int = -1, | ||
num_threads: Optional[int] = None, | ||
): | ||
"""Constructor. | ||
Parameters | ||
---------- | ||
mod : Optional[IRModule] = None | ||
The workload to be optimized. | ||
target : Optional[Target] = None | ||
The target to be optimized for. | ||
task_name : Optional[str] = None | ||
The name of the tuning task. | ||
rand_state : int = -1 | ||
The random state. | ||
Need to be in integer in [1, 2^31-1], -1 means using random number. | ||
num_threads : Optional[int] = None | ||
The number of threads to be used, None means using the logical cpu count. | ||
""" | ||
if num_threads is None: | ||
num_threads = cpu_count() | ||
|
||
self.__init_handle_by_constructor__( | ||
_ffi_api.TuneContext, # type: ignore # pylint: disable=no-member | ||
mod, | ||
target, | ||
task_name, | ||
rand_state, | ||
num_threads, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#include "./tune_context.h" | ||
|
||
#include <random> | ||
#include <utility> | ||
|
||
namespace tvm { | ||
namespace meta_schedule { | ||
|
||
/*! | ||
* \brief Constructor function of TuneContext class. | ||
* \param mod The mod to be optimized. | ||
* \param target The target to be optimized for. | ||
* \param task_name The name of the tuning task. | ||
* \param rand_state The random state. | ||
* \param num_threads The number of threads to be used. | ||
* \param verbose The verbosity level. | ||
*/ | ||
TuneContext::TuneContext(Optional<IRModule> mod, // | ||
Optional<Target> target, // | ||
Optional<String> task_name, // | ||
support::LinearCongruentialEngine::TRandState rand_state, // | ||
int num_threads) { | ||
ObjectPtr<TuneContextNode> n = make_object<TuneContextNode>(); | ||
n->mod = mod; | ||
n->target = target; | ||
n->task_name = task_name; | ||
if (rand_state == -1) { | ||
rand_state = std::random_device()(); | ||
} | ||
support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state); | ||
n->num_threads = num_threads; | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_NODE_TYPE(TuneContextNode); | ||
|
||
TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") | ||
.set_body_typed([](Optional<IRModule> mod, // | ||
Optional<Target> target, // | ||
Optional<String> task_name, // | ||
support::LinearCongruentialEngine::TRandState rand_state, // | ||
int num_threads) -> TuneContext { | ||
return TuneContext(mod, target, task_name, rand_state, num_threads); | ||
}); | ||
} // namespace meta_schedule | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#ifndef TVM_META_SCHEDULE_TUNE_CONTEXT_H_ | ||
#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ | ||
|
||
#include <tvm/ir/module.h> | ||
#include <tvm/support/random_engine.h> | ||
#include <tvm/target/target.h> | ||
|
||
namespace tvm { | ||
namespace meta_schedule { | ||
|
||
/*! \brief The auto tuning context. */ | ||
class TuneContextNode : public runtime::Object { | ||
public: | ||
/*! \brief The workload to be tuned. */ | ||
Optional<IRModule> mod; | ||
/*! \brief The target to be tuned for. */ | ||
Optional<Target> target; | ||
/*! \brief The name of the tuning task. */ | ||
Optional<String> task_name; | ||
/*! \brief The random state. */ | ||
support::LinearCongruentialEngine::TRandState rand_state; | ||
/*! \brief The number of threads to be used. */ | ||
int num_threads; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
v->Visit("mod", &mod); | ||
v->Visit("target", &target); | ||
v->Visit("task_name", &task_name); | ||
v->Visit("rand_state", &rand_state); | ||
v->Visit("num_threads", &num_threads); | ||
} | ||
|
||
static constexpr const char* _type_key = "meta_schedule.TuneContext"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to TuneContextNode. | ||
* \sa TuneContextNode | ||
*/ | ||
class TuneContext : public runtime::ObjectRef { | ||
public: | ||
/*! | ||
* \brief Constructor. | ||
* \param mod The workload to be tuned. | ||
* \param target The target to be tuned for. | ||
* \param task_name The name of the tuning task. | ||
* \param rand_state The random state. | ||
* \param num_threads The number of threads to be used. | ||
*/ | ||
TVM_DLL explicit TuneContext(Optional<IRModule> mod, // | ||
Optional<Target> target, // | ||
Optional<String> task_name, // | ||
support::LinearCongruentialEngine::TRandState rand_state, // | ||
int num_threads); | ||
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode); | ||
}; | ||
|
||
} // namespace meta_schedule | ||
} // namespace tvm | ||
|
||
#endif // TVM_META_SCHEDULE_TUNE_CONTEXT_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Test the tune context of meta schedule.""" | ||
|
||
import sys | ||
import pytest | ||
|
||
import tvm | ||
from tvm import tir | ||
from tvm.script import ty | ||
from tvm.target import Target | ||
from tvm.meta_schedule import TuneContext | ||
|
||
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring | ||
|
||
|
||
@tvm.script.tir | ||
class Matmul: | ||
def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument | ||
tir.func_attr({"global_symbol": "main", "tir.noalias": True}) | ||
A = tir.match_buffer(a, (1024, 1024), "float32") | ||
B = tir.match_buffer(b, (1024, 1024), "float32") | ||
C = tir.match_buffer(c, (1024, 1024), "float32") | ||
with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: | ||
with tir.init(): | ||
C[vi, vj] = 0.0 | ||
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] | ||
|
||
|
||
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring | ||
|
||
|
||
def test_tune_context_create(): | ||
mod = Matmul() | ||
context = TuneContext(mod, Target("llvm"), "Test Task") | ||
assert context.num_threads > 0 | ||
assert context.rand_state != -1 | ||
assert context.task_name == "Test Task" | ||
assert context.mod == mod or tvm.ir.structural_equal(context.mod, mod) | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(pytest.main([__file__] + sys.argv[1:])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters