Skip to content

Commit ecb516e

Browse files
committed
Reimplement SSHRemoteIO with datalad_next.shell
This takes out all of the old remote shell implementation, and uses the new one. It does not touch the get/put implementations (yet). They can also be done with the new shell feature, but that is a different problem.
1 parent 69b0885 commit ecb516e

File tree

1 file changed

+78
-113
lines changed

1 file changed

+78
-113
lines changed

datalad_next/patches/replace_sshremoteio.py

+78-113
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,30 @@
1-
import logging
2-
31
from urllib.parse import urlparse
42
from urllib.request import unquote
53

64
from datalad.distributed.ora_remote import (
75
DEFAULT_BUFFER_SIZE,
86
IOBase,
97
RemoteError,
10-
RemoteCommandFailedError,
118
RIARemoteError,
129
contextmanager,
1310
functools,
1411
on_osx,
1512
sh_quote,
1613
ssh_manager,
1714
stat,
18-
subprocess,
1915
)
2016

2117
from datalad_next.exceptions import CapturedException
2218
from datalad_next.patches import apply_patch
19+
from datalad_next.runners import CommandError
20+
from datalad_next.shell import shell
2321

2422

2523
class SSHRemoteIO(IOBase):
2624
"""IO operation if the object tree is SSH-accessible
2725
2826
It doesn't even think about a windows server.
2927
"""
30-
31-
# output markers to detect possible command failure as well as end of output
32-
# from a particular command:
33-
REMOTE_CMD_FAIL = "ora-remote: end - fail"
34-
REMOTE_CMD_OK = "ora-remote: end - ok"
35-
3628
def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
3729
"""
3830
Parameters
@@ -43,6 +35,7 @@ def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
4335
"""
4436
parsed_url = urlparse(ssh_url)
4537

