Skip to content

Commit 47ab920

Browse files
authored
Merge pull request #284 from pyiron/type_hints
Add type hints
2 parents f3956d6 + 5f370a3 commit 47ab920

File tree

13 files changed

+135
-118
lines changed

13 files changed

+135
-118
lines changed

.github/workflows/unittest-mpich.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ jobs:
4040
label: linux-64-py-3-9-mpich
4141
prefix: /usr/share/miniconda3/envs/my-env
4242

43-
- operating-system: ubuntu-latest
44-
python-version: 3.8
45-
label: linux-64-py-3-8-mpich
46-
prefix: /usr/share/miniconda3/envs/my-env
47-
4843
steps:
4944
- uses: actions/checkout@v2
5045
- uses: conda-incubator/setup-miniconda@v2.2.0

.github/workflows/unittest-openmpi.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ jobs:
4040
label: linux-64-py-3-9-openmpi
4141
prefix: /usr/share/miniconda3/envs/my-env
4242

43-
- operating-system: ubuntu-latest
44-
python-version: 3.8
45-
label: linux-64-py-3-8-openmpi
46-
prefix: /usr/share/miniconda3/envs/my-env
47-
4843
steps:
4944
- uses: actions/checkout@v2
5045
- uses: conda-incubator/setup-miniconda@v2.2.0

pympipool/__init__.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import shutil
3+
from typing import Optional
34
from ._version import get_versions
45
from pympipool.mpi.executor import PyMPIExecutor
56
from pympipool.shared.interface import SLURM_COMMAND
@@ -69,30 +70,30 @@ class Executor:
6970

7071
def __init__(
7172
self,
72-
max_workers=1,
73-
cores_per_worker=1,
74-
threads_per_core=1,
75-
gpus_per_worker=0,
76-
oversubscribe=False,
77-
init_function=None,
78-
cwd=None,
73+
max_workers: int = 1,
74+
cores_per_worker: int = 1,
75+
threads_per_core: int = 1,
76+
gpus_per_worker: int = 0,
77+
oversubscribe: bool = False,
78+
init_function: Optional[callable] = None,
79+
cwd: Optional[str] = None,
7980
executor=None,
80-
hostname_localhost=False,
81+
hostname_localhost: bool = False,
8182
):
8283
# Use __new__() instead of __init__(). This function is only implemented to enable auto-completion.
8384
pass
8485

8586
def __new__(
8687
cls,
87-
max_workers=1,
88-
cores_per_worker=1,
89-
threads_per_core=1,
90-
gpus_per_worker=0,
91-
oversubscribe=False,
92-
init_function=None,
93-
cwd=None,
88+
max_workers: int = 1,
89+
cores_per_worker: int = 1,
90+
threads_per_core: int = 1,
91+
gpus_per_worker: int = 0,
92+
oversubscribe: bool = False,
93+
init_function: Optional[callable] = None,
94+
cwd: Optional[str] = None,
9495
executor=None,
95-
hostname_localhost=False,
96+
hostname_localhost: bool = False,
9697
):
9798
"""
9899
Instead of returning a pympipool.Executor object this function returns either a pympipool.mpi.PyMPIExecutor,

pympipool/backend/serial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from os.path import abspath
22
import sys
3+
from typing import Optional
34

45
from pympipool.shared.communication import (
56
interface_connect,
@@ -10,7 +11,7 @@
1011
from pympipool.shared.backend import call_funct, parse_arguments
1112

1213

13-
def main(argument_lst=None):
14+
def main(argument_lst: Optional[list[str]] = None):
1415
if argument_lst is None:
1516
argument_lst = sys.argv
1617
argument_dict = parse_arguments(argument_lst=argument_lst)

pympipool/flux/executor.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import Optional
23

34
import flux.job
45

@@ -56,14 +57,14 @@ class PyFluxExecutor(ExecutorBroker):
5657

5758
def __init__(
5859
self,
59-
max_workers=1,
60-
cores_per_worker=1,
61-
threads_per_core=1,
62-
gpus_per_worker=0,
63-
init_function=None,
64-
cwd=None,
65-
executor=None,
66-
hostname_localhost=False,
60+
max_workers: int = 1,
61+
cores_per_worker: int = 1,
62+
threads_per_core: int = 1,
63+
gpus_per_worker: int = 0,
64+
init_function: Optional[callable] = None,
65+
cwd: Optional[str] = None,
66+
executor: Optional[flux.job.FluxExecutor] = None,
67+
hostname_localhost: Optional[bool] = False,
6768
):
6869
super().__init__()
6970
self._set_process(
@@ -92,12 +93,12 @@ def __init__(
9293
class FluxPythonInterface(BaseInterface):
9394
def __init__(
9495
self,
95-
cwd=None,
96-
cores=1,
97-
threads_per_core=1,
98-
gpus_per_core=0,
99-
oversubscribe=False,
100-
executor=None,
96+
cwd: Optional[str] = None,
97+
cores: int = 1,
98+
threads_per_core: int = 1,
99+
gpus_per_core: int = 0,
100+
oversubscribe: bool = False,
101+
executor: Optional[flux.job.FluxExecutor] = None,
101102
):
102103
super().__init__(
103104
cwd=cwd,
@@ -109,7 +110,7 @@ def __init__(
109110
self._executor = executor
110111
self._future = None
111112

112-
def bootup(self, command_lst):
113+
def bootup(self, command_lst: list[str]):
113114
if self._oversubscribe:
114115
raise ValueError(
115116
"Oversubscribing is currently not supported for the Flux adapter."
@@ -129,7 +130,7 @@ def bootup(self, command_lst):
129130
jobspec.cwd = self._cwd
130131
self._future = self._executor.submit(jobspec)
131132

132-
def shutdown(self, wait=True):
133+
def shutdown(self, wait: bool = True):
133134
if self.poll():
134135
self._future.cancel()
135136
# The flux future objects are not instantly updated,

pympipool/mpi/executor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
from pympipool.shared.executorbase import (
24
execute_parallel_tasks,
35
ExecutorBroker,
@@ -51,12 +53,12 @@ class PyMPIExecutor(ExecutorBroker):
5153

5254
def __init__(
5355
self,
54-
max_workers=1,
55-
cores_per_worker=1,
56-
oversubscribe=False,
57-
init_function=None,
58-
cwd=None,
59-
hostname_localhost=False,
56+
max_workers: int = 1,
57+
cores_per_worker: int = 1,
58+
oversubscribe: bool = False,
59+
init_function: Optional[callable] = None,
60+
cwd: Optional[str] = None,
61+
hostname_localhost: bool = False,
6062
):
6163
super().__init__()
6264
self._set_process(

pympipool/shared/backend.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from typing import Optional
12
import inspect
23

34

4-
def call_funct(input_dict, funct=None, memory=None):
5+
def call_funct(
6+
input_dict: dict, funct: Optional[callable] = None, memory: Optional[dict] = None
7+
) -> callable:
58
"""
69
Call function from dictionary
710
@@ -30,7 +33,7 @@ def funct(*args, **kwargs):
3033
return funct(input_dict["fn"], *input_dict["args"], **input_dict["kwargs"])
3134

