Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transport & Engine: factor out getcwd() & chdir() for compatibility with upcoming async transport #6594

Merged
merged 6 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
factor-out-cwd
  • Loading branch information
khsrali committed Oct 28, 2024
commit 4df223fd074c6ec15fd01dd42b38390862b3d328
7 changes: 6 additions & 1 deletion src/aiida/calculations/monitors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from aiida.orm import CalcJobNode
from aiida.transports import Transport
from aiida.transports.util import StrPath


def always_kill(node: CalcJobNode, transport: Transport) -> str | None:
Expand All @@ -19,7 +20,11 @@
:returns: A string if the job should be killed, `None` otherwise.
"""
with tempfile.NamedTemporaryFile('w+') as handle:
transport.getfile('_aiidasubmit.sh', handle.name)
cwd = node.get_remote_workdir()
if cwd is None:
raise ValueError('The remote work directory cannot be None')

Check warning on line 25 in src/aiida/calculations/monitors/base.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/calculations/monitors/base.py#L23-L25

Added lines #L23 - L25 were not covered by tests

transport.getfile(StrPath(cwd).join('_aiidasubmit.sh'), handle.name)

Check warning on line 27 in src/aiida/calculations/monitors/base.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/calculations/monitors/base.py#L27

Added line #L27 was not covered by tests
handle.seek(0)
output = handle.read()

Expand Down
6 changes: 1 addition & 5 deletions src/aiida/cmdline/commands/cmd_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,7 @@ def _computer_create_temp_file(transport, scheduler, authinfo, computer):
file_content = f"Test from 'verdi computer test' on {datetime.datetime.now().isoformat()}"
workdir = authinfo.get_workdir().format(username=transport.whoami())

try:
transport.chdir(workdir)
except OSError:
transport.makedirs(workdir)
transport.chdir(workdir)
transport.makedirs(workdir, ignore_existing=True)

with tempfile.NamedTemporaryFile(mode='w+') as tempf:
fname = os.path.split(tempf.name)[1]
Expand Down
123 changes: 66 additions & 57 deletions src/aiida/engine/daemon/execmanager.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion src/aiida/engine/processes/calcjobs/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,6 @@ def _perform_dry_run(self):
with LocalTransport() as transport:
with SubmitTestFolder() as folder:
calc_info = self.presubmit(folder)
transport.chdir(folder.abspath)
khsrali marked this conversation as resolved.
Show resolved Hide resolved
upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True)
self.node.dry_run_info = { # type: ignore[attr-defined]
'folder': folder.abspath,
Expand Down
1 change: 0 additions & 1 deletion src/aiida/engine/processes/calcjobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ async def task_monitor_job(
async def do_monitor():
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
transport.chdir(node.get_remote_workdir())
khsrali marked this conversation as resolved.
Show resolved Hide resolved
return monitors.process(node, transport)

try:
Expand Down
49 changes: 16 additions & 33 deletions src/aiida/orm/nodes/data/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,10 @@
transport = authinfo.get_transport()

with transport:
try:
transport.chdir(self.get_remote_path())
except OSError:
# If the transport OSError the directory no longer exists and was deleted
if not transport.isdir(self.get_remote_path()):
return True

return not transport.listdir()
return not transport.listdir(self.get_remote_path())

def getfile(self, relpath, destpath):
"""Connects to the remote folder and retrieves the content of a file.
Expand Down Expand Up @@ -96,22 +93,15 @@
authinfo = self.get_authinfo()

with authinfo.get_transport() as transport:
try:
full_path = os.path.join(self.get_remote_path(), relpath)
transport.chdir(full_path)
except OSError as exception:
if exception.errno in (2, 20): # directory not existing or not a directory
exc = OSError(
f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a '
'directory or has been deleted.'
)
exc.errno = exception.errno
raise exc from exception
else:
raise
full_path = os.path.join(self.get_remote_path(), relpath)
khsrali marked this conversation as resolved.
Show resolved Hide resolved
if not transport.isdir(full_path):
raise OSError(

Check warning on line 98 in src/aiida/orm/nodes/data/remote/base.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/nodes/data/remote/base.py#L98

Added line #L98 was not covered by tests
f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a '
'directory or has been deleted.'
)

