Skip to content

Commit 7f74579

Browse files
Add codegen_util.load_generated_function
Defines the new function load_generated_function which is meant to be a more user friendly version of `load_generated_package` (which will be particularly more friendly should we no longer re-export all sub-modules of generated packages). Also, explains the arguments of `load_generated_package` in its doc-string. As an example of the expected usage of this function, for `co` a `Codegen` object: ``` python3 generated_paths = co.generate_function() func = load_generated_function(co.name, generated_paths.function_dir) ``` Tested in `test/symforce_codegen_util_test.py`
1 parent 443997a commit 7f74579

File tree

4 files changed

+83
-9
lines changed

4 files changed

+83
-9
lines changed

symforce/codegen/codegen_util.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,13 @@ def _load_generated_package_internal(name: str, path: Path) -> T.Tuple[T.Any, T.
543543
def load_generated_package(name: str, path: T.Openable) -> T.Any:
544544
"""
545545
Dynamically load generated package (or module).
546+
547+
Args:
548+
name: The full name of the package or module to load (for example, "pkg.sub_pkg"
549+
for a package called "sub_pkg" inside of another package "pkg", or
550+
"pkg.sub_pkg.mod" for a module called "mod" inside of pkg.sub_pkg).
551+
path: The path to the directory (or __init__.py) of the package, or the python
552+
file of the module.
546553
"""
547554
# NOTE(brad): We remove all possibly conflicting modules from the cache. This is
548555
# to ensure that when name is executed, it loads local modules (if any) rather
@@ -574,6 +581,35 @@ def load_generated_package(name: str, path: T.Openable) -> T.Any:
574581
return module
575582

576583

584+
def load_generated_function(func_name: str, path_to_package: T.Openable) -> T.Callable:
585+
"""
586+
Returns the function with name func_name found inside the package located at
587+
path_to_package.
588+
589+
Example usage:
590+
591+
def my_func(...):
592+
...
593+
594+
my_codegen = Codegen.function(my_func, config=PythonConfig())
595+
codegen_data = my_codegen.generate_function(output_dir=output_dir)
596+
generated_func = load_generated_function("my_func", codegen_data.function_dir)
597+
generated_func(...)
598+
599+
Preconditions:
600+
path_to_package is a python package with an `__init__.py` containing a module
601+
defined in `func_name.py` which in turn defines an attribute named `func_name`.
602+
Note: the precondition will be satisfied if the package was generated by
603+
`Codegen.generate_function` from a `Codegen` function with name `func_name`.
604+
"""
605+
pkg_path = Path(path_to_package)
606+
if pkg_path.name == "__init__.py":
607+
pkg_path = pkg_path.parent
608+
pkg_name = pkg_path.name
609+
func_module = load_generated_package(f"{pkg_name}.{func_name}", pkg_path / f"{func_name}.py")
610+
return getattr(func_module, func_name)
611+
612+
577613
def load_generated_lcmtype(
578614
package: str, type_name: str, lcmtypes_path: T.Union[str, Path]
579615
) -> T.Type:

test/symforce_codegen_util_test.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from symforce.codegen import codegen_util
1111
from symforce.test_util import TestCase
1212

13+
PKG_LOCATIONS = Path(__file__).parent / "test_data" / "codegen_util_test_data"
14+
RELATIVE_PATH = Path("example_pkg", "__init__.py")
15+
PACKAGE_NAME = "example_pkg"
16+
1317

1418
class SymforceCodegenUtilTest(TestCase):
1519
"""
@@ -22,32 +26,50 @@ def test_load_generated_package(self) -> None:
2226
codegen_util.load_generated_package
2327
"""
2428

25-
pkg_locations = Path(__file__).parent / "test_data" / "codegen_util_test_data"
26-
27-
relative_path = Path("example_pkg", "__init__.py")
28-
29-
package_name = "example_pkg"
30-
3129
pkg1 = codegen_util.load_generated_package(
32-
name=package_name, path=pkg_locations / "example_pkg_1" / relative_path
30+
name=PACKAGE_NAME, path=PKG_LOCATIONS / "example_pkg_1" / RELATIVE_PATH
3331
)
3432

3533
# Testing that the module was loaded correctly
3634
self.assertEqual(pkg1.package_id, 1)
3735
self.assertEqual(pkg1.sub_module.sub_module_id, 1)
3836

3937
# Testing that sys.modules was not polluted
40-
self.assertFalse(package_name in sys.modules)
38+
self.assertFalse(PACKAGE_NAME in sys.modules)
4139

4240
pkg2 = codegen_util.load_generated_package(
43-
name=package_name, path=pkg_locations / "example_pkg_2" / relative_path
41+
name=PACKAGE_NAME, path=PKG_LOCATIONS / "example_pkg_2" / RELATIVE_PATH
4442
)
4543

4644
# Testing that the module was loaded correctly when a module with the same name has
4745
# already been loaded
4846
self.assertEqual(pkg2.package_id, 2)
4947
self.assertEqual(pkg2.sub_module.sub_module_id, 2)
5048

49+
def test_load_generated_function(self) -> None:
50+
"""
51+
Tests:
52+
codegen_util.load_generated_function
53+
"""
54+
55+
func1 = codegen_util.load_generated_function(
56+
func_name="func", path_to_package=PKG_LOCATIONS / "example_pkg_1" / RELATIVE_PATH
57+
)
58+
59+
# Testing that the function was loaded correctly
60+
self.assertEqual(func1(), 1)
61+
62+
# Testing that sys.modules was not polluted
63+
self.assertFalse(PACKAGE_NAME in sys.modules)
64+
65+
func2 = codegen_util.load_generated_function(
66+
func_name="func", path_to_package=PKG_LOCATIONS / "example_pkg_2" / RELATIVE_PATH
67+
)
68+
69+
# Testing that the function was loaded correctly when a function of the same name
70+
# has already been loaded.
71+
self.assertEqual(func2(), 2)
72+
5173

5274
if __name__ == "__main__":
5375
TestCase.main()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ----------------------------------------------------------------------------
5+
6+
7+
def func() -> int:
8+
return 1
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ----------------------------------------------------------------------------
5+
6+
7+
def func() -> int:
8+
return 2

0 commit comments

Comments
 (0)