Skip to content

Commit

Permalink
Add type hints in all files of the library
Browse files Browse the repository at this point in the history
Start adding type hints

Black formatting and future imports to solve problem of forward reference

Fix more type hint issues and imports

Fix type annotations in population class and readd erroneously deleted variable in __init__ function

Fix some errors in node.py

Don't initialize class members

Make annotation more precise

Remove deprecated CGP in annotations

Fix initialization in genome

Updated hl_api: Removed Abstract
  • Loading branch information
Maximilian Schmidt authored and mschmidt87 committed Apr 30, 2020
1 parent 65413fc commit 436ba5c
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 154 deletions.
57 changes: 29 additions & 28 deletions gp/cartesian_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,51 @@


from .node import InputNode, OutputNode
import gp
from typing import Callable, List, Optional, Dict


class CartesianGraph:
"""Class representing a particular Cartesian graph defined by a
Genome.
"""

def __init__(self, genome):
def __init__(self, genome: gp.Genome) -> None:
"""Init function.
Parameters
----------
genome: Genome
Genome defining graph connectivity and node operations.
"""
self._n_outputs = None
self._n_inputs = None
self._n_columns = None
self._n_rows = None
self._nodes = None
self._gnome = None
self._n_outputs: int
self._n_inputs: int
self._n_columns: int
self._n_rows: int
self._nodes: List

self.parse_genome(genome)
self._genome = genome

def __repr__(self):
def __repr__(self) -> str:
return "CartesianGraph(" + str(self._nodes) + ")"

def print_active_nodes(self):
def print_active_nodes(self) -> str:
"""Print a representation of all active nodes in the graph.
"""
return "CartesianGraph(" + str([node for node in self._nodes if node._active]) + ")"

def pretty_str(self):
def pretty_str(self) -> str:
"""Print a pretty representation of the Cartesian graph.
"""
n_characters = 24

def pretty_node_str(node):
def pretty_node_str(node: gp.Node) -> str:
s = node.pretty_str(n_characters)
assert len(s) == n_characters
return s

def empty_node_str():
def empty_node_str() -> str:
return " " * n_characters

s = "\n"
Expand Down Expand Up @@ -86,7 +87,7 @@ def empty_node_str():

return s

def parse_genome(self, genome):
def parse_genome(self, genome: gp.Genome) -> None:
if genome.dna is None:
raise RuntimeError("dna not initialized")

Expand Down Expand Up @@ -114,22 +115,22 @@ def parse_genome(self, genome):

self._determine_active_nodes()

def _hidden_column_idx(self, idx):
def _hidden_column_idx(self, idx: int) -> int:
return (idx - self._n_inputs) // self._n_rows

@property
def input_nodes(self):
def input_nodes(self) -> List[gp.Node]:
return self._nodes[: self._n_inputs]

@property
def hidden_nodes(self):
def hidden_nodes(self) -> List[gp.Node]:
return self._nodes[self._n_inputs : -self._n_outputs]

@property
def output_nodes(self):
def output_nodes(self) -> List[gp.Node]:
return self._nodes[-self._n_outputs :]

def _determine_active_nodes(self):
def _determine_active_nodes(self) -> Dict[int, set]:

# determine active nodes
active_nodes_by_hidden_column_idx = collections.defaultdict(
Expand All @@ -155,7 +156,7 @@ def _determine_active_nodes(self):

return active_nodes_by_hidden_column_idx

def determine_active_regions(self):
def determine_active_regions(self) -> List[int]:
"""Determine the active regions in the computational graph.
Returns
Expand All @@ -171,7 +172,7 @@ def determine_active_regions(self):

return active_regions

def __call__(self, x):
def __call__(self, x: List[float]) -> List[float]:
# store values of x in input nodes
for i, xi in enumerate(x):
assert isinstance(self._nodes[i], InputNode)
Expand All @@ -185,16 +186,16 @@ def __call__(self, x):

return [node._output for node in self.output_nodes]

def __getitem__(self, key):
def __getitem__(self, key: int) -> gp.Node:
return self._nodes[key]

def to_str(self):
def to_str(self) -> str:

self._format_output_str_of_all_nodes()
out_str = ", ".join(node.output_str for node in self.output_nodes)
return f"[{out_str}]"

def _format_output_str_of_all_nodes(self):
def _format_output_str_of_all_nodes(self) -> None:

for i, node in enumerate(self.input_nodes):
node.format_output_str(self)
Expand All @@ -204,8 +205,8 @@ def _format_output_str_of_all_nodes(self):
for node in active_nodes[hidden_column_idx]:
node.format_output_str(self)

def to_func(self):
"""Compile the function(s) represented by the graph.
def to_func(self) -> Callable:
"""Compile the function represented by the computational graph.
Generates a definition of the function in Python code and
executes the function definition to create a Callable.
Expand Down Expand Up @@ -270,7 +271,7 @@ def _f(x):

return locals()["_f"]

def to_torch(self):
def to_torch(self) -> torch.nn.Module:
"""Compile the function(s) represented by the graph to a Torch class.
Generates a definition of the Torch class in Python code and
Expand Down Expand Up @@ -347,7 +348,7 @@ def update_parameters_from_torch_class(self, torch_cls):
except AttributeError:
pass

def to_sympy(self, simplify=True):
def to_sympy(self, simplify: Optional[bool] = True) -> List[sympy.core.Expr]:
"""Compile the function(s) represented by the graph to a SymPy expression.
Generates one SymPy expression for each output node.
Expand All @@ -366,7 +367,7 @@ def to_sympy(self, simplify=True):
if sympy is None:
raise ModuleNotFoundError("No module named 'sympy' (extra requirement)")

def _validate_sympy_expr(expr):
def _validate_sympy_expr(expr: sympy.core.Expr) -> sympy.core.Expr:
"""Helper function that raises an exception upon encountering a SymPy
expression that can not be evaluated.
Expand Down
Loading

0 comments on commit 436ba5c

Please sign in to comment.