Skip to content

Commit fbdadc6

Browse files
committed
using the jupyter kernel provisioner to control the jupyter kernel flow with slurm
1 parent fb3723a commit fbdadc6

File tree

7 files changed

+415
-404
lines changed

7 files changed

+415
-404
lines changed

bin/slurmkernel

Lines changed: 191 additions & 195 deletions
Large diffs are not rendered by default.

setup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,10 @@
2222
include_package_data=True,
2323
packages=['slurm_jupyter_kernel'],
2424
scripts=['bin/slurmkernel'],
25-
install_requires=['pexpect>=4.8.0', 'ipython>=8.5.0', 'jupyter_client>=7.3.5']
25+
install_requires=['pexpect>=4.8.0', 'ipython>=8.5.0', 'jupyter_client>=7.3.5'],
26+
entry_points={
27+
'jupyter_client.kernel_provisioners': [
28+
'remote-slurm-provisioner = slurm_jupyter_kernel.provisioner:RemoteSlurmProvisioner',
29+
]
30+
}
2631
);

slurm_jupyter_kernel/kernel_scripts/ijulia.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ echo -e '#!/bin/bash\n$1\n"$@"' > $2/ijulia_wrapper.sh
1414
chmod +x $2/ijulia_wrapper.sh
1515

1616
# kernel settings
17-
KERNEL_LANGUAGE=julia
18-
KERNEL_CMD=$2/ijulia_wrapper.sh julia -i --color=yes --project=$2 $3/packages/IJulia/AQu2H/src/kernel.jl {connection_file}
19-
KERNEL_ENVIRONMENT=$4
17+
LANGUAGE=julia
18+
ARGV=$2/ijulia_wrapper.sh julia -i --color=yes --project=$2 $3/packages/IJulia/AQu2H/src/kernel.jl {connection_file}
19+
ENVIRONMENT=$4

slurm_jupyter_kernel/kernel_scripts/ipython.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@ echo -e '#!/bin/bash\n$1\nsource $2/ipython_venv/bin/activate\n"$@"' > $2/ipytho
1313
chmod +x $2/ipython_venv/ipy_wrapper.sh
1414

1515
# kernel settings
16-
KERNEL_LANGUAGE=python
17-
KERNEL_CMD=$2/ipython_venv/ipy_wrapper.sh ipython kernel -f {connection_file}
16+
LANGUAGE=python
17+
ARGV=$2/ipython_venv/ipy_wrapper.sh ipython kernel -f {connection_file}

