Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions sphinx_parser/src/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import builtins
import keyword
import os
from typing import Any

import yaml
from black import FileMode, format_str
Expand All @@ -9,7 +10,7 @@
predefined = ["description", "default", "data_type", "required", "alias", "units"]


def _find_alias(all_data: dict, head: list | None = None):
def _find_alias(all_data: dict, head: list | None = None) -> dict[str, str]:
"""
Find all aliases in the data structure.

Expand All @@ -31,13 +32,13 @@ def _find_alias(all_data: dict, head: list | None = None):
return results


def _replace_alias(all_data: dict):
def _replace_alias(all_data: dict) -> dict:
for key, value in _find_alias(all_data).items():
_set(all_data, key, _get(all_data, value))
return all_data


def _get(obj: dict, path: str, sep: str = "/"):
def _get(obj: dict, path: str, sep: str = "/") -> Any:
"""
Get a value from a nested dictionary.

Expand Down Expand Up @@ -69,13 +70,13 @@ def _set(obj: dict, path: str, value):
obj[last] = value


def _get_safe_parameter_name(name: str):
def _get_safe_parameter_name(name: str) -> str:
if keyword.iskeyword(name) or name in dir(builtins):
name = name + "_"
return name


def _get_docstring_line(data: dict, key: str):
def _get_docstring_line(data: dict, key: str) -> str:
"""
Get a single line for the docstring.

Expand Down Expand Up @@ -108,7 +109,12 @@ def _get_docstring_line(data: dict, key: str):
return line


def _get_docstring(all_data, description=None, indent=indent, predefined=predefined):
def _get_docstring(
all_data: dict,
description: str | None = None,
indent: str = indent,
predefined: list[str] = predefined,
) -> list[str]:
txt = [indent + '"""']
if description is not None:
txt.append(f"{indent}{description}\n")
Expand All @@ -123,7 +129,7 @@ def _get_docstring(all_data, description=None, indent=indent, predefined=predefi
return txt


def _get_input_arg(key, entry, indent=indent):
def _get_input_arg(key: str, entry: dict, indent: str = indent) -> str:
t = entry.get("data_type", "dict")
units = "".join(entry.get("units", "").split())
if not entry.get("required", False) and units != "":
Expand All @@ -136,7 +142,7 @@ def _get_input_arg(key, entry, indent=indent):
return t


def _rename_keys(data):
def _rename_keys(data: dict) -> dict:
d_1 = {_get_safe_parameter_name(key): value for key, value in data.items()}
d_2 = {
key: d
Expand All @@ -148,11 +154,11 @@ def _rename_keys(data):


def _get_function(
data,
data: dict,
function_name: list[str],
predefined=predefined,
is_kwarg=False,
):
predefined: list[str] = predefined,
is_kwarg: bool = False,
) -> str:
d = _rename_keys(data)
func = []
if is_kwarg:
Expand Down Expand Up @@ -187,7 +193,9 @@ def _get_function(
return "\n".join(result)


def _get_all_function_names(all_data, head="", predefined=predefined):
def _get_all_function_names(
all_data: dict, head: str = "", predefined: list[str] = predefined
) -> list[str]:
key_lst = []
for tag, data in all_data.items():
if tag not in predefined and data.get("data_type", "dict") == "dict":
Expand All @@ -196,7 +204,7 @@ def _get_all_function_names(all_data, head="", predefined=predefined):
return key_lst


def _get_class(all_data):
def _get_class(all_data: dict) -> str:
fnames = _get_all_function_names(all_data)
txt = ""
for name in fnames:
Expand All @@ -216,7 +224,7 @@ def _get_class(all_data):
return txt


def _get_file_content(yml_file_name="input_data.yml"):
def _get_file_content(yml_file_name: str = "input_data.yml") -> str:
file_location = os.path.join(os.path.dirname(__file__), yml_file_name)
with open(file_location, "r") as f:
file_content = f.read()
Expand Down Expand Up @@ -246,7 +254,7 @@ def _get_file_content(yml_file_name="input_data.yml"):
return file_content


def export_class(yml_file_name="input_data.yml", py_file_name="input.py"):
def export_class(yml_file_name: str = "input_data.yml", py_file_name: str = "input.py"):
file_content = _get_file_content(yml_file_name)
with open(os.path.join(os.path.dirname(__file__), "..", py_file_name), "w") as f:
f.write(file_content)
Expand Down
Loading