38+
self.url = ssh_url
4639
# the connection to the remote
4740
# we don't open it yet, not yet clear if needed
4841
self.ssh = ssh_manager.get_connection(
@@ -52,29 +45,22 @@ def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
5245
self.ssh.open()
5346
# open a remote shell
5447
cmd = ['ssh'] + self.ssh._ssh_args + [self.ssh.sshri.as_str()]
55-
self.shell = subprocess.Popen(cmd,
56-
stderr=subprocess.DEVNULL,
57-
stdout=subprocess.PIPE,
58-
stdin=subprocess.PIPE)
59-
# swallow login message(s):
60-
self.shell.stdin.write(b"echo RIA-REMOTE-LOGIN-END\n")
61-
self.shell.stdin.flush()
62-
while True:
63-
line = self.shell.stdout.readline()
64-
if line == b"RIA-REMOTE-LOGIN-END\n":
65-
break
66-
# TODO: Same for stderr?
67-
68-
# make sure default is used when None was passed, too.
69-
self.buffer_size = buffer_size if buffer_size else DEFAULT_BUFFER_SIZE
48+
# we settle on `bash` as a shell. It should be around and then we
49+
# can count on it
50+
cmd.append('bash')
51+
self.servershell_context = shell(
52+
cmd,
53+
chunk_size=buffer_size,
54+
)
55+
self.servershell = self.servershell_context.__enter__()
7056

7157
# if the URL had a path, we try to 'cd' into it to make operations on
7258
# relative paths intelligible
7359
if parsed_url.path:
7460
# unquote path
7561
real_path = unquote(parsed_url.path)
7662
try:
77-
self._run(
63+
self.servershell(
7864
f'cd {sh_quote(real_path)}',
7965
check=True,
8066
)
@@ -84,22 +70,7 @@ def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
8470
CapturedException(e)
8571

8672
def close(self):
87-
# try exiting shell clean first
88-
self.shell.stdin.write(b"exit\n")
89-
self.shell.stdin.flush()
90-
exitcode = self.shell.wait(timeout=0.5)
91-
# be more brutal if it doesn't work
92-
if exitcode is None: # timed out
93-
# TODO: Theoretically terminate() can raise if not successful.
94-
# How to deal with that?
95-
self.shell.terminate()
96-
97-
def _append_end_markers(self, cmd):
98-
"""Append end markers to remote command"""
99-
100-
return cmd + " && printf '%s\\n' {} || printf '%s\\n' {}\n".format(
101-
sh_quote(self.REMOTE_CMD_OK),
102-
sh_quote(self.REMOTE_CMD_FAIL))
73+
self.servershell_context.__exit__(None, None, None)
10374

10475
def _get_download_size_from_key(self, key):
10576
"""Get the size of an annex object file from it's key
@@ -110,7 +81,7 @@ def _get_download_size_from_key(self, key):
11081
Parameter
11182
---------
11283
key: str
113-
annex key of the file
84+
annex key of the filte
11485
11586
Returns
11687
-------
@@ -154,38 +125,6 @@ def _get_download_size_from_key(self, key):
154125
else:
155126
raise RIARemoteError("invalid key: {}".format(key))
156127

157-
def _run(self, cmd, no_output=True, check=False):
158-
159-
# TODO: we might want to redirect stderr to stdout here (or have
160-
# additional end marker in stderr) otherwise we can't empty stderr
161-
# to be ready for next command. We also can't read stderr for
162-
# better error messages (RemoteError) without making sure there's
163-
# something to read in any case (it's blocking!).
164-
# However, if we are sure stderr can only ever happen if we would
165-
# raise RemoteError anyway, it might be okay.
166-
call = self._append_end_markers(cmd)
167-
self.shell.stdin.write(call.encode())
168-
self.shell.stdin.flush()
169-
170-
lines = []
171-
while True:
172-
line = self.shell.stdout.readline().decode()
173-
lines.append(line)
174-
if line == self.REMOTE_CMD_OK + '\n':
175-
# end reading
176-
break
177-
elif line == self.REMOTE_CMD_FAIL + '\n':
178-
if check:
179-
raise RemoteCommandFailedError(
180-
"{cmd} failed: {msg}".format(cmd=cmd,
181-
msg="".join(lines[:-1]))
182-
)
183-
else:
184-
break
185-
if no_output and len(lines) > 1:
186-
raise RIARemoteError("{}: {}".format(call, "".join(lines)))
187-
return "".join(lines[:-1])
188-
189128
@contextmanager
190129
def ensure_writeable(self, path):
191130
"""Context manager to get write permission on `path` and restore
@@ -215,12 +154,14 @@ def ensure_writeable(self, path):
215154
# needed.
216155
conversion = functools.partial(int, base=16)
217156

218-
output = self._run(f"stat {format_option} {path}",
219-
no_output=False, check=True)
157+
output = self.servershell(
158+
f"stat {format_option} {path}",
159+
check=True,
160+
).stdout.decode()
220161
mode = conversion(output)
221162
if not mode & stat.S_IWRITE:
222163
new_mode = oct(mode | stat.S_IWRITE)[-3:]
223-
self._run(f"chmod {new_mode} {path}")
164+
self.servershell(f"chmod {new_mode} {path}", check=True)
224165
changed = True
225166
else:
226167
changed = False
@@ -229,16 +170,23 @@ def ensure_writeable(self, path):
229170
finally:
230171
if changed:
231172
# restore original mode
232-
self._run("chmod {mode} {file}".format(mode=oct(mode)[-3:],
233-
file=path),
234-
check=False) # don't fail if path doesn't exist
235-
# anymore
173+
self.servershell(
174+
f"chmod {oct(mode)[-3:]} {path}",
175+
# don't fail if path doesn't exist anymore
176+
check=False,
177+
)
236178

237179
def mkdir(self, path):
238-
self._run('mkdir -p {}'.format(sh_quote(str(path))))
180+
self.servershell(
181+
f'mkdir -p {sh_quote(str(path))}',
182+
check=True,
183+
)
239184

240185
def symlink(self, target, link_name):
241-
self._run('ln -s {} {}'.format(sh_quote(str(target)), sh_quote(str(link_name))))
186+
self.servershell(
187+
f'ln -s {sh_quote(str(target))} {sh_quote(str(link_name))}',
188+
check=True,
189+
)
242190

243191
def put(self, src, dst, progress_cb):
244192
self.ssh.put(str(src), str(dst))
@@ -295,25 +243,38 @@ def get(self, src, dst, progress_cb):
295243

296244
def rename(self, src, dst):
297245
with self.ensure_writeable(dst.parent):
298-
self._run('mv {} {}'.format(sh_quote(str(src)), sh_quote(str(dst))))
246+
self.servershell(
247+
f'mv {sh_quote(str(src))} {sh_quote(str(dst))}',
248+
check=True,
249+
)
299250

300251
def remove(self, path):
301252
try:
302253
with self.ensure_writeable(path.parent):
303-
self._run('rm {}'.format(sh_quote(str(path))), check=True)
304-
except RemoteCommandFailedError as e:
305-
raise RIARemoteError(f"Unable to remove {path} "
306-
"or to obtain write permission in parent directory.") from e
254+
self.servershell(
255+
f'rm {sh_quote(str(path))}',
256+
check=True,
257+
)
258+
except CommandError as e:
259+
raise RIARemoteError(
260+
f"Unable to remove {path} "
261+
"or to obtain write permission in parent directory.") from e
307262

308263
def remove_dir(self, path):
309264
with self.ensure_writeable(path.parent):
310-
self._run('rmdir {}'.format(sh_quote(str(path))))
265+
self.servershell(
266+
f'rmdir {sh_quote(str(path))}',
267+
check=True,
268+
)
311269

312270
def exists(self, path):
313271
try:
314-
self._run('test -e {}'.format(sh_quote(str(path))), check=True)
272+
self.servershell(
273+
f'test -e {sh_quote(str(path))}',
274+
check=True,
275+
)
315276
return True
316-
except RemoteCommandFailedError:
277+
except CommandError:
317278
return False
318279

319280
def in_archive(self, archive_path, file_path):
@@ -324,14 +285,15 @@ def in_archive(self, archive_path, file_path):
324285
loc = str(file_path)
325286
# query 7z for the specific object location, keeps the output
326287
# lean, even for big archives
327-
cmd = '7z l {} {}'.format(
328-
sh_quote(str(archive_path)),
329-
sh_quote(loc))
288+
cmd = f'7z l {sh_quote(str(archive_path))} {sh_quote(loc)}'
330289

331290
# Note: Currently relies on file_path not showing up in case of failure
332291
# including non-existent archive. If need be could be more sophisticated
333292
# and called with check=True + catch RemoteCommandFailedError
334-
out = self._run(cmd, no_output=False, check=False)
293+
out = self.servershell(
294+
cmd,
295+
check=False,
296+
).stdout.decode()
335297

336298
return loc in out
337299

@@ -373,10 +335,13 @@ def get_from_archive(self, archive, src, dst, progress_cb):
373335

374336
def read_file(self, file_path):
375337

376-
cmd = "cat {}".format(sh_quote(str(file_path)))
338+
cmd = f"cat {sh_quote(str(file_path))}"
377339
try:
378-
out = self._run(cmd, no_output=False, check=True)
379-
except RemoteCommandFailedError as e:
340+
out = self.servershell(
341+
cmd,
342+
check=True,
343+
).stdout.decode()
344+
except CommandError as e:
380345
# Currently we don't read stderr. All we know is, we couldn't read.
381346
# Try narrowing it down by calling a subsequent exists()
382347
if not self.exists(file_path):
@@ -397,13 +362,16 @@ def write_file(self, file_path, content, mode='w'):
397362
if not content.endswith('\n'):
398363
content += '\n'
399364

400-
cmd = "printf '%s' {} {} {}".format(
401-
sh_quote(content),
402-
mode,
403-
sh_quote(str(file_path)))
365+
# it really should read from stdin, but MIH cannot make it happen
366+
stdin = content.encode()
367+
cmd = f"head -c {len(stdin)} | cat {mode} {sh_quote(str(file_path))}"
404368
try:
405-
self._run(cmd, check=True)
406-
except RemoteCommandFailedError as e:
369+
self.servershell(
370+
cmd,
371+
check=True,
372+
stdin=[stdin],
373+
)
374+
except CommandError as e:
407375
raise RIARemoteError(f"Could not write to {file_path}") from e
408376

409377
def get_7z(self):
@@ -416,17 +384,14 @@ def get_7z(self):
416384
# to just call 7z and see whether it returns zero.
417385

418386
try:
419-
self._run("7z", check=True, no_output=False)
387+
self.servershell(
388+
"7z",
389+
check=True,
390+
)
420391
return True
421-
except RemoteCommandFailedError:
392+
except CommandError:
422393
return False
423394

424-
# try:
425-
# out = self._run("which 7z", check=True, no_output=False)
426-
# return out
427-
# except RemoteCommandFailedError:
428-
# return None
429-
430395

431396
# replace the whole class
432397
apply_patch('datalad.distributed.ora_remote', None, 'SSHRemoteIO', SSHRemoteIO)

0 commit comments

Comments
 (0)