Skip to content

Commit

Permalink
copy-remote adopted with behaviour of asyncssh
Browse files Browse the repository at this point in the history
  • Loading branch information
khsrali committed Dec 4, 2024
1 parent cc0bc5c commit 3210c27
Showing 1 changed file with 82 additions and 21 deletions.
103 changes: 82 additions & 21 deletions src/aiida/transports/plugins/ssh_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ async def copy_async(
):
"""Copy a file or a folder from remote to remote.
:param remotesource: path to the remote source directory / file
:param remotedestination: path to the remote destination directory / file
:param remotesource: abs path to the remote source directory / file
:param remotedestination: abs path to the remote destination directory / file
:param dereference: follow symbolic links
:param recursive: copy recursively
:param preserve: preserve file attributes
Expand All @@ -528,27 +528,88 @@ async def copy_async(
if not remotesource:
raise ValueError('remotesource must be a non empty string')

try:
# SFTP.copy() supports remote copy only in very recent version OpenSSH 9.0 and later.
# For the older versions, it downloads the file and uploads it again!
# For performance reasons, we should check if the remote copy is supported, if so use
# self._sftp.mcopy() & self._sftp.copy() otherwise send a `cp` command to the remote machine.
# This is a temporary solution until the feature is implemented in asyncssh:
# See here: https://github.com/ronf/asyncssh/issues/724
if False:
# self._sftp._supports_copy_data:
try: # type: ignore[unreachable]
if self.has_magic(remotesource):
await self._sftp.mcopy(
remotesource,
remotedestination,
preserve=preserve,
recurse=recursive,
follow_symlinks=dereference,
)
else:
if not await self.path_exists_async(remotesource):
raise OSError(f'The remote path {remotesource} does not exist')
await self._sftp.copy(
remotesource,
remotedestination,
preserve=preserve,
recurse=recursive,
follow_symlinks=dereference,
)
except asyncssh.sftp.SFTPFailure as exc:
raise OSError(f'Error while copying {remotesource} to {remotedestination}: {exc}')
else:
# I copy pasted the whole logic below from SshTransport class:

async def _exec_cp(cp_exe: str, cp_flags: str, src: str, dst: str):
"""Execute the ``cp`` command on the remote machine."""
# to simplify writing the above copy function
command = f'{cp_exe} {cp_flags} {escape_for_bash(src)} {escape_for_bash(dst)}'

retval, stdout, stderr = await self.exec_command_wait_async(command)

if retval == 0:
if stderr.strip():
self.logger.warning(f'There was nonempty stderr in the cp command: {stderr}')
else:
self.logger.error(
"Problem executing cp. Exit code: {}, stdout: '{}', " "stderr: '{}', command: '{}'".format(
retval, stdout, stderr, command
)
)
if 'No such file or directory' in str(stderr):
raise FileNotFoundError(f'Error while executing cp: {stderr}')

raise OSError(
'Error while executing cp. Exit code: {}, '
"stdout: '{}', stderr: '{}', "
"command: '{}'".format(retval, stdout, stderr, command)
)

cp_exe = 'cp'
cp_flags = '-f'

if recursive:
cp_flags += ' -r'

if preserve:
cp_flags += ' -p'

if dereference:
# use -L; --dereference is not supported on mac
cp_flags += ' -L'

if self.has_magic(remotesource):
await self._sftp.mcopy(
remotesource,
remotedestination,
preserve=preserve,
recurse=recursive,
follow_symlinks=dereference,
)
to_copy_list = await self.glob_async(remotesource)

if len(to_copy_list) > 1:
if not self.path_exists(remotedestination) or self.isfile(remotedestination):
raise OSError("Can't copy more than one file in the same destination file")

for file in to_copy_list:
await _exec_cp(cp_exe, cp_flags, file, remotedestination)

else:
if not await self.path_exists_async(remotesource):
raise OSError(f'The remote path {remotesource} does not exist')
await self._sftp.copy(
remotesource,
remotedestination,
preserve=preserve,
recurse=recursive,
follow_symlinks=dereference,
)
except asyncssh.sftp.SFTPFailure as exc:
raise OSError(f'Error while copying {remotesource} to {remotedestination}: {exc}')
await _exec_cp(cp_exe, cp_flags, remotesource, remotedestination)

async def copyfile_async(
self,
Expand Down

0 comments on commit 3210c27

Please sign in to comment.