Skip to content

Commit

Permalink
review applied
Browse files Browse the repository at this point in the history
  • Loading branch information
khsrali committed Nov 25, 2024
1 parent 6e350e7 commit 03ccc30
Show file tree
Hide file tree
Showing 23 changed files with 483 additions and 486 deletions.
4 changes: 2 additions & 2 deletions src/aiida/calculations/monitors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from typing import Union

from aiida.orm import CalcJobNode
from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport


def always_kill(node: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport']) -> str | None:
def always_kill(node: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> str | None:
"""Retrieve and inspect files in working directory of job to determine whether the job should be killed.
This particular implementation is just for demonstration purposes and will kill the job as long as there is a
Expand Down
16 changes: 7 additions & 9 deletions src/aiida/engine/daemon/execmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from aiida.schedulers.datastructures import JobState

if TYPE_CHECKING:
from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport

Check warning on line 38 in src/aiida/engine/daemon/execmanager.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/engine/daemon/execmanager.py#L38

Added line #L38 was not covered by tests

REMOTE_WORK_DIRECTORY_LOST_FOUND = 'lost+found'

Expand Down Expand Up @@ -64,7 +64,7 @@ def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]:

async def upload_calculation(
node: CalcJobNode,
transport: Union['BlockingTransport', 'AsyncTransport'],
transport: Union['Transport', 'AsyncTransport'],
calc_info: CalcInfo,
folder: Folder,
inputs: Optional[MappingType[str, Any]] = None,
Expand Down Expand Up @@ -393,9 +393,7 @@ async def _copy_sandbox_files(logger, node, transport, folder, workdir: Path):
await transport.put_async(folder.get_abs_path(filename), workdir.joinpath(filename))


def submit_calculation(
calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport']
) -> str | ExitCode:
def submit_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> str | ExitCode:
"""Submit a previously uploaded `CalcJob` to the scheduler.
:param calculation: the instance of CalcJobNode to submit.
Expand Down Expand Up @@ -425,7 +423,7 @@ def submit_calculation(
return result


async def stash_calculation(calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport']) -> None:
async def stash_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> None:
"""Stash files from the working directory of a completed calculation to a permanent remote folder.
After a calculation has been completed, optionally stash files from the work directory to a storage location on the
Expand Down Expand Up @@ -491,7 +489,7 @@ async def stash_calculation(calculation: CalcJobNode, transport: Union['Blocking


async def retrieve_calculation(
calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport'], retrieved_temporary_folder: str
calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport'], retrieved_temporary_folder: str
) -> FolderData | None:
"""Retrieve all the files of a completed job calculation using the given transport.
Expand Down Expand Up @@ -556,7 +554,7 @@ async def retrieve_calculation(
return retrieved_files


def kill_calculation(calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport']) -> None:
def kill_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> None:
"""Kill the calculation through the scheduler
:param calculation: the instance of CalcJobNode to kill.
Expand Down Expand Up @@ -591,7 +589,7 @@ def kill_calculation(calculation: CalcJobNode, transport: Union['BlockingTranspo

async def retrieve_files_from_list(
calculation: CalcJobNode,
transport: Union['BlockingTransport', 'AsyncTransport'],
transport: Union['Transport', 'AsyncTransport'],
folder: str,
retrieve_list: List[Union[str, Tuple[str, str, int], list]],
) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/engine/processes/calcjobs/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from aiida.plugins import BaseFactory

if t.TYPE_CHECKING:
from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport

Check warning on line 19 in src/aiida/engine/processes/calcjobs/monitors.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/engine/processes/calcjobs/monitors.py#L19

Added line #L19 was not covered by tests

LOGGER = AIIDA_LOGGER.getChild(__name__)

Expand Down Expand Up @@ -124,7 +124,7 @@ def validate(self):

if any(required_parameter not in parameters for required_parameter in ('node', 'transport')):
correct_signature = (
"(node: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport'], **kwargs) str | None:"
"(node: CalcJobNode, transport: Union['Transport', 'AsyncTransport'], **kwargs) str | None:"
)
raise ValueError(
f'The monitor `{self.entry_point}` has an invalid function signature, it should be: {correct_signature}'
Expand Down Expand Up @@ -179,7 +179,7 @@ def monitors(self) -> collections.OrderedDict:
def process(
self,
node: CalcJobNode,
transport: Union['BlockingTransport', 'AsyncTransport'],
transport: Union['Transport', 'AsyncTransport'],
) -> CalcJobMonitorResult | None:
"""Call all monitors in order and return the result as one returns anything other than ``None``.
Expand Down
6 changes: 2 additions & 4 deletions src/aiida/engine/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aiida.orm import AuthInfo

if TYPE_CHECKING:
from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport

Check warning on line 21 in src/aiida/engine/transports.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/engine/transports.py#L21

Added line #L21 was not covered by tests

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,9 +54,7 @@ def loop(self) -> asyncio.AbstractEventLoop:
return self._loop

@contextlib.contextmanager
def request_transport(
self, authinfo: AuthInfo
) -> Iterator[Awaitable[Union['BlockingTransport', 'AsyncTransport']]]:
def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable[Union['Transport', 'AsyncTransport']]]:
"""Request a transport from an authinfo. Because the client is not allowed to
request a transport immediately they will instead be given back a future
that can be awaited to get the transport::
Expand Down
4 changes: 2 additions & 2 deletions src/aiida/orm/authinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from aiida.orm import Computer, User
from aiida.orm.implementation import StorageBackend
from aiida.orm.implementation.authinfos import BackendAuthInfo # noqa: F401
from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport

Check warning on line 24 in src/aiida/orm/authinfos.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/authinfos.py#L24

Added line #L24 was not covered by tests

__all__ = ('AuthInfo',)

Expand Down Expand Up @@ -166,7 +166,7 @@ def get_workdir(self) -> str:
except KeyError:
return self.computer.get_workdir()

def get_transport(self) -> Union['BlockingTransport', 'AsyncTransport']:
def get_transport(self) -> Union['Transport', 'AsyncTransport']:
"""Return a fully configured transport that can be used to connect to the computer set for this instance."""
computer = self.computer
transport_type = computer.transport_type
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/orm/computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from aiida.orm import AuthInfo, User
from aiida.orm.implementation import StorageBackend
from aiida.schedulers import Scheduler
from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport

Check warning on line 26 in src/aiida/orm/computers.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/computers.py#L26

Added line #L26 was not covered by tests

__all__ = ('Computer',)

Expand Down Expand Up @@ -622,7 +622,7 @@ def is_user_enabled(self, user: 'User') -> bool:
# Return False if the user is not configured (in a sense, it is disabled for that user)
return False

def get_transport(self, user: Optional['User'] = None) -> Union['BlockingTransport', 'AsyncTransport']:
def get_transport(self, user: Optional['User'] = None) -> Union['Transport', 'AsyncTransport']:
"""Return a Transport class, configured with all correct parameters.
The Transport is closed (meaning that if you want to run any operation with
it, you have to open it first (i.e., e.g. for a SSH transport, you have
Expand All @@ -646,7 +646,7 @@ def get_transport(self, user: Optional['User'] = None) -> Union['BlockingTranspo
authinfo = authinfos.AuthInfo.get_collection(self.backend).get(dbcomputer=self, aiidauser=user)
return authinfo.get_transport()

def get_transport_class(self) -> Union[Type['BlockingTransport'], Type['AsyncTransport']]:
def get_transport_class(self) -> Union[Type['Transport'], Type['AsyncTransport']]:
"""Get the transport class for this computer. Can be used to instantiate a transport instance."""
try:
return TransportFactory(self.transport_type)
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/orm/nodes/data/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def listdir_withattributes(self, path='.'):
:param relpath: If 'relpath' is specified, lists the content of the given subfolder.
:return: a list of dictionaries, where the documentation
is in :py:class:BlockingTransport.listdir_withattributes.
is in :py:class:Transport.listdir_withattributes.
"""
authinfo = self.get_authinfo()

Expand Down
6 changes: 3 additions & 3 deletions src/aiida/orm/nodes/process/calculation/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from aiida.parsers import Parser
from aiida.schedulers.datastructures import JobInfo, JobState
from aiida.tools.calculations import CalculationTools
from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport

Check warning on line 29 in src/aiida/orm/nodes/process/calculation/calcjob.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/nodes/process/calculation/calcjob.py#L29

Added line #L29 was not covered by tests

__all__ = ('CalcJobNode',)

Expand Down Expand Up @@ -450,10 +450,10 @@ def get_authinfo(self) -> 'AuthInfo':

return computer.get_authinfo(self.user)

def get_transport(self) -> Union['BlockingTransport', 'AsyncTransport']:
def get_transport(self) -> Union['Transport', 'AsyncTransport']:
"""Return the transport for this calculation.
:return: Union['BlockingTransport', 'AsyncTransport'] configured
:return: Union['Transport', 'AsyncTransport'] configured
with the `AuthInfo` associated to the computer of this node
"""
return self.get_authinfo().get_transport()
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/orm/utils/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@

from aiida import orm
from aiida.orm.implementation import StorageBackend
from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport

Check warning on line 24 in src/aiida/orm/utils/remote.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/utils/remote.py#L24

Added line #L24 was not covered by tests


def clean_remote(transport: Union['BlockingTransport', 'AsyncTransport'], path: str) -> None:
def clean_remote(transport: Union['Transport', 'AsyncTransport'], path: str) -> None:
"""Recursively remove a remote folder, with the given absolute path, and all its contents. The path should be
made accessible through the transport channel, which should already be open
:param transport: an open Union['BlockingTransport', 'AsyncTransport'] channel
:param transport: an open Union['Transport', 'AsyncTransport'] channel
:param path: an absolute path on the remote made available through the transport
"""
if not isinstance(path, str):
Expand Down
14 changes: 7 additions & 7 deletions src/aiida/plugins/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from aiida.schedulers import Scheduler
from aiida.tools.data.orbital import Orbital
from aiida.tools.dbimporters import DbImporter
from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport

Check warning on line 45 in src/aiida/plugins/factories.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/plugins/factories.py#L45

Added line #L45 was not covered by tests


def raise_invalid_type_error(entry_point_name: str, entry_point_group: str, valid_classes: Tuple[Any, ...]) -> NoReturn:
Expand Down Expand Up @@ -412,7 +412,7 @@ def StorageFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint
@overload
def TransportFactory(
entry_point_name: str, load: Literal[True] = True
) -> Union[Type['BlockingTransport'], Type['AsyncTransport']]: ...
) -> Union[Type['Transport'], Type['AsyncTransport']]: ...


@overload
Expand All @@ -421,25 +421,25 @@ def TransportFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint:

def TransportFactory(
entry_point_name: str, load: bool = True
) -> Union[EntryPoint, Type['BlockingTransport'], Type['AsyncTransport']]:
"""Return the Union['BlockingTransport', 'AsyncTransport'] sub class registered under the given entry point.
) -> Union[EntryPoint, Type['Transport'], Type['AsyncTransport']]:
"""Return the Union['Transport', 'AsyncTransport'] sub class registered under the given entry point.
:param entry_point_name: the entry point name.
:param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself.
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from inspect import isclass

from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport

entry_point_group = 'aiida.transports'
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (BlockingTransport, AsyncTransport)
valid_classes = (Transport, AsyncTransport)

if not load:
return entry_point

if isclass(entry_point) and (issubclass(entry_point, BlockingTransport) or issubclass(entry_point, AsyncTransport)):
if isclass(entry_point) and (issubclass(entry_point, Transport) or issubclass(entry_point, AsyncTransport)):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand Down
4 changes: 2 additions & 2 deletions src/aiida/schedulers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from aiida.schedulers.datastructures import JobInfo, JobResource, JobTemplate, JobTemplateCodeInfo

if t.TYPE_CHECKING:
from aiida.transports import AsyncTransport, BlockingTransport
from aiida.transports import AsyncTransport, Transport

Check warning on line 25 in src/aiida/schedulers/scheduler.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/schedulers/scheduler.py#L25

Added line #L25 was not covered by tests

__all__ = ('Scheduler', 'SchedulerError', 'SchedulerParsingError')

Expand Down Expand Up @@ -366,7 +366,7 @@ def transport(self):

return self._transport

def set_transport(self, transport: Union['BlockingTransport', 'AsyncTransport']):
def set_transport(self, transport: Union['Transport', 'AsyncTransport']):
"""Set the transport to be used to query the machine or to submit scripts.
This class assumes that the transport is open and active.
Expand Down
1 change: 0 additions & 1 deletion src/aiida/transports/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

__all__ = (
'Transport',
'BlockingTransport',
'SshTransport',
'AsyncTransport',
'convert_to_bool',
Expand Down
Loading

0 comments on commit 03ccc30

Please sign in to comment.