Skip to content

Commit

Permalink
Type checking: aiida/engine (+bug fixes) (#4669)
Browse files Browse the repository at this point in the history
Added type checking for the modules

* `aiida.engine`
* `aiida.manage.manager`

Move `aiida.orm` imports to top of file in `aiida.engine` module. This should be
fine as `aiida.orm` should not import anything from `aiida.engine` and this way
we don't need import guards specifically for type checking.
  • Loading branch information
chrisjsewell authored Jan 25, 2021
1 parent 0e1f39f commit 2ae0f42
Show file tree
Hide file tree
Showing 40 changed files with 1,166 additions and 855 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ repos:
files: >-
(?x)^(
aiida/common/progress_reporter.py|
aiida/engine/processes/calcjobs/calcjob.py|
aiida/manage/manager.py|
aiida/engine/.*py|
aiida/manage/database/delete/nodes.py|
aiida/tools/graph/graph_traversers.py|
aiida/tools/groups/paths.py|
Expand Down
2 changes: 1 addition & 1 deletion aiida/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
from .processes import *
from .utils import *

__all__ = (launch.__all__ + processes.__all__ + utils.__all__)
__all__ = (launch.__all__ + processes.__all__ + utils.__all__) # type: ignore[name-defined]
76 changes: 44 additions & 32 deletions aiida/engine/daemon/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@
import shutil
import socket
import tempfile
from typing import Any, Dict, Optional, TYPE_CHECKING

from aiida.manage.configuration import get_config, get_config_option
from aiida.manage.configuration.profile import Profile

if TYPE_CHECKING:
from circus.client import CircusClient

VERDI_BIN = shutil.which('verdi')
# Recent versions of virtualenv create the environment variable VIRTUAL_ENV
VIRTUALENV = os.environ.get('VIRTUAL_ENV', None)

# see https://github.com/python/typing/issues/182
JsonDictType = Dict[str, Any]


class ControllerProtocol(enum.Enum):
"""
Expand All @@ -33,13 +41,13 @@ class ControllerProtocol(enum.Enum):
TCP = 1


def get_daemon_client(profile_name=None):
def get_daemon_client(profile_name: Optional[str] = None) -> 'DaemonClient':
"""
Return the daemon client for the given profile or the current profile if not specified.
:param profile_name: the profile name, will use the current profile if None
:return: the daemon client
:rtype: :class:`aiida.engine.daemon.client.DaemonClient`
:raises aiida.common.MissingConfigurationError: if the configuration file cannot be found
:raises aiida.common.ProfileConfigurationError: if the given profile does not exist
"""
Expand All @@ -65,30 +73,30 @@ class DaemonClient: # pylint: disable=too-many-public-methods
_DAEMON_NAME = 'aiida-{name}'
_ENDPOINT_PROTOCOL = ControllerProtocol.IPC

def __init__(self, profile):
def __init__(self, profile: Profile):
"""
Construct a DaemonClient instance for a given profile
:param profile: the profile instance :class:`aiida.manage.configuration.profile.Profile`
"""
config = get_config()
self._profile = profile
self._SOCKET_DIRECTORY = None # pylint: disable=invalid-name
self._DAEMON_TIMEOUT = config.get_option('daemon.timeout') # pylint: disable=invalid-name
self._SOCKET_DIRECTORY: Optional[str] = None # pylint: disable=invalid-name
self._DAEMON_TIMEOUT: int = config.get_option('daemon.timeout') # pylint: disable=invalid-name

@property
def profile(self):
def profile(self) -> Profile:
return self._profile

@property
def daemon_name(self):
def daemon_name(self) -> str:
"""
Get the daemon name which is tied to the profile name
"""
return self._DAEMON_NAME.format(name=self.profile.name)

@property
def cmd_string(self):
def cmd_string(self) -> str:
"""
Return the command string to start the AiiDA daemon
"""
Expand All @@ -101,42 +109,42 @@ def cmd_string(self):
return f'{VERDI_BIN} -p {self.profile.name} devel run_daemon'

@property
def loglevel(self):
def loglevel(self) -> str:
return get_config_option('logging.circus_loglevel')

@property
def virtualenv(self):
def virtualenv(self) -> Optional[str]:
return VIRTUALENV

@property
def circus_log_file(self):
def circus_log_file(self) -> str:
return self.profile.filepaths['circus']['log']

@property
def circus_pid_file(self):
def circus_pid_file(self) -> str:
return self.profile.filepaths['circus']['pid']

@property
def circus_port_file(self):
def circus_port_file(self) -> str:
return self.profile.filepaths['circus']['port']

@property
def circus_socket_file(self):
def circus_socket_file(self) -> str:
return self.profile.filepaths['circus']['socket']['file']

@property
def circus_socket_endpoints(self):
def circus_socket_endpoints(self) -> Dict[str, str]:
return self.profile.filepaths['circus']['socket']

@property
def daemon_log_file(self):
def daemon_log_file(self) -> str:
return self.profile.filepaths['daemon']['log']

@property
def daemon_pid_file(self):
def daemon_pid_file(self) -> str:
return self.profile.filepaths['daemon']['pid']

def get_circus_port(self):
def get_circus_port(self) -> int:
"""
Retrieve the port for the circus controller, which should be written to the circus port file. If the
daemon is running, the port file should exist and contain the port to which the controller is connected.
Expand All @@ -158,7 +166,7 @@ def get_circus_port(self):

return port

def get_circus_socket_directory(self):
def get_circus_socket_directory(self) -> str:
"""
Retrieve the absolute path of the directory where the circus sockets are stored if the IPC protocol is
used and the daemon is running. If the daemon is running, the sockets file should exist and contain the
Expand Down Expand Up @@ -192,7 +200,7 @@ def get_circus_socket_directory(self):
self._SOCKET_DIRECTORY = socket_dir_path
return socket_dir_path

def get_daemon_pid(self):
def get_daemon_pid(self) -> Optional[int]:
"""
Get the daemon pid which should be written in the daemon pid file specific to the profile
Expand All @@ -207,15 +215,15 @@ def get_daemon_pid(self):
return None

@property
def is_daemon_running(self):
def is_daemon_running(self) -> bool:
"""
Return whether the daemon is running, which is determined by seeing if the daemon pid file is present
:return: True if daemon is running, False otherwise
"""
return self.get_daemon_pid() is not None

def delete_circus_socket_directory(self):
def delete_circus_socket_directory(self) -> None:
"""
Attempt to delete the directory used to store the circus endpoint sockets. Will not raise if the
directory does not exist
Expand Down Expand Up @@ -321,7 +329,7 @@ def get_tcp_endpoint(self, port=None):
return endpoint

@property
def client(self):
def client(self) -> 'CircusClient':
"""
Return an instance of the CircusClient with the endpoint defined by the controller endpoint, which
used the port that was written to the port file upon starting of the daemon
Expand All @@ -334,7 +342,7 @@ def client(self):
from circus.client import CircusClient
return CircusClient(endpoint=self.get_controller_endpoint(), timeout=self._DAEMON_TIMEOUT)

def call_client(self, command):
def call_client(self, command: JsonDictType) -> JsonDictType:
"""
Call the client with a specific command. Will check whether the daemon is running first
by checking for the pid file. When the pid is found yet the call still fails with a
Expand All @@ -358,47 +366,51 @@ def call_client(self, command):

return result

def get_status(self):
def get_status(self) -> JsonDictType:
"""
Get the daemon running status
:return: the client call response
If successful, will will contain 'status' key
"""
command = {'command': 'status', 'properties': {'name': self.daemon_name}}

return self.call_client(command)

def get_numprocesses(self):
def get_numprocesses(self) -> JsonDictType:
"""
Get the number of running daemon processes
:return: the client call response
If successful, will contain 'numprocesses' key
"""
command = {'command': 'numprocesses', 'properties': {'name': self.daemon_name}}

return self.call_client(command)

def get_worker_info(self):
def get_worker_info(self) -> JsonDictType:
"""
Get workers statistics for this daemon
:return: the client call response
If successful, will contain 'info' key
"""
command = {'command': 'stats', 'properties': {'name': self.daemon_name}}

return self.call_client(command)

def get_daemon_info(self):
def get_daemon_info(self) -> JsonDictType:
"""
Get statistics about this daemon itself
:return: the client call response
If successful, will contain 'info' key
"""
command = {'command': 'dstats', 'properties': {}}

return self.call_client(command)

def increase_workers(self, number):
def increase_workers(self, number: int) -> JsonDictType:
"""
Increase the number of workers
Expand All @@ -409,7 +421,7 @@ def increase_workers(self, number):

return self.call_client(command)

def decrease_workers(self, number):
def decrease_workers(self, number: int) -> JsonDictType:
"""
Decrease the number of workers
Expand All @@ -420,7 +432,7 @@ def decrease_workers(self, number):

return self.call_client(command)

def stop_daemon(self, wait):
def stop_daemon(self, wait: bool) -> JsonDictType:
"""
Stop the daemon
Expand All @@ -436,7 +448,7 @@ def stop_daemon(self, wait):

return result

def restart_daemon(self, wait):
def restart_daemon(self, wait: bool) -> JsonDictType:
"""
Restart the daemon
Expand Down
Loading

0 comments on commit 2ae0f42

Please sign in to comment.