slurm_jupyter_kernel/provisioner.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
from jupyter_client.provisioning.local_provisioner import LocalProvisioner;
2+
from jupyter_client.connect import KernelConnectionInfo;
3+
4+
from typing import Any
5+
from typing import Dict
6+
from typing import List
7+
from typing import Optional
8+
from traitlets import Unicode;
9+
from traitlets import Dict as tDict;
10+
from time import sleep;
11+
import re;
12+
import json;
13+
from subprocess import check_output, Popen, PIPE, DEVNULL, STDOUT;
14+
15+
# custom exceptions
16+
class NoSlurmFlagsFound (Exception):
17+
pass;
18+
class UnknownLoginnode (Exception):
19+
pass;
20+
class UnknownUsername (Exception):
21+
pass;
22+
class NoSlurmJobID (Exception):
23+
pass;
24+
class SSHConnectionError (Exception):
25+
pass;
26+
27+
class RemoteSlurmProvisioner(LocalProvisioner):
28+
29+
sbatch_flags: dict = tDict(config=True);
30+
proxyjump: str = Unicode(config=True);
31+
loginnode: str = Unicode(config=True);
32+
username: str = Unicode(config=True);
33+
34+
default_batch_job = """#!/bin/bash
35+
#SBATCH -J jupyter_slurm_kernel
36+
{SBATCH_JOB_FLAGS}
37+
38+
tmpfile=$(mktemp)
39+
cat << EOF > $tmpfile
40+
{KERNEL_CONNECTION_INFO}
41+
EOF
42+
connection_file=$tmpfile
43+
44+
{EXTRA_ENVIRONMENT}
45+
46+
{COMMAND}
47+
""";
48+
49+
def __init__(self, **kwargs):
50+
51+
self.job_id = None;
52+
self.job_state = None;
53+
self.exec_node = None;
54+
self.active_port_forwarding = False;
55+
56+
super().__init__(**kwargs);
57+
58+
async def pre_launch(self, **kwargs: Any) -> Dict[str, Any]:
59+
60+
if not self.sbatch_flags:
61+
raise NoSlurmFlagsFound('Please provide sbatch flags to start the Slurm job with!');
62+
63+
if not self.loginnode:
64+
raise UnknownLoginnode('Could not start Slurm job. Unknown loginnode!');
65+
66+
if not self.username:
67+
loginnode = self.loginnode;
68+
if self.proxyjump:
69+
loginnode = self.loginnode + f' (via {self.proxyjump})';
70+
raise UnknownUsername(f'Could not login to {loginnode}! Unknown username!');
71+
72+
# Build sbatch job flags
73+
slurm_job_flags = '';
74+
for parameter, value in self.sbatch_flags.items():
75+
slurm_job_flags += f'#SBATCH --{parameter}={value}\n';
76+
77+
# build ssh command
78+
proxyjump = '';
79+
if self.proxyjump:
80+
proxyjump = f'-J {self.proxyjump}';
81+
self.ssh_command = f'ssh -tA {proxyjump} {self.loginnode}';
82+
83+
# build sbatch command
84+
self.sbatch_command = ['/bin/bash', '--login', '-c', '"sbatch --parsable"'];
85+
86+
# add extra environment variables into sbatch job
87+
extra_environment = '';
88+
try:
89+
if len(self.kernel_spec.env) >= 1:
90+
for key, val in self.kernel_spec.env.items():
91+
extra_environment += f'export {key}={val}\n';
92+
except:
93+
pass;
94+
95+
# finally build the Slurm sbatch job
96+
kernel_command = ' '.join(self.kernel_spec.argv);
97+
self.batch_job = self.default_batch_job.format(SBATCH_JOB_FLAGS=slurm_job_flags,EXTRA_ENVIRONMENT=extra_environment,COMMAND=kernel_command,KERNEL_CONNECTION_INFO='{KERNEL_CONNECTION_INFO}');
98+
99+
return await super().pre_launch(**kwargs)
100+
101+
async def launch_kernel (self, cmd: List[str], **kwargs: Any) -> KernelConnectionInfo:
102+
103+
# kernel connection info is now available - add it to the Slurm batch job
104+
kernel_connection_info = {};
105+
# the kernel connection info contains byte-strings which are not JSON valid
106+
for key, val in self.connection_info.items():
107+
if isinstance(val, bytes):
108+
kernel_connection_info[key] = val.decode('utf-8');
109+
else:
110+
kernel_connection_info[key] = val;
111+
112+
kernel_connection_info = str(kernel_connection_info).replace("'", '"');
113+
114+
self.batch_job = self.batch_job.format(KERNEL_CONNECTION_INFO=kernel_connection_info, connection_file='$connection_file');
115+
self.log.debug('Final sbatch jobfile: ' + str(self.batch_job));
116+
117+
run_command = self.ssh_command.split(' ') + self.sbatch_command;
118+
self.log.debug('Would run SSH command: ' + str(run_command));
119+
120+
self.process = Popen(run_command, stdout=PIPE, stderr=DEVNULL, stdin=PIPE);
121+
child_process_out, child_process_err = self.process.communicate(input=self.batch_job.encode());
122+
child_process_out = child_process_out.decode('utf-8').strip();
123+
124+
self.log.debug('Submitted Slurm job! sbatch output: ' + str(child_process_out));
125+
126+
# now parsing the output to fetch the Slurm job id
127+
self.job_id = re.search(r"(\d+)", child_process_out, re.IGNORECASE);
128+
if self.job_id:
129+
try:
130+
self.job_id = self.job_id.group(1);
131+
self.job_id = int(self.job_id);
132+
self.log.info("Slurm job successfully submitted. Slurm job id: " + str(self.job_id));
133+
except:
134+
raise NoSlurmJobID("Could not fetch the Slurm job id!");
135+
136+
return self.connection_info;
137+
138+
def _start_ssh_port_forwarding (self):
139+
140+
if self.exec_node:
141+
if self.connection_info:
142+
143+
port_forward = ' ';
144+
port_forward = port_forward.join(['-L {{{kport}}}:127.0.0.1:{{{kport}}}'.format(kport=kport) for kport in [ 'stdin_port', 'shell_port', 'iopub_port', 'hb_port', 'control_port' ]]);
145+
# replace needed ports
146+
port_forward = port_forward.format(**self.connection_info);
147+
148+
proxy_jump = '';
149+
if self.proxyjump:
150+
proxy_jump = f'-J {self.proxyjump},{self.loginnode}';
151+
else:
152+
proxy_jump = f'-J {self.loginnode}';
153+
154+
ssh_command = ['ssh', '-fNA', '-o', 'StrictHostKeyChecking=no'] + proxy_jump.split(' ') + port_forward.split(' ');
155+
ssh_command.append(self.exec_node);
156+
157+
self.log.info('Starting SSH tunnel to forward kernel ports to localhost');
158+
self.log.debug('Using command: ' + str(ssh_command));
159+
160+
ssh_tunnel_process = Popen(ssh_command, stdout=PIPE, stderr=STDOUT);
161+
# TODO: work with exceptions
162+
self.active_port_forwarding = True;
163+
164+
def _get_slurm_job_state (self, job_id: int):
165+
166+
check_command = self.ssh_command.split(' ') + ['-T', '/bin/bash', '--login', '-c', f'"squeue -h -j {self.job_id} -o \'%T %B\' 2> /dev/null"'];
167+
168+
squeue_output = check_output(check_command);
169+
squeue_output = squeue_output.decode('utf-8').split(' ');
170+
self.state = squeue_output[0].strip();
171+
172+
if 'RUNNING' in self.state:
173+
self.exec_node = squeue_output[1];
174+
self.exec_node = self.exec_node.strip();
175+
176+
return [self.state, self.exec_node];
177+
178+
async def poll(self) -> Optional[int]:
179+
180+
# 0 = polling
181+
result = 0;
182+
if self.job_id:
183+
state, exec_node = self._get_slurm_job_state(self.job_id);
184+
if state == 'RUNNING':
185+
186+
if isinstance(exec_node, str):
187+
if self.active_port_forwarding == False:
188+
self.log.info(f'Slurm job is in state running on compute node {exec_node}');
189+
self._start_ssh_port_forwarding();
190+
result = None;
191+
192+
return result;
193+
194+
def get_shutdown_wait_time(self, recommended: float = 5) -> float:
195+
196+
#recommended = 30.0;
197+
return super().get_shutdown_wait_time(recommended);
198+
199+
def get_stable_start_time(self, recommended: float = 10) -> float:
200+
201+
recommended = 30.0;
202+
return super().get_stable_start_time(recommended)
203+
204+
async def send_signal(self, signum: int) -> None:
205+
206+
if signum == 0:
207+
return await self.poll();
208+
else:
209+
return await super().send_signal(signum);

