Skip to content

Commit

Permalink
Fix more type hint issues and imports
Browse files Browse the repository at this point in the history
  • Loading branch information
mschmidt87 committed Apr 10, 2020
1 parent 40d18e7 commit fda27f2
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 69 deletions.
13 changes: 6 additions & 7 deletions gp/cartesian_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ def __init__(self, genome: gp.Genome) -> None:
genome: Genome
Genome defining graph connectivity and node operations.
"""
self._n_outputs = None
self._n_inputs = None
self._n_columns = None
self._n_rows = None
self._nodes = None
self._gnome = None
self._n_outputs: int = 0
self._n_inputs: int = 0
self._n_columns: int = 0
self._n_rows: int = 0
self._nodes: List = []

self.parse_genome(genome)
self._genome = genome
Expand Down Expand Up @@ -163,7 +162,7 @@ def determine_active_regions(self) -> List[int]:

return active_regions

def __call__(self, x: float) -> List[float]:
def __call__(self, x: List[float]) -> List[float]:
# store values of x in input nodes
for i, xi in enumerate(x):
assert isinstance(self._nodes[i], InputNode)
Expand Down
26 changes: 16 additions & 10 deletions gp/genome.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gp
import numpy as np

from typing import Generator, List, Optional
from typing import Generator, List, Optional, Tuple

from .primitives import Primitives

Expand Down Expand Up @@ -65,14 +65,14 @@ def __init__(
1 + self._primitives.max_arity
) # one function gene + multiple input genes

self._dna = None # stores dna as list of alleles for all regions
self._dna: List[int] = [] # stores dna as list of alleles for all regions

# constants used as identifiers for input and output nodes
self._id_input_node = -1
self._id_output_node = -2
self._non_coding_allele = None
self._non_coding_allele: int = 0

def __getitem__(self, key: int) -> None:
def __getitem__(self, key: int) -> int:
if self._dna is None:
raise RuntimeError("dna not initialized")
return self._dna[key]
Expand Down Expand Up @@ -126,7 +126,7 @@ def _create_input_region(self) -> list:
return region

def _create_random_hidden_region(
self, rng: np.random.RandomState, permissable_inputs: int
self, rng: np.random.RandomState, permissable_inputs: List[int]
) -> list:

# construct dna region consisting of function allele and
Expand All @@ -146,7 +146,7 @@ def _create_random_hidden_region(
return region

def _create_random_output_region(
self, rng: np.random.RandomState, permissable_inputs: int
self, rng: np.random.RandomState, permissable_inputs: List[int]
) -> list:

# fill region with identifier for output node and single
Expand Down Expand Up @@ -193,7 +193,7 @@ def randomize(self, rng: np.random.RandomState) -> None:
self._validate_dna(dna)
self._dna = dna

def _permissable_inputs(self, region_idx: int) -> list:
def _permissable_inputs(self, region_idx: int) -> List[int]:

assert not self._is_input_region(region_idx)

Expand Down Expand Up @@ -281,7 +281,9 @@ def _hidden_column_idx(self, region_idx: int) -> int:
assert hidden_column_idx < self._n_columns
return hidden_column_idx

def iter_input_regions(self, dna: Optional[List[int]] = None) -> Generator[int, list]:
def iter_input_regions(
self, dna: Optional[List[int]] = None
) -> Generator[Tuple[int, list], None, None]:
if dna is None:
dna = self.dna
for i in range(self._n_inputs):
Expand All @@ -291,7 +293,9 @@ def iter_input_regions(self, dna: Optional[List[int]] = None) -> Generator[int,
]
yield region_idx, region

def iter_hidden_regions(self, dna: Optional[List[int]] = None) -> Generator[int]:
def iter_hidden_regions(
self, dna: Optional[List[int]] = None
) -> Generator[Tuple[int, List[int]], None, None]:
if dna is None:
dna = self.dna
for i in range(self._n_hidden):
Expand All @@ -301,7 +305,9 @@ def iter_hidden_regions(self, dna: Optional[List[int]] = None) -> Generator[int]
]
yield region_idx, region

def iter_output_regions(self, dna: Optional[List[int]] = None) -> Generator[int]:
def iter_output_regions(
self, dna: Optional[List[int]] = None
) -> Generator[Tuple[int, List[int]], None, None]:
if dna is None:
dna = self.dna
for i in range(self._n_outputs):
Expand Down
28 changes: 15 additions & 13 deletions gp/hl_api.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import numpy as np
from typing import Optional, Callable
from .abstract_individual import AbstractIndividual
from .abstract_population import AbstractPopulation
from .ea import MuPlusLambda


def evolve(
pop,
objective,
ea,
max_generations,
min_fitness,
print_progress=False,
*,
callback=None,
label=None,
n_processes=1,
):
pop: AbstractPopulation,
objective: Callable[[AbstractIndividual], AbstractIndividual],
ea: MuPlusLambda,
max_generations: int,
min_fitness: float,
print_progress: Optional[bool] = False,
callback: Optional[Callable[[AbstractPopulation], None]] = None,
label: Optional[str] = None,
n_processes: int = 1,
) -> None:
"""
Evolves a population and returns the history of fitness of parents.
Expand Down Expand Up @@ -45,8 +48,7 @@ def evolve(
Returns
-------
dict
History of the evolution.
None
"""

ea.initialize_fitness_parents(pop, objective, label=label)
Expand Down
81 changes: 45 additions & 36 deletions gp/node.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import gp
from typing import Any, List
from __future__ import annotations

from typing import List, Union

from cgp_graph import CGPGraph

primitives_dict = {} # maps string of class names to classes


def register(cls: gp.CGPNode) -> None:
def register(cls: CGPNode) -> None:
"""Register a primitive in the global dictionary of primitives
Parameters
Expand All @@ -25,12 +28,15 @@ class Node:
"""Base class for primitive functions used in Cartesian computational graphs.
"""

_arity = None
_active = False
_inputs = None
_output = None
_idx = None
_is_parameter = False
__name__ = "CGPNode"
_arity: int = 0
_active: bool = False
_inputs: List[int] = []
_output: float = 0.0
_idx: int = 0
_is_parameter: bool = False
_output_str: str = ""
_parameter_str: str = ""

def __init__(self, idx: int, inputs: List[int]) -> None:
"""Init function.
Expand All @@ -47,12 +53,15 @@ def __init__(self, idx: int, inputs: List[int]) -> None:

assert idx not in inputs

def __init_subclass__(cls: gp.CGPNode, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
register(cls)
# def __init_subclass__(cls: CGPNode, **kwargs: Any) -> None:
# super().__init_subclass__(**kwargs)
# register(cls)

def __call__(self, x: float, graph: CGPGraph) -> None:
raise NotImplementedError

@property
def arity(self) -> int:
def arity(self) -> Union[None, int]:
return self._arity

@property
Expand All @@ -64,13 +73,13 @@ def inputs(self) -> List[int]:
return self._inputs[: self._arity]

@property
def idx(self) -> int:
def idx(self) -> Union[None, int]:
return self._idx

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(idx: {self.idx}, active: {self._active}, "
f"arity: {self._arity}, inputs {self._inputs}, output {self._outputs})"
f"arity: {self._arity}, inputs {self._inputs}, output {self._output})"
)

def pretty_str(self, n: int) -> str:
Expand Down Expand Up @@ -114,15 +123,15 @@ def pretty_str(self, n: int) -> str:
return s.ljust(n)

@property
def output(self) -> int:
def output(self) -> float:
return self._output

def activate(self) -> None:
"""Set node to active.
"""
self._active = True

def format_output_str(self, graph: gp.CGPGraph) -> None:
def format_output_str(self, graph: CGPGraph) -> None:
"""Format output string of the node.
"""
raise NotImplementedError()
Expand All @@ -144,7 +153,7 @@ def output_str(self) -> str:

@property
def output_str_torch(self) -> str:
return self.output_str
return self._output_str

@property
def is_parameter(self) -> bool:
Expand All @@ -164,10 +173,10 @@ class Add(Node):
def __init__(self, idx: int, inputs: List[int]) -> None:
super().__init__(idx, inputs)

def __call__(self, x: float, graph: gp.CGPGraph) -> None:
def __call__(self, x: float, graph: CGPGraph) -> None:
self._output = graph[self._inputs[0]].output + graph[self._inputs[1]].output

def format_output_str(self, graph: gp.CGPGraph) -> None:
def format_output_str(self, graph: CGPGraph) -> None:
self._output_str = (
f"({graph[self._inputs[0]].output_str} + {graph[self._inputs[1]].output_str})"
)
Expand All @@ -182,10 +191,10 @@ class Sub(Node):
def __init__(self, idx: int, inputs: List[int]) -> None:
super().__init__(idx, inputs)

def __call__(self, x: float, graph: gp.CGPGraph) -> None:
def __call__(self, x: float, graph: CGPGraph) -> None:
self._output = graph[self._inputs[0]].output - graph[self._inputs[1]].output

def format_output_str(self, graph: gp.CGPGraph) -> None:
def format_output_str(self, graph: CGPGraph) -> None:
self._output_str = (
f"({graph[self._inputs[0]].output_str} - {graph[self._inputs[1]].output_str})"
)
Expand All @@ -200,10 +209,10 @@ class Mul(Node):
def __init__(self, idx: int, inputs: List[int]) -> None:
super().__init__(idx, inputs)

def __call__(self, x: float, graph: gp.CGPGraph) -> None:
def __call__(self, x: float, graph: CGPGraph) -> None:
self._output = graph[self._inputs[0]].output * graph[self._inputs[1]].output

def format_output_str(self, graph: gp.CGPGraph) -> None:
def format_output_str(self, graph: CGPGraph) -> None:
self._output_str = (
f"({graph[self._inputs[0]].output_str} * {graph[self._inputs[1]].output_str})"
)
Expand All @@ -218,11 +227,11 @@ class Div(Node):
def __init__(self, idx: int, inputs: List[int]) -> None:
super().__init__(idx, inputs)

def __call__(self, x: float, graph: gp.CGPGraph) -> None:
def __call__(self, x: float, graph: CGPGraph) -> None:

self._output = graph[self._inputs[0]].output / graph[self._inputs[1]].output

def format_output_str(self, graph: gp.CGPGraph) -> None:
def format_output_str(self, graph: CGPGraph) -> None:
self._output_str = (
f"({graph[self._inputs[0]].output_str} / {graph[self._inputs[1]].output_str})"
)
Expand All @@ -240,13 +249,13 @@ def __init__(self, idx: int, inputs: List[int]) -> None:

self._output = 1.0

def __call__(self, x: float, graph: gp.CGPGraph) -> None:
def __call__(self, x: float, graph: CGPGraph) -> None:
pass

def format_output_str(self, graph: gp.CGPGraph) -> None:
def format_output_str(self, graph: CGPGraph) -> None:
self._output_str = f"{self._output}"

def format_output_str_torch(self, graph: gp.CGPGraph) -> None:
def format_output_str_torch(self, graph: CGPGraph) -> None:
self._output_str = f"self._p{self._idx}.expand(x.shape[0])"

def format_parameter_str(self) -> None:
Expand All @@ -264,13 +273,13 @@ class InputNode(Node):
def __init__(self, idx: int, inputs: List[int]) -> None:
super().__init__(idx, inputs)

def __call__(self, x: float, graph: gp.CGPGraph) -> None:
def __call__(self, x: float, graph: CGPGraph) -> None:
assert False

def format_output_str(self, graph: gp.CGPGraph) -> None:
def format_output_str(self, graph: CGPGraph) -> None:
self._output_str = f"x[{self._idx}]"

def format_output_str_torch(self, graph: gp.CGPGraph) -> None:
def format_output_str_torch(self, graph: CGPGraph) -> None:
self._output_str = f"x[:, {self._idx}]"


Expand All @@ -283,10 +292,10 @@ class OutputNode(Node):
def __init__(self, idx: int, inputs: List[int]) -> None:
super().__init__(idx, inputs)

def __call__(self, x: float, graph: gp.CGPGraph) -> None:
def __call__(self, x: float, graph: CGPGraph) -> None:
self._output = graph[self._inputs[0]].output

def format_output_str(self, graph: gp.CGPGraph) -> None:
def format_output_str(self, graph: CGPGraph) -> None:
self._output_str = f"{graph[self._inputs[0]].output_str}"


Expand All @@ -299,10 +308,10 @@ class Pow(Node):
def __init__(self, idx: int, inputs: List[int]) -> None:
super().__init__(idx, inputs)

def __call__(self, x: float, graph: gp.CGPGraph) -> None:
def __call__(self, x: float, graph: CGPGraph) -> None:
self._output = graph[self._inputs[0]].output ** graph[self._inputs[1]].output

def format_output_str(self, graph: gp.CGPGraph) -> None:
def format_output_str(self, graph: CGPGraph) -> None:
self._output_str = (
f"({graph[self._inputs[0]].output_str} ** {graph[self._inputs[1]].output_str})"
)
Loading

0 comments on commit fda27f2

Please sign in to comment.