Skip to content

Commit

Permalink
[mulitapi combiner] combine models in multiapi sdk (Azure#28718)
Browse files Browse the repository at this point in the history
  • Loading branch information
iscai-msft authored Mar 20, 2023
1 parent 9c908dd commit 10a4783
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 12 deletions.
127 changes: 118 additions & 9 deletions tools/azure-sdk-tools/packaging_tools/multiapi_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import argparse
from pathlib import Path
import shutil
from typing import Dict, Optional, List, Any, TypeVar, Callable
from typing import Dict, Optional, List, Any, TypeVar, Callable, Set

from jinja2 import PackageLoader, Environment

Expand All @@ -34,6 +34,10 @@ def modify_relative_imports(regex: str, file: str) -> str:
return file.replace(original_str, new_str)


def strip_version_from_docs(input: str) -> str:
return re.sub(r".v20[^.]*", "", input)


class VersionedObject:
"""An object that can be added / removed in an api version"""

Expand Down Expand Up @@ -77,6 +81,19 @@ def _combine_helper(
return objs


def _sort_models_helper(current: "ModelAndEnum", seen_model_names: Set[str]) -> List["ModelAndEnum"]:
if current.name in seen_model_names:
return []
ancestors: List["ModelAndEnum"] = [current]
for parent in current.parents:
if parent.name in seen_model_names:
continue
ancestors = _sort_models_helper(parent, seen_model_names) + ancestors
seen_model_names.add(parent.name)
seen_model_names.add(current.name)
return ancestors


class Parameter(VersionedObject):
def __init__(
self,
Expand Down Expand Up @@ -105,7 +122,7 @@ def __init__(
self._request_builder: Optional[str] = None

def source_code(self, async_mode: bool) -> str:
return inspect.getsource(self._get_op(self.api_versions[-1], async_mode))
return strip_version_from_docs(inspect.getsource(self._get_op(self.api_versions[-1], async_mode)))

@property
def request_builder_name(self) -> Optional[str]:
Expand Down Expand Up @@ -256,6 +273,9 @@ def _get_operation(code_model: "CodeModel", name: str) -> Operation:
get_names_by_api_version=_get_names_by_api_version,
)

def doc(self, async_mode: bool) -> str:
return strip_version_from_docs(self.generated_class(async_mode).__doc__)


class Client:
def __init__(self, code_model: "CodeModel") -> None:
Expand All @@ -279,6 +299,29 @@ def name(self) -> str:
return list(self.code_model.api_version_to_metadata.values())[-1]["client"]["name"]


class ModelAndEnum(VersionedObject):
def __init__(self, code_model: "CodeModel", name: str) -> None:
super().__init__(code_model, name)
self._parents: List["ModelAndEnum"] = []

@property
def generated_class(self):
folder_api_version = self.code_model.api_version_to_folder_api_version[self.api_versions[-1]]
module = importlib.import_module(f"{self.code_model.module_name}.{folder_api_version}.models")
return getattr(module, self.name)

@property
def source_code(self) -> str:
return strip_version_from_docs(inspect.getsource(self.generated_class))

@property
def parents(self) -> List["ModelAndEnum"]:
if not self._parents:
for parent in self.generated_class.__mro__[1 : len(self.generated_class.__mro__) - 2]:
self._parents.append(self.code_model.models[parent.__name__])
return self._parents


class CodeModel:
def __init__(self, pkg_path: Path):
self._root_of_code = pkg_path
Expand All @@ -297,6 +340,9 @@ def __init__(self, pkg_path: Path):
self.default_folder_api_version = self.api_version_to_folder_api_version[self.default_api_version]
self.module_name = pkg_path.stem.replace("-", ".")
self.operation_groups = self._combine_operation_groups()
self.models: Dict[str, ModelAndEnum] = {}
self.enums: List[ModelAndEnum] = []
self._combine_models_and_enums()
self.client = Client(self)

def get_root_of_code(self, async_mode: bool) -> Path:
Expand Down Expand Up @@ -343,6 +389,35 @@ def _get_operation_group(code_model: "CodeModel", name: str):
operation.combine_parameters()
return ogs

def _combine_models_and_enums(self) -> None:
def _get_model(code_model: "CodeModel", name: str) -> ModelAndEnum:
return ModelAndEnum(code_model, name)

def _get_names_by_api_version(api_version: str):
folder_api_version = self.api_version_to_folder_api_version[api_version]
module = importlib.import_module(f"{self.module_name}.{folder_api_version}.models")
return [m for m in dir(module) if m[0] != "_"]

models_and_enums = _combine_helper(
code_model=self,
sorted_api_versions=self.sorted_api_versions,
get_cls=_get_model,
get_names_by_api_version=_get_names_by_api_version,
)
for m in models_and_enums:
if hasattr(m.generated_class, "from_dict"):
self.models[m.name] = m
else:
self.enums.append(m)
self._sort_models()

def _sort_models(self) -> None:
seen_model_names: Set[str] = set()
sorted_models: Dict[str, ModelAndEnum] = {}
for model in self.models.values():
sorted_models.update({m.name: m for m in _sort_models_helper(model, seen_model_names)})
self.models = sorted_models


class Serializer:
def __init__(self, code_model: "CodeModel") -> None:
Expand Down Expand Up @@ -486,7 +561,9 @@ def serialize_client(self, async_mode: bool):

main_client_source = "class" + "class".join(split_main_client_source[1:])

client_initialization = re.search(r"((?s).*?) @classmethod", main_client_source).group(1)
client_initialization = strip_version_from_docs(
re.search(r"((?s).*?) @classmethod", main_client_source).group(1)
)

# TODO: switch to current file path
with open(f"{self.code_model.get_root_of_code(async_mode)}/_client.py", "w") as fd:
Expand Down Expand Up @@ -532,12 +609,43 @@ def serialize_general(self):
with open(f"{self.code_model.get_root_of_code(False)}/_validation.py", "w") as fd:
fd.write(self.env.get_template("validation.py.jinja2").render())

def serialize_models_folder(self):
# serialize init file
models_folder = self.code_model.get_root_of_code(False) / "models"
Path(models_folder).mkdir(parents=True, exist_ok=True)
with open(f"{models_folder}/__init__.py", "w") as fd:
fd.write(self.env.get_template("models_init.py.jinja2").render(code_model=self.code_model))
default_api_version = self.code_model.default_folder_api_version
default_models_folder_name = f"{self.code_model.module_name}.{default_api_version}.models"

# serialize models file
default_models_module = importlib.import_module(f"{default_models_folder_name}._models_py3")
imports = inspect.getsource(default_models_module).split("class")[0]
imports = modify_relative_imports(r"from (.*) import _serialization", imports)
with open(f"{models_folder}/_models.py", "w") as fd:
fd.write(self.env.get_template("models.py.jinja2").render(code_model=self.code_model, imports=imports))

# serialize enums file
default_enums_module = importlib.import_module(
f"{default_models_folder_name}.{self.code_model.client.generated_filename}_enums"
)
imports = inspect.getsource(default_enums_module).split("class")[0]
if self.code_model.enums:
with open(f"{models_folder}/_enums.py", "w") as fd:
fd.write(self.env.get_template("enums.py.jinja2").render(code_model=self.code_model, imports=imports))

# serialize patch file
with open(f"{models_folder}/_patch.py", "w") as wfd:
with open(f"{self.code_model.get_root_of_code(False)}/{default_api_version}/models/_patch.py", "r") as rfd:
wfd.write(rfd.read())

def remove_versioned_files(self):
root_of_code = self.code_model.get_root_of_code(False)
for api_version_folder_stem in self.code_model.api_version_to_folder_api_version.values():
api_version_folder = root_of_code / api_version_folder_stem
shutil.rmtree(api_version_folder / Path("operations"), ignore_errors=True)
shutil.rmtree(api_version_folder / Path("aio"), ignore_errors=True)
shutil.rmtree(api_version_folder / Path("models"), ignore_errors=True)
files_to_remove = [
"__init__.py",
"_configuration.py",
Expand All @@ -551,12 +659,13 @@ def remove_versioned_files(self):
for file in files_to_remove:
os.remove(f"{api_version_folder}/{file}")

# add empty init file so we can still see the models folder
with open(f"{api_version_folder}/__init__.py", "w") as f:
f.write("")

def remove_top_level_files(self, async_mode: bool):
top_level_files = [self.code_model.client.generated_filename, "_operations_mixin"]
top_level_files = [
self.code_model.client.generated_filename,
"_operations_mixin",
]
if not async_mode:
top_level_files.append("models")
for file in top_level_files:
os.remove(f"{self.code_model.get_root_of_code(async_mode)}/{file}.py")

Expand All @@ -571,8 +680,8 @@ def serialize(self):
self.serialize_client(async_mode=False)
self.serialize_client(async_mode=True)
self.serialize_general()
self.serialize_models_folder()
self.remove_old_code()
# self.serialize_models_file()


def get_args() -> argparse.Namespace:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ from {{ ".." if async_mode else "." }}_validation import api_version_validation

{{ getsource(generated_client._models_dict) }}

{{ getsource(generated_client.models) }}

{% for operation_group in operation_group_properties %}

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{{ imports }}

{% for enum in code_model.enums %}
{{ enum.source_code }}
{% endfor %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{{ imports }}

{% for model in code_model.models.values() %}
{{ model.source_code }}
{% endfor %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
{% if code_model.models %}
from ._models import (
{% for model in code_model.models.keys() %}
{{ model }},
{% endfor %}
)
{% endif %}

{% if code_model.enums %}
from ._enums import (
{% for enum in code_model.enums %}
{{ enum.name }},
{% endfor %}
)
{% endif %}

from ._patch import __all__ as _patch_all
from ._patch import * # pylint: disable=unused-wildcard-import
from ._patch import patch_sdk as _patch_sdk

__all__ = [
{% for model in code_model.models.keys() %}
"{{ model }}",
{% endfor %}
{% for enum in code_model.enums %}
"{{ enum.name }}",
{% endfor %}
]
__all__.extend([p for p in _patch_all if p not in __all__])
_patch_sdk()
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class {{ operation_group.name }}{{ "(" + operation_group.name.replace("Operations", "") + "ABC)" if operation_group.is_mixin else "" }}:
{% if not operation_group.is_mixin %}
"""
{{ operation_group.generated_class(async_mode).__doc__ }}
{{ operation_group.doc(async_mode) }}
"""
models = _models

Expand Down

0 comments on commit 10a4783

Please sign in to comment.