1
+ from jupyter_client .provisioning .local_provisioner import LocalProvisioner ;
2
+ from jupyter_client .connect import KernelConnectionInfo ;
3
+
4
+ from typing import Any
5
+ from typing import Dict
6
+ from typing import List
7
+ from typing import Optional
8
+ from traitlets import Unicode ;
9
+ from traitlets import Dict as tDict ;
10
+ from time import sleep ;
11
+ import re ;
12
+ import json ;
13
+ from subprocess import check_output , Popen , PIPE , DEVNULL , STDOUT ;
14
+
15
+ # custom exceptions
16
+ class NoSlurmFlagsFound (Exception ):
17
+ pass ;
18
+ class UnknownLoginnode (Exception ):
19
+ pass ;
20
+ class UnknownUsername (Exception ):
21
+ pass ;
22
+ class NoSlurmJobID (Exception ):
23
+ pass ;
24
+ class SSHConnectionError (Exception ):
25
+ pass ;
26
+
27
+ class RemoteSlurmProvisioner (LocalProvisioner ):
28
+
29
+ sbatch_flags : dict = tDict (config = True );
30
+ proxyjump : str = Unicode (config = True );
31
+ loginnode : str = Unicode (config = True );
32
+ username : str = Unicode (config = True );
33
+
34
+ default_batch_job = """#!/bin/bash
35
+ #SBATCH -J jupyter_slurm_kernel
36
+ {SBATCH_JOB_FLAGS}
37
+
38
+ tmpfile=$(mktemp)
39
+ cat << EOF > $tmpfile
40
+ {KERNEL_CONNECTION_INFO}
41
+ EOF
42
+ connection_file=$tmpfile
43
+
44
+ {EXTRA_ENVIRONMENT}
45
+
46
+ {COMMAND}
47
+ """ ;
48
+
49
+ def __init__ (self , ** kwargs ):
50
+
51
+ self .job_id = None ;
52
+ self .job_state = None ;
53
+ self .exec_node = None ;
54
+ self .active_port_forwarding = False ;
55
+
56
+ super ().__init__ (** kwargs );
57
+
58
+ async def pre_launch (self , ** kwargs : Any ) -> Dict [str , Any ]:
59
+
60
+ if not self .sbatch_flags :
61
+ raise NoSlurmFlagsFound ('Please provide sbatch flags to start the Slurm job with!' );
62
+
63
+ if not self .loginnode :
64
+ raise UnknownLoginnode ('Could not start Slurm job. Unknown loginnode!' );
65
+
66
+ if not self .username :
67
+ loginnode = self .loginnode ;
68
+ if self .proxyjump :
69
+ loginnode = self .loginnode + f' (via { self .proxyjump } )' ;
70
+ raise UnknownUsername (f'Could not login to { loginnode } ! Unknown username!' );
71
+
72
+ # Build sbatch job flags
73
+ slurm_job_flags = '' ;
74
+ for parameter , value in self .sbatch_flags .items ():
75
+ slurm_job_flags += f'#SBATCH --{ parameter } ={ value } \n ' ;
76
+
77
+ # build ssh command
78
+ proxyjump = '' ;
79
+ if self .proxyjump :
80
+ proxyjump = f'-J { self .proxyjump } ' ;
81
+ self .ssh_command = f'ssh -tA { proxyjump } { self .loginnode } ' ;
82
+
83
+ # build sbatch command
84
+ self .sbatch_command = ['/bin/bash' , '--login' , '-c' , '"sbatch --parsable"' ];
85
+
86
+ # add extra environment variables into sbatch job
87
+ extra_environment = '' ;
88
+ try :
89
+ if len (self .kernel_spec .env ) >= 1 :
90
+ for key , val in self .kernel_spec .env .items ():
91
+ extra_environment += f'export { key } ={ val } \n ' ;
92
+ except :
93
+ pass ;
94
+
95
+ # finally build the Slurm sbatch job
96
+ kernel_command = ' ' .join (self .kernel_spec .argv );
97
+ self .batch_job = self .default_batch_job .format (SBATCH_JOB_FLAGS = slurm_job_flags ,EXTRA_ENVIRONMENT = extra_environment ,COMMAND = kernel_command ,KERNEL_CONNECTION_INFO = '{KERNEL_CONNECTION_INFO}' );
98
+
99
+ return await super ().pre_launch (** kwargs )
100
+
101
+ async def launch_kernel (self , cmd : List [str ], ** kwargs : Any ) -> KernelConnectionInfo :
102
+
103
+ # kernel connection info is now available - add it to the Slurm batch job
104
+ kernel_connection_info = {};
105
+ # the kernel connection info contains byte-strings which are not JSON valid
106
+ for key , val in self .connection_info .items ():
107
+ if isinstance (val , bytes ):
108
+ kernel_connection_info [key ] = val .decode ('utf-8' );
109
+ else :
110
+ kernel_connection_info [key ] = val ;
111
+
112
+ kernel_connection_info = str (kernel_connection_info ).replace ("'" , '"' );
113
+
114
+ self .batch_job = self .batch_job .format (KERNEL_CONNECTION_INFO = kernel_connection_info , connection_file = '$connection_file' );
115
+ self .log .debug ('Final sbatch jobfile: ' + str (self .batch_job ));
116
+
117
+ run_command = self .ssh_command .split (' ' ) + self .sbatch_command ;
118
+ self .log .debug ('Would run SSH command: ' + str (run_command ));
119
+
120
+ self .process = Popen (run_command , stdout = PIPE , stderr = DEVNULL , stdin = PIPE );
121
+ child_process_out , child_process_err = self .process .communicate (input = self .batch_job .encode ());
122
+ child_process_out = child_process_out .decode ('utf-8' ).strip ();
123
+
124
+ self .log .debug ('Submitted Slurm job! sbatch output: ' + str (child_process_out ));
125
+
126
+ # now parsing the output to fetch the Slurm job id
127
+ self .job_id = re .search (r"(\d+)" , child_process_out , re .IGNORECASE );
128
+ if self .job_id :
129
+ try :
130
+ self .job_id = self .job_id .group (1 );
131
+ self .job_id = int (self .job_id );
132
+ self .log .info ("Slurm job successfully submitted. Slurm job id: " + str (self .job_id ));
133
+ except :
134
+ raise NoSlurmJobID ("Could not fetch the Slurm job id!" );
135
+
136
+ return self .connection_info ;
137
+
138
+ def _start_ssh_port_forwarding (self ):
139
+
140
+ if self .exec_node :
141
+ if self .connection_info :
142
+
143
+ port_forward = ' ' ;
144
+ port_forward = port_forward .join (['-L {{{kport}}}:127.0.0.1:{{{kport}}}' .format (kport = kport ) for kport in [ 'stdin_port' , 'shell_port' , 'iopub_port' , 'hb_port' , 'control_port' ]]);
145
+ # replace needed ports
146
+ port_forward = port_forward .format (** self .connection_info );
147
+
148
+ proxy_jump = '' ;
149
+ if self .proxyjump :
150
+ proxy_jump = f'-J { self .proxyjump } ,{ self .loginnode } ' ;
151
+ else :
152
+ proxy_jump = f'-J { self .loginnode } ' ;
153
+
154
+ ssh_command = ['ssh' , '-fNA' , '-o' , 'StrictHostKeyChecking=no' ] + proxy_jump .split (' ' ) + port_forward .split (' ' );
155
+ ssh_command .append (self .exec_node );
156
+
157
+ self .log .info ('Starting SSH tunnel to forward kernel ports to localhost' );
158
+ self .log .debug ('Using command: ' + str (ssh_command ));
159
+
160
+ ssh_tunnel_process = Popen (ssh_command , stdout = PIPE , stderr = STDOUT );
161
+ # TODO: work with exceptions
162
+ self .active_port_forwarding = True ;
163
+
164
+ def _get_slurm_job_state (self , job_id : int ):
165
+
166
+ check_command = self .ssh_command .split (' ' ) + ['-T' , '/bin/bash' , '--login' , '-c' , f'"squeue -h -j { self .job_id } -o \' %T %B\' 2> /dev/null"' ];
167
+
168
+ squeue_output = check_output (check_command );
169
+ squeue_output = squeue_output .decode ('utf-8' ).split (' ' );
170
+ self .state = squeue_output [0 ].strip ();
171
+
172
+ if 'RUNNING' in self .state :
173
+ self .exec_node = squeue_output [1 ];
174
+ self .exec_node = self .exec_node .strip ();
175
+
176
+ return [self .state , self .exec_node ];
177
+
178
+ async def poll (self ) -> Optional [int ]:
179
+
180
+ # 0 = polling
181
+ result = 0 ;
182
+ if self .job_id :
183
+ state , exec_node = self ._get_slurm_job_state (self .job_id );
184
+ if state == 'RUNNING' :
185
+
186
+ if isinstance (exec_node , str ):
187
+ if self .active_port_forwarding == False :
188
+ self .log .info (f'Slurm job is in state running on compute node { exec_node } ' );
189
+ self ._start_ssh_port_forwarding ();
190
+ result = None ;
191
+
192
+ return result ;
193
+
194
+ def get_shutdown_wait_time (self , recommended : float = 5 ) -> float :
195
+
196
+ #recommended = 30.0;
197
+ return super ().get_shutdown_wait_time (recommended );
198
+
199
+ def get_stable_start_time (self , recommended : float = 10 ) -> float :
200
+
201
+ recommended = 30.0 ;
202
+ return super ().get_stable_start_time (recommended )
203
+
204
+ async def send_signal (self , signum : int ) -> None :
205
+
206
+ if signum == 0 :
207
+ return await self .poll ();
208
+ else :
209
+ return await super ().send_signal (signum );
0 commit comments