Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Enable BertTuning with MetaScheduler #11

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class MODEL_TYPE(Enum): # pylint: disable=invalid-name
VIDEO_CLASSIFICATION = (2,)
SEGMENTATION = (3,)
OBJECT_DETECTION = (4,)
TEXT_CLASSIFICATION = (5,)


# Specify the type of each model
Expand Down Expand Up @@ -95,6 +96,11 @@ class MODEL_TYPE(Enum): # pylint: disable=invalid-name
"r3d_18": MODEL_TYPE.VIDEO_CLASSIFICATION,
"mc3_18": MODEL_TYPE.VIDEO_CLASSIFICATION,
"r2plus1d_18": MODEL_TYPE.VIDEO_CLASSIFICATION,
# Text classification
"bert_tiny": MODEL_TYPE.TEXT_CLASSIFICATION,
"bert_base": MODEL_TYPE.TEXT_CLASSIFICATION,
"bert_medium": MODEL_TYPE.TEXT_CLASSIFICATION,
"bert_large": MODEL_TYPE.TEXT_CLASSIFICATION,
}


Expand All @@ -121,6 +127,8 @@ def get_torch_model(

import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel
from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel
import transformers # type: ignore # pylint: disable=import-error,import-outside-toplevel
import os # type: ignore # pylint: disable=import-error,import-outside-toplevel

def do_trace(model, inp):
model_trace = torch.jit.trace(model, inp)
Expand All @@ -136,6 +144,50 @@ def do_trace(model, inp):
model = getattr(models.detection, model_name)()
elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION:
model = getattr(models.video, model_name)()
elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
config_dict = {
"bert_tiny": transformers.BertConfig(
num_hidden_layers=6,
hidden_size=512,
intermediate_size=2048,
num_attention_heads=8,
return_dict=False,
),
"bert_base": transformers.BertConfig(
num_hidden_layers=12,
hidden_size=768,
intermediate_size=3072,
num_attention_heads=12,
return_dict=False,
),
"bert_medium": transformers.BertConfig(
num_hidden_layers=12,
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
return_dict=False,
),
"bert_large": transformers.BertConfig(
num_hidden_layers=24,
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
return_dict=False,
),
}
configuration = config_dict[model_name]
model = transformers.BertModel(configuration)
input_name = "input_ids"
A = torch.randint(10000, input_shape)

model.eval()
scripted_model = torch.jit.trace(model, [A], strict=False)

input_name = "input_ids"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
return mod, params
else:
raise ValueError("Unsupported model in Torch model zoo.")

Expand Down
8 changes: 4 additions & 4 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ class PostOrderApplyNode : public SpaceGeneratorNode {

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod_) final {
using ScheduleAndUnvisitedBlocks = std::pair<tir::Schedule, Array<tir::BlockRV>>;
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/mod_, //
/*rand_state=*/ForkSeed(&this->rand_state_), //
/*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, //
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/mod_, //
/*rand_state=*/ForkSeed(&this->rand_state_), //
/*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags,
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);

std::vector<ScheduleAndUnvisitedBlocks> stack;
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void TaskSchedulerNode::Tune() {

int running_tasks = tasks.size();
for (int task_id; (task_id = NextTaskId()) != -1;) {
LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name;
LOG(INFO) << "Scheduler picks Task #" << task_id + 1 << ": " << tasks[task_id]->task_name;
TuneContext task = tasks[task_id];
ICHECK(!task->is_stopped);
ICHECK(!task->runner_futures.defined());
Expand Down
12 changes: 8 additions & 4 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


@pytest.mark.skip("Integration test")
@pytest.mark.parametrize("model_name", ["resnet18"])
@pytest.mark.parametrize("model_name", ["resnet18", "mobilenet_v2", "bert_base"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("target", ["llvm --num-cores=16", "nvidia/geforce-rtx-3070"])
def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str):
Expand All @@ -47,6 +47,9 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str)
input_shape = (1, 3, 300, 300)
elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION:
input_shape = (batch_size, 3, 3, 299, 299)
elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
seq_length = 128
input_shape = (batch_size, seq_length)
else:
raise ValueError("Unsupported model: " + model_name)
output_shape: Tuple[int, int] = (batch_size, 1000)
Expand All @@ -71,7 +74,7 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str)
work_dir=work_dir,
)
for i, sch in enumerate(schs):
print("-" * 10 + f" Part {i}/{len(schs)} " + "-" * 10)
print("-" * 10 + f" Part {i+1}/{len(schs)} " + "-" * 10)
if sch is None:
print("No valid schedule found!")
else:
Expand All @@ -80,5 +83,6 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str)


if __name__ == """__main__""":
test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16")
test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070")
# test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16")
# test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070")
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16")