slurm_jupyter_kernel/script_template.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def choose_template (self):
7979
def use (self, loginnode=None, user=None, proxyjump=None, dry_run=False):
8080

8181
ssh_options = {};
82-
kernel_specs = ['KERNEL_LANGUAGE', 'KERNEL_DISPLAYNAME', 'KERNEL_CMD', 'KERNEL_ENVIRONMENT'];
82+
kernel_specs = ['LANGUAGE', 'DISPLAYNAME', 'ARGV', 'ENVIRONMENT'];
8383

8484
# first of all: check running ssh-agent
8585
try:
@@ -214,18 +214,18 @@ def use (self, loginnode=None, user=None, proxyjump=None, dry_run=False):
214214
# return kernel information
215215
set_kernel_specs = { key.lower(): val for key, val in set_kernel_specs.items() };
216216
set_kernel_specs['loginnode'] = loginnode;
217-
set_kernel_specs['user'] = user;
217+
set_kernel_specs['username'] = user;
218218
set_kernel_specs['proxyjump'] = proxy;
219219

220220
# get kernel name
221-
if not 'kernel_displayname' in set_kernel_specs.keys():
221+
if not 'displayname' in set_kernel_specs.keys():
222222
while True:
223223
kernel_displayname = input(f'Display Name of the new Jupyter kernel (will be shown in e.g. JupyterLab): ');
224224
if kernel_displayname == '':
225225
print(f'{Color.F_LightRed}Please enter a valid Jupyter kernel name!{Color.F_Default}');
226226
continue;
227227

228-
set_kernel_specs['kernel_displayname'] = kernel_displayname;
228+
set_kernel_specs['displayname'] = kernel_displayname;
229229
break;
230230

231231
print('Please specify the Slurm job parameter to start the job with (comma-separated, e.g. "account=hpc,time=00:00:00"):');

0 commit comments

Comments
 (0)