1
- import logging
2
-
3
1
from urllib .parse import urlparse
4
2
from urllib .request import unquote
5
3
6
4
from datalad .distributed .ora_remote import (
7
5
DEFAULT_BUFFER_SIZE ,
8
6
IOBase ,
9
7
RemoteError ,
10
- RemoteCommandFailedError ,
11
8
RIARemoteError ,
12
9
contextmanager ,
13
10
functools ,
14
11
on_osx ,
15
12
sh_quote ,
16
13
ssh_manager ,
17
14
stat ,
18
- subprocess ,
19
15
)
20
16
21
17
from datalad_next .exceptions import CapturedException
22
18
from datalad_next .patches import apply_patch
19
+ from datalad_next .runners import CommandError
20
+ from datalad_next .shell import shell
23
21
24
22
25
23
class SSHRemoteIO (IOBase ):
26
24
"""IO operation if the object tree is SSH-accessible
27
25
28
26
It doesn't even think about a windows server.
29
27
"""
30
-
31
- # output markers to detect possible command failure as well as end of output
32
- # from a particular command:
33
- REMOTE_CMD_FAIL = "ora-remote: end - fail"
34
- REMOTE_CMD_OK = "ora-remote: end - ok"
35
-
36
28
def __init__ (self , ssh_url , buffer_size = DEFAULT_BUFFER_SIZE ):
37
29
"""
38
30
Parameters
@@ -43,6 +35,7 @@ def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
43
35
"""
44
36
parsed_url = urlparse (ssh_url )
45
37
38
+ self .url = ssh_url
46
39
# the connection to the remote
47
40
# we don't open it yet, not yet clear if needed
48
41
self .ssh = ssh_manager .get_connection (
@@ -52,29 +45,22 @@ def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
52
45
self .ssh .open ()
53
46
# open a remote shell
54
47
cmd = ['ssh' ] + self .ssh ._ssh_args + [self .ssh .sshri .as_str ()]
55
- self .shell = subprocess .Popen (cmd ,
56
- stderr = subprocess .DEVNULL ,
57
- stdout = subprocess .PIPE ,
58
- stdin = subprocess .PIPE )
59
- # swallow login message(s):
60
- self .shell .stdin .write (b"echo RIA-REMOTE-LOGIN-END\n " )
61
- self .shell .stdin .flush ()
62
- while True :
63
- line = self .shell .stdout .readline ()
64
- if line == b"RIA-REMOTE-LOGIN-END\n " :
65
- break
66
- # TODO: Same for stderr?
67
-
68
- # make sure default is used when None was passed, too.
69
- self .buffer_size = buffer_size if buffer_size else DEFAULT_BUFFER_SIZE
48
+ # we settle on `bash` as a shell. It should be around and then we
49
+ # can count on it
50
+ cmd .append ('bash' )
51
+ self .servershell_context = shell (
52
+ cmd ,
53
+ chunk_size = buffer_size ,
54
+ )
55
+ self .servershell = self .servershell_context .__enter__ ()
70
56
71
57
# if the URL had a path, we try to 'cd' into it to make operations on
72
58
# relative paths intelligible
73
59
if parsed_url .path :
74
60
# unquote path
75
61
real_path = unquote (parsed_url .path )
76
62
try :
77
- self ._run (
63
+ self .servershell (
78
64
f'cd { sh_quote (real_path )} ' ,
79
65
check = True ,
80
66
)
@@ -84,22 +70,7 @@ def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
84
70
CapturedException (e )
85
71
86
72
def close (self ):
87
- # try exiting shell clean first
88
- self .shell .stdin .write (b"exit\n " )
89
- self .shell .stdin .flush ()
90
- exitcode = self .shell .wait (timeout = 0.5 )
91
- # be more brutal if it doesn't work
92
- if exitcode is None : # timed out
93
- # TODO: Theoretically terminate() can raise if not successful.
94
- # How to deal with that?
95
- self .shell .terminate ()
96
-
97
- def _append_end_markers (self , cmd ):
98
- """Append end markers to remote command"""
99
-
100
- return cmd + " && printf '%s\\ n' {} || printf '%s\\ n' {}\n " .format (
101
- sh_quote (self .REMOTE_CMD_OK ),
102
- sh_quote (self .REMOTE_CMD_FAIL ))
73
+ self .servershell_context .__exit__ (None , None , None )
103
74
104
75
def _get_download_size_from_key (self , key ):
105
76
"""Get the size of an annex object file from it's key
@@ -110,7 +81,7 @@ def _get_download_size_from_key(self, key):
110
81
Parameter
111
82
---------
112
83
key: str
113
- annex key of the file
84
+ annex key of the filte
114
85
115
86
Returns
116
87
-------
@@ -154,38 +125,6 @@ def _get_download_size_from_key(self, key):
154
125
else :
155
126
raise RIARemoteError ("invalid key: {}" .format (key ))
156
127
157
- def _run (self , cmd , no_output = True , check = False ):
158
-
159
- # TODO: we might want to redirect stderr to stdout here (or have
160
- # additional end marker in stderr) otherwise we can't empty stderr
161
- # to be ready for next command. We also can't read stderr for
162
- # better error messages (RemoteError) without making sure there's
163
- # something to read in any case (it's blocking!).
164
- # However, if we are sure stderr can only ever happen if we would
165
- # raise RemoteError anyway, it might be okay.
166
- call = self ._append_end_markers (cmd )
167
- self .shell .stdin .write (call .encode ())
168
- self .shell .stdin .flush ()
169
-
170
- lines = []
171
- while True :
172
- line = self .shell .stdout .readline ().decode ()
173
- lines .append (line )
174
- if line == self .REMOTE_CMD_OK + '\n ' :
175
- # end reading
176
- break
177
- elif line == self .REMOTE_CMD_FAIL + '\n ' :
178
- if check :
179
- raise RemoteCommandFailedError (
180
- "{cmd} failed: {msg}" .format (cmd = cmd ,
181
- msg = "" .join (lines [:- 1 ]))
182
- )
183
- else :
184
- break
185
- if no_output and len (lines ) > 1 :
186
- raise RIARemoteError ("{}: {}" .format (call , "" .join (lines )))
187
- return "" .join (lines [:- 1 ])
188
-
189
128
@contextmanager
190
129
def ensure_writeable (self , path ):
191
130
"""Context manager to get write permission on `path` and restore
@@ -215,12 +154,14 @@ def ensure_writeable(self, path):
215
154
# needed.
216
155
conversion = functools .partial (int , base = 16 )
217
156
218
- output = self ._run (f"stat { format_option } { path } " ,
219
- no_output = False , check = True )
157
+ output = self .servershell (
158
+ f"stat { format_option } { path } " ,
159
+ check = True ,
160
+ ).stdout .decode ()
220
161
mode = conversion (output )
221
162
if not mode & stat .S_IWRITE :
222
163
new_mode = oct (mode | stat .S_IWRITE )[- 3 :]
223
- self ._run (f"chmod { new_mode } { path } " )
164
+ self .servershell (f"chmod { new_mode } { path } " , check = True )
224
165
changed = True
225
166
else :
226
167
changed = False
@@ -229,16 +170,23 @@ def ensure_writeable(self, path):
229
170
finally :
230
171
if changed :
231
172
# restore original mode
232
- self ._run ("chmod {mode} {file}" .format (mode = oct (mode )[- 3 :],
233
- file = path ),
234
- check = False ) # don't fail if path doesn't exist
235
- # anymore
173
+ self .servershell (
174
+ f"chmod { oct (mode )[- 3 :]} { path } " ,
175
+ # don't fail if path doesn't exist anymore
176
+ check = False ,
177
+ )
236
178
237
179
def mkdir (self , path ):
238
- self ._run ('mkdir -p {}' .format (sh_quote (str (path ))))
180
+ self .servershell (
181
+ f'mkdir -p { sh_quote (str (path ))} ' ,
182
+ check = True ,
183
+ )
239
184
240
185
def symlink (self , target , link_name ):
241
- self ._run ('ln -s {} {}' .format (sh_quote (str (target )), sh_quote (str (link_name ))))
186
+ self .servershell (
187
+ f'ln -s { sh_quote (str (target ))} { sh_quote (str (link_name ))} ' ,
188
+ check = True ,
189
+ )
242
190
243
191
def put (self , src , dst , progress_cb ):
244
192
self .ssh .put (str (src ), str (dst ))
@@ -295,25 +243,38 @@ def get(self, src, dst, progress_cb):
295
243
296
244
def rename (self , src , dst ):
297
245
with self .ensure_writeable (dst .parent ):
298
- self ._run ('mv {} {}' .format (sh_quote (str (src )), sh_quote (str (dst ))))
246
+ self .servershell (
247
+ f'mv { sh_quote (str (src ))} { sh_quote (str (dst ))} ' ,
248
+ check = True ,
249
+ )
299
250
300
251
def remove (self , path ):
301
252
try :
302
253
with self .ensure_writeable (path .parent ):
303
- self ._run ('rm {}' .format (sh_quote (str (path ))), check = True )
304
- except RemoteCommandFailedError as e :
305
- raise RIARemoteError (f"Unable to remove { path } "
306
- "or to obtain write permission in parent directory." ) from e
254
+ self .servershell (
255
+ f'rm { sh_quote (str (path ))} ' ,
256
+ check = True ,
257
+ )
258
+ except CommandError as e :
259
+ raise RIARemoteError (
260
+ f"Unable to remove { path } "
261
+ "or to obtain write permission in parent directory." ) from e
307
262
308
263
def remove_dir (self , path ):
309
264
with self .ensure_writeable (path .parent ):
310
- self ._run ('rmdir {}' .format (sh_quote (str (path ))))
265
+ self .servershell (
266
+ f'rmdir { sh_quote (str (path ))} ' ,
267
+ check = True ,
268
+ )
311
269
312
270
def exists (self , path ):
313
271
try :
314
- self ._run ('test -e {}' .format (sh_quote (str (path ))), check = True )
272
+ self .servershell (
273
+ f'test -e { sh_quote (str (path ))} ' ,
274
+ check = True ,
275
+ )
315
276
return True
316
- except RemoteCommandFailedError :
277
+ except CommandError :
317
278
return False
318
279
319
280
def in_archive (self , archive_path , file_path ):
@@ -324,14 +285,15 @@ def in_archive(self, archive_path, file_path):
324
285
loc = str (file_path )
325
286
# query 7z for the specific object location, keeps the output
326
287
# lean, even for big archives
327
- cmd = '7z l {} {}' .format (
328
- sh_quote (str (archive_path )),
329
- sh_quote (loc ))
288
+ cmd = f'7z l { sh_quote (str (archive_path ))} { sh_quote (loc )} '
330
289
331
290
# Note: Currently relies on file_path not showing up in case of failure
332
291
# including non-existent archive. If need be could be more sophisticated
333
292
# and called with check=True + catch RemoteCommandFailedError
334
- out = self ._run (cmd , no_output = False , check = False )
293
+ out = self .servershell (
294
+ cmd ,
295
+ check = False ,
296
+ ).stdout .decode ()
335
297
336
298
return loc in out
337
299
@@ -373,10 +335,13 @@ def get_from_archive(self, archive, src, dst, progress_cb):
373
335
374
336
def read_file (self , file_path ):
375
337
376
- cmd = "cat {}" . format ( sh_quote (str (file_path )))
338
+ cmd = f "cat { sh_quote (str (file_path ))} "
377
339
try :
378
- out = self ._run (cmd , no_output = False , check = True )
379
- except RemoteCommandFailedError as e :
340
+ out = self .servershell (
341
+ cmd ,
342
+ check = True ,
343
+ ).stdout .decode ()
344
+ except CommandError as e :
380
345
# Currently we don't read stderr. All we know is, we couldn't read.
381
346
# Try narrowing it down by calling a subsequent exists()
382
347
if not self .exists (file_path ):
@@ -397,13 +362,16 @@ def write_file(self, file_path, content, mode='w'):
397
362
if not content .endswith ('\n ' ):
398
363
content += '\n '
399
364
400
- cmd = "printf '%s' {} {} {}" .format (
401
- sh_quote (content ),
402
- mode ,
403
- sh_quote (str (file_path )))
365
+ # it really should read from stdin, but MIH cannot make it happen
366
+ stdin = content .encode ()
367
+ cmd = f"head -c { len (stdin )} | cat { mode } { sh_quote (str (file_path ))} "
404
368
try :
405
- self ._run (cmd , check = True )
406
- except RemoteCommandFailedError as e :
369
+ self .servershell (
370
+ cmd ,
371
+ check = True ,
372
+ stdin = [stdin ],
373
+ )
374
+ except CommandError as e :
407
375
raise RIARemoteError (f"Could not write to { file_path } " ) from e
408
376
409
377
def get_7z (self ):
@@ -416,17 +384,14 @@ def get_7z(self):
416
384
# to just call 7z and see whether it returns zero.
417
385
418
386
try :
419
- self ._run ("7z" , check = True , no_output = False )
387
+ self .servershell (
388
+ "7z" ,
389
+ check = True ,
390
+ )
420
391
return True
421
- except RemoteCommandFailedError :
392
+ except CommandError :
422
393
return False
423
394
424
- # try:
425
- # out = self._run("which 7z", check=True, no_output=False)
426
- # return out
427
- # except RemoteCommandFailedError:
428
- # return None
429
-
430
395
431
396
# replace the whole class
432
397
apply_patch ('datalad.distributed.ora_remote' , None , 'SSHRemoteIO' , SSHRemoteIO )
0 commit comments