Skip to content

Commit

Permalink
[microTVM] Add method to query template info without creating a proje…
Browse files Browse the repository at this point in the history
…ct (apache#8950)

Add info() method to TemplateProject class so it's possible to query all
available options for a given template project without creating a new
one. This is necessary because TVMC will query the available options for
a given template project to show them to the user so the user can use
them to finally create a new project dir.

That is also useful in general to query the available options for any
project type. For example, one can query all boards available on the
Zephyr platform with:

import tvm.micro.project as project_api

template = project_api.TemplateProject.from_directory(ZEPHYR_TEMPLATE_DIR)
boards = template.info()["project_options"][8]["choices"]

where 8 element refers to the "zephyr_board" option.

Signed-off-by: Gustavo Romero <gustavo.romero@linaro.org>
  • Loading branch information
gromero authored and ylc committed Jan 13, 2022
1 parent 4a0e472 commit b643981
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions python/tvm/micro/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,16 @@ class TemplateProject:
"""Defines a glue interface to interact with a template project through the API Server."""

@classmethod
def from_directory(cls, template_project_dir, options):
return cls(client.instantiate_from_dir(template_project_dir), options)
def from_directory(cls, template_project_dir):
return cls(client.instantiate_from_dir(template_project_dir))

def __init__(self, api_client, options):
def __init__(self, api_client):
self._api_client = api_client
self._options = options
self._info = self._api_client.server_info_query(__version__)
if not self._info["is_template"]:
raise NotATemplateProjectError()

def generate_project(self, graph_executor_factory, project_dir):
def generate_project(self, graph_executor_factory, project_dir, options):
"""Generate a project given GraphRuntimeFactory."""
model_library_dir = utils.tempdir()
model_library_format_path = model_library_dir.relpath("model.tar")
Expand All @@ -112,10 +111,13 @@ def generate_project(self, graph_executor_factory, project_dir):
model_library_format_path=model_library_format_path,
standalone_crt_dir=get_standalone_crt_dir(),
project_dir=project_dir,
options=self._options,
options=options,
)

return GeneratedProject.from_directory(project_dir, self._options)
return GeneratedProject.from_directory(project_dir, options)

def info(self):
return self._info


def generate_project(
Expand Down Expand Up @@ -147,5 +149,5 @@ def generate_project(
GeneratedProject :
A class that wraps the generated project and which can be used to further interact with it.
"""
template = TemplateProject.from_directory(str(template_project_dir), options)
return template.generate_project(module, str(generated_project_dir))
template = TemplateProject.from_directory(str(template_project_dir))
return template.generate_project(module, str(generated_project_dir), options)

0 comments on commit b643981

Please sign in to comment.