Skip to content

Commit

Permalink
Register mechanism for the Optimum CLI (#928)
Browse files Browse the repository at this point in the history
* [WIP] register mechanism for the CLI

* Register mechanism for the CLI

* [WORKING] Register mechanism for the CLI

* [DEBUG] Test add print

* [DEBUG] Remove print

* Add __init__.py to have a Python file inside optimum/command/register

* Fix test

* Trigger CI

* Add docstring

* Apply suggestions
  • Loading branch information
michaelbenayoun authored Mar 30, 2023
1 parent 9068f96 commit f934498
Show file tree
Hide file tree
Showing 16 changed files with 587 additions and 120 deletions.
18 changes: 5 additions & 13 deletions optimum/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from argparse import ArgumentParser


class BaseOptimumCLICommand(ABC):
@staticmethod
@abstractmethod
def register_subcommand(parser: ArgumentParser):
raise NotImplementedError()

@abstractmethod
def run(self):
raise NotImplementedError()
from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
from .env import EnvironmentCommand
from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand
from .onnxruntime import ONNXRuntimeCommand, ONNXRuntimmeOptimizeCommand, ONNXRuntimmeQuantizeCommand
from .optimum_cli import register_optimum_cli_subcommand
140 changes: 140 additions & 0 deletions optimum/commands/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimum command-line interface base classes."""

from abc import ABC
from argparse import ArgumentParser, RawTextHelpFormatter
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple, Type


if TYPE_CHECKING:
from argparse import Namespace, _SubParsersAction


@dataclass(frozen=True)
class CommandInfo:
name: str
help: str
subcommand_class: Optional[Type["BaseOptimumCLICommand"]] = None
formatter_class: Type = RawTextHelpFormatter

@property
def is_subcommand_info(self):
return self.subcommand_class is not None

def is_subcommand_info_or_raise(self):
if not self.is_subcommand_info:
raise ValueError(f"The command info must define a subcommand_class attribute, but got: {self}.")


class BaseOptimumCLICommand(ABC):
COMMAND: CommandInfo
SUBCOMMANDS: Tuple[CommandInfo, ...] = ()

def __init__(
self,
subparsers: Optional["_SubParsersAction"],
args: Optional["Namespace"] = None,
command: Optional[CommandInfo] = None,
from_defaults_factory: bool = False,
):
"""
Initializes the instance.
Args:
subparsers (`Optional[_SubParsersAction]`):
The parent subparsers this command will create its parser on.
args (`Optional[Namespace]`, defaults to `None`):
The arguments that are going to be parsed by the CLI.
command (`Optional[CommandInfo]`, defaults to `None`):
The command info for this instance. This can be used to set the class attribute `COMMAND`.
from_defaults_factory (`bool`, defaults to `False`):
When setting the parser defaults, we create a second instance of self. By setting
`from_defaults_factory=True`, we do not do unnecessary actions for setting the defaults, such as
creating a parser.
"""
if command is not None:
self.COMMAND = command

if from_defaults_factory:
self.parser = None
self.subparsers = subparsers
else:
if subparsers is None:
raise ValueError(f"A subparsers instance is needed when from_defaults_factory=False, command: {self}.")
self.parser = subparsers.add_parser(self.COMMAND.name, help=self.COMMAND.help)
self.parse_args(self.parser)

def defaults_factory(args):
return self.__class__(self.subparsers, args, command=self.COMMAND, from_defaults_factory=True)

self.parser.set_defaults(func=defaults_factory)

for subcommand in self.SUBCOMMANDS:
if not isinstance(subcommand, CommandInfo):
raise ValueError(f"Subcommands must be instances of CommandInfo, but got {type(subcommand)} here.")
self.register_subcommand(subcommand)

self.args = args

@property
def subparsers(self):
"""
This property handles how subparsers are created, which are only needed when registering a subcommand.
If `self` does not have any subcommand, no subparsers should be created or it will mess with the command.
This property ensures that we create subparsers only if needed.
"""
subparsers = getattr(self, "_subparsers", None)
if subparsers is None:
if self.SUBCOMMANDS:
if self.parser is not None:
self._subparsers = self.parser.add_subparsers()
else:
self._subparsers = None
else:
self._subparsers = None
return self._subparsers

@subparsers.setter
def subparsers(self, subparsers: Optional["_SubParsersAction"]):
self._subparsers = subparsers

@property
def registered_subcommands(self):
if not hasattr(self, "_registered_subcommands"):
self._registered_subcommands = []
return self._registered_subcommands

@staticmethod
def parse_args(parser: "ArgumentParser"):
pass

def register_subcommand(self, command_info: CommandInfo):
command_info.is_subcommand_info_or_raise()
self.SUBCOMMANDS = self.SUBCOMMANDS + (command_info,)
self.registered_subcommands.append(command_info.subcommand_class(self.subparsers, command=command_info))

def run(self):
raise NotImplementedError()


class RootOptimumCLICommand(BaseOptimumCLICommand):
COMMAND = CommandInfo(name="root", help="optimum-cli root command")

def __init__(self, cli_name: str, usage: Optional[str] = None, args: Optional["Namespace"] = None):
self.parser = ArgumentParser(cli_name, usage=usage)
self.subparsers = self.parser.add_subparsers()
self.args = None
18 changes: 5 additions & 13 deletions optimum/commands/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,21 @@
# limitations under the License.

import platform
from argparse import ArgumentParser

import huggingface_hub
from transformers import __version__ as transformers_version
from transformers.utils import is_tf_available, is_torch_available

from ..version import __version__ as version
from . import BaseOptimumCLICommand


def info_command_factory(_):
return EnvironmentCommand()
from . import BaseOptimumCLICommand, CommandInfo


class EnvironmentCommand(BaseOptimumCLICommand):
COMMAND = CommandInfo(name="env", help="Get information about the environment used.")

@staticmethod
def register_subcommand(parser: ArgumentParser):
download_parser = parser.add_parser("env", help="Get information about the environment used.")
download_parser.set_defaults(func=info_command_factory)
def format_dict(d):
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"

def run(self):
pt_version = "not installed"
Expand Down Expand Up @@ -69,7 +65,3 @@ def run(self):
print(self.format_dict(info))

return info

@staticmethod
def format_dict(d):
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
39 changes: 2 additions & 37 deletions optimum/commands/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from argparse import ArgumentParser, RawTextHelpFormatter

from ...exporters.onnx.__main__ import parse_args_onnx
from .. import BaseOptimumCLICommand
from .base import ExportCommand
from .onnx import ONNXExportCommand
from .tflite import TFLiteExportCommand, parse_args_tflite


def onnx_export_factory(args):
return ONNXExportCommand(args)


def tflite_export_factory(_):
return TFLiteExportCommand(" ".join(sys.argv[3:]))


class ExportCommand(BaseOptimumCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
export_parser = parser.add_parser(
"export", help="Export PyTorch and TensorFlow models to several format (currently supported: onnx)."
)
export_sub_parsers = export_parser.add_subparsers()

onnx_parser = export_sub_parsers.add_parser(
"onnx", help="Export PyTorch and TensorFlow to ONNX.", formatter_class=RawTextHelpFormatter
)

parse_args_onnx(onnx_parser)
onnx_parser.set_defaults(func=onnx_export_factory)

tflite_parser = export_sub_parsers.add_parser("tflite", help="Export TensorFlow to TensorFlow Lite.")

parse_args_tflite(tflite_parser)
tflite_parser.set_defaults(func=tflite_export_factory)

def run(self):
raise NotImplementedError()
from .tflite import TFLiteExportCommand
38 changes: 38 additions & 0 deletions optimum/commands/export/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""optimum.exporters command-line interface base classes."""

from .. import BaseOptimumCLICommand, CommandInfo
from .onnx import ONNXExportCommand
from .tflite import TFLiteExportCommand


class ExportCommand(BaseOptimumCLICommand):
COMMAND = CommandInfo(
name="export",
help="Export PyTorch and TensorFlow models to several format.",
)
SUBCOMMANDS = (
CommandInfo(
name="onnx",
help="Export PyTorch and TensorFlow to ONNX.",
subcommand_class=ONNXExportCommand,
),
CommandInfo(
name="tflite",
help="Export TensorFlow to TensorFlow Lite.",
subcommand_class=TFLiteExportCommand,
),
)
18 changes: 13 additions & 5 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,24 @@
# limitations under the License.
"""Defines the command line for the export with ONNX."""

from ...exporters.onnx.__main__ import main_export
from typing import TYPE_CHECKING

from ...exporters.onnx.__main__ import main_export, parse_args_onnx
from ...utils import DEFAULT_DUMMY_SHAPES
from ..base import BaseOptimumCLICommand


if TYPE_CHECKING:
from argparse import ArgumentParser


class ONNXExportCommand:
def __init__(self, args):
self.args = args
class ONNXExportCommand(BaseOptimumCLICommand):
@staticmethod
def parse_args(parser: "ArgumentParser"):
return parse_args_onnx(parser)

def run(self):
# get the shapes to be used to generate dummy inputs
# Get the shapes to be used to generate dummy inputs
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = getattr(self.args, input_name)
Expand Down
29 changes: 25 additions & 4 deletions optimum/commands/export/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@
"""Defines the command line for the export with TensorFlow Lite."""

import subprocess
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from ...exporters import TasksManager
from ...exporters.tflite import QuantizationApproach
from ..base import BaseOptimumCLICommand


def parse_args_tflite(parser):
if TYPE_CHECKING:
from argparse import ArgumentParser, Namespace, _SubParsersAction

from ..base import CommandInfo


def parse_args_tflite(parser: "ArgumentParser"):
required_group = parser.add_argument_group("Required arguments")
required_group.add_argument(
"-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from."
Expand Down Expand Up @@ -212,9 +221,21 @@ def parse_args_tflite(parser):
)


class TFLiteExportCommand:
def __init__(self, args_string):
self.args_string = args_string
class TFLiteExportCommand(BaseOptimumCLICommand):
def __init__(
self,
parser: "_SubParsersAction",
args: Optional["Namespace"] = None,
command: Optional["CommandInfo"] = None,
from_defaults_factory: bool = False,
):
super().__init__(parser, args, command=command, from_defaults_factory=from_defaults_factory)
# TODO: hack until TFLiteExportCommand does not use subprocess anymore.
self.args_string = " ".join(sys.argv[3:])

@staticmethod
def parse_args(parser: "ArgumentParser"):
return parse_args_tflite(parser)

def run(self):
full_command = f"python3 -m optimum.exporters.tflite {self.args_string}"
Expand Down
Loading

0 comments on commit f934498

Please sign in to comment.