3235

33-
def parse_arguments(argument_lst):
36+
def parse_arguments(argument_lst: list[str]) -> dict:
3437
"""
3538
Simple function to parse command line arguments
3639
@@ -50,7 +53,9 @@ def parse_arguments(argument_lst):
5053
)
5154

5255

53-
def update_default_dict_from_arguments(argument_lst, argument_dict, default_dict):
56+
def update_default_dict_from_arguments(
57+
argument_lst: list[str], argument_dict: dict, default_dict: dict
58+
) -> dict:
5459
default_dict.update(
5560
{
5661
k: argument_lst[argument_lst.index(v) + 1]
@@ -61,7 +66,9 @@ def update_default_dict_from_arguments(argument_lst, argument_dict, default_dict
6166
return default_dict
6267

6368

64-
def _update_dict_delta(dict_input, dict_output, keys_possible_lst):
69+
def _update_dict_delta(
70+
dict_input: dict, dict_output: dict, keys_possible_lst: list
71+
) -> dict:
6572
return {
6673
k: v
6774
for k, v in dict_input.items()

pympipool/shared/communication.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, interface=None):
1818
self._process = None
1919
self._interface = interface
2020

21-
def send_dict(self, input_dict):
21+
def send_dict(self, input_dict: dict):
2222
"""
2323
Send a dictionary with instructions to a connected client process.
2424
@@ -42,7 +42,7 @@ def receive_dict(self):
4242
error_type = output["error_type"].split("'")[1]
4343
raise eval(error_type)(output["error"])
4444

45-
def send_and_receive_dict(self, input_dict):
45+
def send_and_receive_dict(self, input_dict: dict) -> dict:
4646
"""
4747
Combine both the send_dict() and receive_dict() function in a single call.
4848
@@ -66,7 +66,7 @@ def bind_to_random_port(self):
6666
"""
6767
return self._socket.bind_to_random_port("tcp://*")
6868

69-
def bootup(self, command_lst):
69+
def bootup(self, command_lst: list[str]):
7070
"""
7171
Boot up the client process to connect to the SocketInterface.
7272
@@ -75,7 +75,7 @@ def bootup(self, command_lst):
7575
"""
7676
self._interface.bootup(command_lst=command_lst)
7777

78-
def shutdown(self, wait=True):
78+
def shutdown(self, wait: bool = True):
7979
result = None
8080
if self._interface.poll():
8181
result = self.send_and_receive_dict(
@@ -96,9 +96,9 @@ def __del__(self):
9696

9797

9898
def interface_bootup(
99-
command_lst,
99+
command_lst: list[str],
100100
connections,
101-
hostname_localhost=False,
101+
hostname_localhost: bool = False,
102102
):
103103
"""
104104
Start interface for ZMQ communication
@@ -132,7 +132,7 @@ def interface_bootup(
132132
return interface
133133

134134

135-
def interface_connect(host, port):
135+
def interface_connect(host: str, port: str):
136136
"""
137137
Connect to an existing SocketInterface instance by providing the hostname and the port as strings.
138138
@@ -146,7 +146,7 @@ def interface_connect(host, port):
146146
return context, socket
147147

148148

149-
def interface_send(socket, result_dict):
149+
def interface_send(socket: zmq.Socket, result_dict: dict):
150150
"""
151151
Send results to a SocketInterface instance.
152152
@@ -157,7 +157,7 @@ def interface_send(socket, result_dict):
157157
socket.send(cloudpickle.dumps(result_dict))
158158

159159

160-
def interface_receive(socket):
160+
def interface_receive(socket: zmq.Socket):
161161
"""
162162
Receive instructions from a SocketInterface instance.
163163
@@ -167,7 +167,7 @@ def interface_receive(socket):
167167
return cloudpickle.loads(socket.recv())
168168

169169

170-
def interface_shutdown(socket, context):
170+
def interface_shutdown(socket: zmq.Socket, context: zmq.Context):
171171
"""
172172
Close the connection to a SocketInterface instance.
173173

0 commit comments

Comments
 (0)