Skip to content

Commit

Permalink
Consolidate numpy member transforms to reduce function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
correctmost authored and jacobtylerwalls committed Sep 30, 2024
1 parent f19fc0a commit c7ea1e9
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 33 deletions.
19 changes: 10 additions & 9 deletions astroid/brain/brain_numpy_core_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import functools

from astroid.brain.brain_numpy_utils import (
attribute_looks_like_numpy_member,
infer_numpy_member,
attribute_name_looks_like_numpy_member,
infer_numpy_attribute,
)
from astroid.inference_tip import inference_tip
from astroid.manager import AstroidManager
Expand All @@ -25,10 +25,11 @@


def register(manager: AstroidManager) -> None:
for func_name, func_src in METHODS_TO_BE_INFERRED.items():
inference_function = functools.partial(infer_numpy_member, func_src)
manager.register_transform(
Attribute,
inference_tip(inference_function),
functools.partial(attribute_looks_like_numpy_member, func_name),
)
manager.register_transform(
Attribute,
inference_tip(functools.partial(infer_numpy_attribute, METHODS_TO_BE_INFERRED)),
functools.partial(
attribute_name_looks_like_numpy_member,
frozenset(METHODS_TO_BE_INFERRED.keys()),
),
)
31 changes: 16 additions & 15 deletions astroid/brain/brain_numpy_core_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from astroid import nodes
from astroid.brain.brain_numpy_utils import (
attribute_looks_like_numpy_member,
infer_numpy_member,
name_looks_like_numpy_member,
attribute_name_looks_like_numpy_member,
infer_numpy_attribute,
infer_numpy_name,
member_name_looks_like_numpy_member,
)
from astroid.brain.helpers import register_module_extender
from astroid.builder import parse
Expand Down Expand Up @@ -92,15 +93,15 @@ def register(manager: AstroidManager) -> None:
manager, "numpy.core.multiarray", numpy_core_multiarray_transform
)

for method_name, function_src in METHODS_TO_BE_INFERRED.items():
inference_function = functools.partial(infer_numpy_member, function_src)
manager.register_transform(
Attribute,
inference_tip(inference_function),
functools.partial(attribute_looks_like_numpy_member, method_name),
)
manager.register_transform(
Name,
inference_tip(inference_function),
functools.partial(name_looks_like_numpy_member, method_name),
)
method_names = frozenset(METHODS_TO_BE_INFERRED.keys())

manager.register_transform(
Attribute,
inference_tip(functools.partial(infer_numpy_attribute, METHODS_TO_BE_INFERRED)),
functools.partial(attribute_name_looks_like_numpy_member, method_names),
)
manager.register_transform(
Name,
inference_tip(functools.partial(infer_numpy_name, METHODS_TO_BE_INFERRED)),
functools.partial(member_name_looks_like_numpy_member, method_names),
)
19 changes: 10 additions & 9 deletions astroid/brain/brain_numpy_core_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from astroid import nodes
from astroid.brain.brain_numpy_utils import (
attribute_looks_like_numpy_member,
infer_numpy_member,
attribute_name_looks_like_numpy_member,
infer_numpy_attribute,
)
from astroid.brain.helpers import register_module_extender
from astroid.builder import parse
Expand Down Expand Up @@ -41,10 +41,11 @@ def register(manager: AstroidManager) -> None:
manager, "numpy.core.numeric", numpy_core_numeric_transform
)

for method_name, function_src in METHODS_TO_BE_INFERRED.items():
inference_function = functools.partial(infer_numpy_member, function_src)
manager.register_transform(
Attribute,
inference_tip(inference_function),
functools.partial(attribute_looks_like_numpy_member, method_name),
)
manager.register_transform(
Attribute,
inference_tip(functools.partial(infer_numpy_attribute, METHODS_TO_BE_INFERRED)),
functools.partial(
attribute_name_looks_like_numpy_member,
frozenset(METHODS_TO_BE_INFERRED.keys()),
),
)
39 changes: 39 additions & 0 deletions astroid/brain/brain_numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ def _get_numpy_version() -> tuple[str, str, str]:
return ("0", "0", "0")


def infer_numpy_name(
sources: dict[str, str], node: Name, context: InferenceContext | None = None
):
extracted_node = extract_node(sources[node.name])
return extracted_node.infer(context=context)


def infer_numpy_attribute(
sources: dict[str, str], node: Attribute, context: InferenceContext | None = None
):
extracted_node = extract_node(sources[node.attrname])
return extracted_node.infer(context=context)


# TODO: Deprecate and remove this function
def infer_numpy_member(src, node, context: InferenceContext | None = None):
node = extract_node(src)
return node.infer(context=context)
Expand Down Expand Up @@ -61,6 +76,29 @@ def _is_a_numpy_module(node: Name) -> bool:
)


def member_name_looks_like_numpy_member(
member_names: frozenset[str], node: Name
) -> bool:
"""
Returns True if the Name node's name matches a member name from numpy
"""
return node.name in member_names and node.root().name.startswith("numpy")


def attribute_name_looks_like_numpy_member(
member_names: frozenset[str], node: Attribute
) -> bool:
"""
Returns True if the Attribute node's name matches a member name from numpy
"""
return (
node.attrname in member_names
and isinstance(node.expr, Name)
and _is_a_numpy_module(node.expr)
)


# TODO: Deprecate and remove this function
def name_looks_like_numpy_member(member_name: str, node: Name) -> bool:
"""
Returns True if the Name is a member of numpy whose
Expand All @@ -69,6 +107,7 @@ def name_looks_like_numpy_member(member_name: str, node: Name) -> bool:
return node.name == member_name and node.root().name.startswith("numpy")


# TODO: Deprecate and remove this function
def attribute_looks_like_numpy_member(member_name: str, node: Attribute) -> bool:
"""
Returns True if the Attribute is a member of numpy whose
Expand Down

0 comments on commit c7ea1e9

Please sign in to comment.