Skip to content

Commit 63ad07b

Browse files
Add codegen_util.load_generated_function
Defines the new function load_generated_function which is meant to be a more user friendly version of `load_generated_package` (which will be particularly more friendly should we no longer re-export all sub-modules of generated packages). Also, explains the arguments of `load_generated_package` in its doc-string. As an example of the expected usage of this function, for `co` a `Codegen` object: ``` python3 generated_paths = co.generate_function() func = load_generated_function(co.name, generated_paths.function_dir) ``` Tested in `test/symforce_codegen_util_test.py`
1 parent 443997a commit 63ad07b

File tree

4 files changed

+73
-9
lines changed

4 files changed

+73
-9
lines changed

symforce/codegen/codegen_util.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,13 @@ def _load_generated_package_internal(name: str, path: Path) -> T.Tuple[T.Any, T.
543543
def load_generated_package(name: str, path: T.Openable) -> T.Any:
544544
"""
545545
Dynamically load generated package (or module).
546+
547+
name is the full name of the package or module to load (for example,
548+
"pkg.sub_pkg" for a package called "sub_pkg" inside of another package
549+
"pkg", or "pkg.sub_pkg.mod" for a module called "mod" inside of pkg.sub_pkg).
550+
551+
path is the path to the directory (or __init__.py) of the package, or the python
552+
file of the module.
546553
"""
547554
# NOTE(brad): We remove all possibly conflicting modules from the cache. This is
548555
# to ensure that when name is executed, it loads local modules (if any) rather
@@ -574,6 +581,25 @@ def load_generated_package(name: str, path: T.Openable) -> T.Any:
574581
return module
575582

576583

584+
def load_generated_function(func_name: str, path_to_package: T.Openable) -> T.Any:
585+
"""
586+
Loads the function with name func_name found inside the package located at
587+
path_to_package.
588+
589+
Preconditions:
590+
path_to_package is a python package with an `__init__.py` containing a module
591+
defined in `func_name.py` which in turn defines an attribute named `func_name`.
592+
Note: the precondition will be satisfied if the package was generated by
593+
`Codegen.generate_function` from a `Codegen` function with name `func_name`.
594+
"""
595+
pkg_path = Path(path_to_package)
596+
if pkg_path.name == "__init__.py":
597+
pkg_path = pkg_path.parent
598+
pkg_name = pkg_path.name
599+
func_module = load_generated_package(f"{pkg_name}.{func_name}", pkg_path / f"{func_name}.py")
600+
return getattr(func_module, func_name)
601+
602+
577603
def load_generated_lcmtype(
578604
package: str, type_name: str, lcmtypes_path: T.Union[str, Path]
579605
) -> T.Type:

test/symforce_codegen_util_test.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from symforce.codegen import codegen_util
1111
from symforce.test_util import TestCase
1212

13+
PKG_LOCATIONS = Path(__file__).parent / "test_data" / "codegen_util_test_data"
14+
RELATIVE_PATH = Path("example_pkg", "__init__.py")
15+
PACKAGE_NAME = "example_pkg"
16+
1317

1418
class SymforceCodegenUtilTest(TestCase):
1519
"""
@@ -22,32 +26,50 @@ def test_load_generated_package(self) -> None:
2226
codegen_util.load_generated_package
2327
"""
2428

25-
pkg_locations = Path(__file__).parent / "test_data" / "codegen_util_test_data"
26-
27-
relative_path = Path("example_pkg", "__init__.py")
28-
29-
package_name = "example_pkg"
30-
3129
pkg1 = codegen_util.load_generated_package(
32-
name=package_name, path=pkg_locations / "example_pkg_1" / relative_path
30+
name=PACKAGE_NAME, path=PKG_LOCATIONS / "example_pkg_1" / RELATIVE_PATH
3331
)
3432

3533
# Testing that the module was loaded correctly
3634
self.assertEqual(pkg1.package_id, 1)
3735
self.assertEqual(pkg1.sub_module.sub_module_id, 1)
3836

3937
# Testing that sys.modules was not polluted
40-
self.assertFalse(package_name in sys.modules)
38+
self.assertFalse(PACKAGE_NAME in sys.modules)
4139

4240
pkg2 = codegen_util.load_generated_package(
43-
name=package_name, path=pkg_locations / "example_pkg_2" / relative_path
41+
name=PACKAGE_NAME, path=PKG_LOCATIONS / "example_pkg_2" / RELATIVE_PATH
4442
)
4543

4644
# Testing that the module was loaded correctly when a module with the same name has
4745
# already been loaded
4846
self.assertEqual(pkg2.package_id, 2)
4947
self.assertEqual(pkg2.sub_module.sub_module_id, 2)
5048

49+
def test_load_generated_function(self) -> None:
50+
"""
51+
Tests:
52+
codegen_util.load_generated_function
53+
"""
54+
55+
func1 = codegen_util.load_generated_function(
56+
func_name="func", path_to_package=PKG_LOCATIONS / "example_pkg_1" / RELATIVE_PATH
57+
)
58+
59+
# Testing that the function was loaded correctly
60+
self.assertEqual(func1(), 1)
61+
62+
# Testing that sys.modules was not polluted
63+
self.assertFalse(PACKAGE_NAME in sys.modules)
64+
65+
func2 = codegen_util.load_generated_function(
66+
func_name="func", path_to_package=PKG_LOCATIONS / "example_pkg_2" / RELATIVE_PATH
67+
)
68+
69+
# Testing that the function was loaded correctly when a function of the same name
70+
# has already been loaded.
71+
self.assertEqual(func2(), 2)
72+
5173

5274
if __name__ == "__main__":
5375
TestCase.main()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ----------------------------------------------------------------------------
5+
6+
7+
def func() -> int:
8+
return 1
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ----------------------------------------------------------------------------
5+
6+
7+
def func() -> int:
8+
return 2

0 commit comments

Comments
 (0)