diff --git a/nn/naming.py b/nn/naming.py index 58935e26..d98cf824 100644 --- a/nn/naming.py +++ b/nn/naming.py @@ -906,7 +906,9 @@ def get_base_extern_data_py_code_str(self) -> str: return "".join(code_lines) @classmethod - def get_base_extern_data_py_code_str_direct(cls, extern_data: Dict[str, Any]) -> str: + def get_base_extern_data_py_code_str_direct( + cls, extern_data: Dict[str, Any], *, other: Optional[Dict[str, Any]] = None + ) -> str: """ directly get serialized Python code via extern data """ @@ -914,12 +916,16 @@ def get_base_extern_data_py_code_str_direct(cls, extern_data: Dict[str, Any]) -> from returnn.util.pprint import pformat extern_data = dim_tags_proxy.collect_dim_tags_and_transform_config(extern_data) + other = dim_tags_proxy.collect_dim_tags_and_transform_config(other) code_lines = [ cls.ImportPyCodeStr, f"{dim_tags_proxy.py_code_str()}\n", f"extern_data = {pformat(extern_data)}\n", ] + if other: + for k, v in other.items(): + code_lines.append(f"{k} = {pformat(v)}\n") return "".join(code_lines) def get_ext_net_dict_py_code_str(