try:
return transport.listdir()
return transport.listdir(full_path)
except OSError as exception:
if exception.errno in (2, 20): # directory not existing or not a directory
exc = OSError(
Expand All @@ -132,22 +122,15 @@
authinfo = self.get_authinfo()

with authinfo.get_transport() as transport:
try:
full_path = os.path.join(self.get_remote_path(), path)
transport.chdir(full_path)
except OSError as exception:
if exception.errno in (2, 20): # directory not existing or not a directory
exc = OSError(
f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a '
'directory or has been deleted.'
)
exc.errno = exception.errno
raise exc from exception
else:
raise
full_path = os.path.join(self.get_remote_path(), path)
if not transport.isdir(full_path):
raise OSError(

Check warning on line 127 in src/aiida/orm/nodes/data/remote/base.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/nodes/data/remote/base.py#L127

Added line #L127 was not covered by tests
f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a '
'directory or has been deleted.'
)

try:
return transport.listdir_withattributes()
return transport.listdir_withattributes(full_path)
except OSError as exception:
if exception.errno in (2, 20): # directory not existing or not a directory
exc = OSError(
Expand Down
5 changes: 1 addition & 4 deletions src/aiida/orm/utils/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,8 @@ def clean_remote(transport: Transport, path: str) -> None:
if not transport.is_open:
raise ValueError('the transport should already be open')

basedir, relative_path = os.path.split(path)

try:
transport.chdir(basedir)
transport.rmtree(relative_path)
transport.rmtree(path)
except OSError:
pass

Expand Down
7 changes: 4 additions & 3 deletions src/aiida/schedulers/plugins/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ class BashCliScheduler(Scheduler, metaclass=abc.ABCMeta):
def submit_job(self, working_directory: str, filename: str) -> str | ExitCode:
khsrali marked this conversation as resolved.
Show resolved Hide resolved
"""Submit a job.

:param working_directory: The absolute filepath to the working directory where the job is to be exectued.
:param working_directory: The absolute filepath to the working directory where the job is to be executed.
:param filename: The filename of the submission script relative to the working directory.
"""
self.transport.chdir(working_directory)
result = self.transport.exec_command_wait(self._get_submit_command(escape_for_bash(filename)))
result = self.transport.exec_command_wait(
self._get_submit_command(escape_for_bash(filename)), workdir=working_directory
)
return self._parse_submit_output(*result)

def get_jobs(
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/schedulers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def create_job_resource(cls, **kwargs):
def submit_job(self, working_directory: str, filename: str) -> str | ExitCode:
"""Submit a job.

:param working_directory: The absolute filepath to the working directory where the job is to be exectued.
:param working_directory: The absolute filepath to the working directory where the job is to be executed.
:param filename: The filename of the submission script relative to the working directory.
:returns:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/aiida/transports/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

from .plugins import *
from .transport import *
from .util import StrPath

__all__ = (
'SshTransport',
'Transport',
'convert_to_bool',
'parse_sshconfig',
'StrPath',
)

# fmt: on
45 changes: 31 additions & 14 deletions src/aiida/transports/plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@
###########################################################################
"""Local transport"""

###
### GP: a note on the local transport:
### I believe that we must not use os.chdir to keep track of the folder
### in which we are, since this may have very nasty side effects in other
### parts of code, and make things not thread-safe.
### we should instead keep track internally of the 'current working directory'
### in the exact same way as paramiko does already.

import contextlib
import errno
import glob
Expand All @@ -24,6 +16,7 @@
import shutil
import subprocess

from aiida.common.warnings import warn_deprecation
from aiida.transports import cli as transport_cli
from aiida.transports.transport import Transport, TransportInternalError

Expand Down Expand Up @@ -105,6 +98,10 @@ def chdir(self, path):
:param path: path to cd into
:raise OSError: if the directory does not have read attributes.
"""
warn_deprecation(
'`chdir()` is deprecated and will be removed in the next major version.',
version=3,
)
new_path = os.path.join(self.curdir, path)
if not os.path.isdir(new_path):
raise OSError(f"'{new_path}' is not a valid directory")
Expand All @@ -124,6 +121,10 @@ def normalize(self, path='.'):

def getcwd(self):
"""Returns the current working directory, emulated by the transport"""
warn_deprecation(
'`getcwd()` is deprecated and will be removed in the next major version.',
version=3,
)
return self.curdir

@staticmethod
Expand Down Expand Up @@ -695,11 +696,9 @@ def isfile(self, path):
return os.path.isfile(os.path.join(self.curdir, path))

@contextlib.contextmanager
def _exec_command_internal(self, command, **kwargs):
def _exec_command_internal(self, command, workdir=None, **kwargs):
"""Executes the specified command in bash login shell.

Before the command is executed, changes directory to the current
working directory as returned by self.getcwd().

For executing commands and waiting for them to finish, use
exec_command_wait.
Expand All @@ -710,6 +709,10 @@ def _exec_command_internal(self, command, **kwargs):

:param command: the command to execute. The command is assumed to be
already escaped using :py:func:`aiida.common.escaping.escape_for_bash`.
:param workdir: (optional, default=None) if set, the command will be executed
in the specified working directory.
if None, the command will be executed in the current working directory,
from DEPRECATED `self.getcwd()`.

:return: a tuple with (stdin, stdout, stderr, proc),
where stdin, stdout and stderr behave as file-like objects,
Expand All @@ -724,26 +727,40 @@ def _exec_command_internal(self, command, **kwargs):

command = bash_commmand + escape_for_bash(command)

if workdir:
cwd = workdir
else:
warn_deprecation(
khsrali marked this conversation as resolved.
Show resolved Hide resolved
'`getcwd()` is deprecated and will be removed in the next major version.'
'You should always pass `workdir` as an argument, instead.',
version=3,
)
cwd = self.getcwd()

with subprocess.Popen(
command,
shell=True,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=self.getcwd(),
cwd=cwd,
start_new_session=True,
) as process:
yield process

def exec_command_wait_bytes(self, command, stdin=None, **kwargs):
def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs):
"""Executes the specified command and waits for it to finish.

:param command: the command to execute
:param workdir: (optional, default=None) if set, the command will be executed
in the specified working directory.
if None, the command will be executed in the current working directory,
from DEPRECATED `self.getcwd()`.

:return: a tuple with (return_value, stdout, stderr) where stdout and stderr
are both bytes and the return_value is an int.
"""
with self._exec_command_internal(command) as process:
with self._exec_command_internal(command, workdir) as process:
if stdin is not None:
# Implicitly assume that the desired encoding is 'utf-8' if I receive a string.
# Also, if I get a StringIO, I just read it all in memory and put it into a BytesIO.
Expand Down
35 changes: 29 additions & 6 deletions src/aiida/transports/plugins/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from aiida.cmdline.params import options
from aiida.cmdline.params.types.path import AbsolutePathOrEmptyParamType
from aiida.common.escaping import escape_for_bash
from aiida.common.warnings import warn_deprecation

from ..transport import Transport, TransportInternalError

Expand Down Expand Up @@ -586,6 +587,10 @@ def chdir(self, path):
Differently from paramiko, if you pass None to chdir, nothing
happens and the cwd is unchanged.
"""
warn_deprecation(
'`chdir()` is deprecated and will be removed in the next major version.',
version=3,
)
from paramiko.sftp import SFTPError

old_path = self.sftp.getcwd()
Expand Down Expand Up @@ -651,6 +656,10 @@ def getcwd(self):
this method will return None. But in __enter__ this is set explicitly,
so this should never happen within this class.
"""
warn_deprecation(
'`chdir()` is deprecated and will be removed in the next major version.',
version=3,
)
return self.sftp.getcwd()

def makedirs(self, path, ignore_existing=False):
Expand Down Expand Up @@ -1276,11 +1285,9 @@ def isfile(self, path):
return False
raise # Typically if I don't have permissions (errno=13)

def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1):
def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1, workdir=None):
"""Executes the specified command in bash login shell.

Before the command is executed, changes directory to the current
working directory as returned by self.getcwd().

For executing commands and waiting for them to finish, use
exec_command_wait.
Expand All @@ -1291,6 +1298,10 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1):
stderr on the same buffer (i.e., stdout).
Note: If combine_stderr is True, stderr will always be empty.
:param bufsize: same meaning of the one used by paramiko.
:param workdir: (optional, default=None) if set, the command will be executed
in the specified working directory.
if None, the command will be executed in the current working directory,
from DEPRECATED `self.getcwd()`, if that has a value.

:return: a tuple with (stdin, stdout, stderr, channel),
where stdin, stdout and stderr behave as file-like objects,
Expand All @@ -1300,7 +1311,13 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1):
channel = self.sshclient.get_transport().open_session()
channel.set_combine_stderr(combine_stderr)

if self.getcwd() is not None:
if workdir is not None:
command_to_execute = f'cd {workdir} && ( {command} )'
elif self.getcwd() is not None:
warn_deprecation(
khsrali marked this conversation as resolved.
Show resolved Hide resolved
'`getcwd()` is deprecated and will be removed in the next major version.',
version=3,
)
escaped_folder = escape_for_bash(self.getcwd())
command_to_execute = f'cd {escaped_folder} && ( {command} )'
else:
Expand All @@ -1320,7 +1337,9 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1):

return stdin, stdout, stderr, channel

def exec_command_wait_bytes(self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01):
def exec_command_wait_bytes(
self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01, workdir=None
):
"""Executes the specified command and waits for it to finish.

:param command: the command to execute
Expand All @@ -1330,14 +1349,18 @@ def exec_command_wait_bytes(self, command, stdin=None, combine_stderr=False, buf
self._exec_command_internal()
:param bufsize: same meaning of paramiko.
:param timeout: ssh channel timeout for stdout, stderr.
:param workdir: (optional, default=None) if set, the command will be executed
in the specified working directory.

:return: a tuple with (return_value, stdout, stderr) where stdout and stderr
are both bytes and the return_value is an int.
"""
import socket
import time

ssh_stdin, stdout, stderr, channel = self._exec_command_internal(command, combine_stderr, bufsize=bufsize)
ssh_stdin, stdout, stderr, channel = self._exec_command_internal(
command, combine_stderr, bufsize=bufsize, workdir=workdir
)

if stdin is not None:
if isinstance(stdin, str):
Expand Down
Loading
Loading