Skip to content

Commit fb7cb16

Browse files
authored
Merge pull request #721 from Tauffer-Consulting/singularity-image-check
Improve Singularity check for image
2 parents 572a684 + 5597ef1 commit fb7cb16

File tree

1 file changed

+48
-13
lines changed

1 file changed

+48
-13
lines changed

spikeinterface/sorters/runsorter.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from .sorterlist import sorter_dict
1212
from .utils import SpikeSortingError, has_nvidia
1313

14+
try:
15+
HAS_DOCKER = True
16+
import docker
17+
except ModuleNotFoundError:
18+
HAS_DOCKER = False
19+
1420
REGISTRY = 'spikeinterface'
1521

1622
SORTER_DOCKER_MAP = dict(
@@ -24,6 +30,7 @@
2430
ironclust='ironclust-compiled',
2531
kilosort='kilosort-compiled',
2632
kilosort2='kilosort2-compiled',
33+
kilosort2_5='kilosort2_5-compiled',
2734
kilosort3='kilosort3-compiled',
2835
waveclus='waveclus-compiled',
2936
)
@@ -223,19 +230,15 @@ def __init__(self, mode, container_image, volumes, extra_kwargs):
223230
'container_requires_gpu', None)
224231

225232
if mode == 'docker':
226-
import docker
233+
if not HAS_DOCKER:
234+
raise ModuleNotFoundError("No module named 'docker'")
227235
client = docker.from_env()
228236
if container_requires_gpu is not None:
229237
extra_kwargs.pop('container_requires_gpu')
230238
extra_kwargs["device_requests"] = [
231239
docker.types.DeviceRequest(count=-1, capabilities=[['gpu']])]
232240

233-
# check if the image is already present locally
234-
repo_tags = []
235-
for image in client.images.list():
236-
repo_tags.extend(image.attrs['RepoTags'])
237-
238-
if container_image not in repo_tags:
241+
if self._get_docker_image(container_image) is None:
239242
print(f"Docker: pulling image {container_image}")
240243
client.images.pull(container_image)
241244

@@ -245,13 +248,36 @@ def __init__(self, mode, container_image, volumes, extra_kwargs):
245248
elif mode == 'singularity':
246249
from spython.main import Client
247250
# load local image file if it exists, otherwise search dockerhub
251+
sif_file = Client._get_filename(container_image)
252+
singularity_image = None
248253
if Path(container_image).exists():
249-
self.singularity_image = container_image
254+
singularity_image = container_image
255+
elif Path(sif_file).exists():
256+
singularity_image = sif_file
250257
else:
251-
print(f"Singularity: pulling image {container_image}")
252-
self.singularity_image = Client.pull(f'docker://{container_image}')
253-
254-
if not Path(self.singularity_image).exists():
258+
if HAS_DOCKER:
259+
docker_image = self._get_docker_image(container_image)
260+
if docker_image:
261+
print('Building singularity image from local docker image')
262+
# Save docker image as tar and build singularity image
263+
tmp_file = sif_file.replace('sif', 'tar').replace(':', '_')
264+
f = open(tmp_file, 'wb')
265+
try:
266+
for chunk in docker_image.save(chunk_size=100*1024*1024): # 100 MB
267+
f.write(chunk)
268+
singularity_image = Client.build(f'docker-archive://{tmp_file}', sif_file, sudo=False)
269+
except Exception as e:
270+
print(f'Failed to build singularity image from local: {e}')
271+
finally:
272+
# Clean up
273+
f.close()
274+
if os.path.exists(tmp_file):
275+
os.remove(tmp_file)
276+
if not singularity_image:
277+
print(f"Singularity: pulling image {container_image}")
278+
singularity_image = Client.pull(f'docker://{container_image}')
279+
280+
if not Path(singularity_image).exists():
255281
raise FileNotFoundError(f'Unable to locate container image {container_image}')
256282

257283
# bin options
@@ -263,7 +289,16 @@ def __init__(self, mode, container_image, volumes, extra_kwargs):
263289
# only nvidia at the moment
264290
options += ['--nv']
265291

266-
self.client_instance = Client.instance(self.singularity_image, start=False, options=options)
292+
self.client_instance = Client.instance(singularity_image, start=False, options=options)
293+
294+
@staticmethod
295+
def _get_docker_image(container_image):
296+
docker_client = docker.from_env(timeout=300)
297+
try:
298+
docker_image = docker_client.images.get(container_image)
299+
except docker.errors.ImageNotFound:
300+
docker_image = None
301+
return docker_image
267302

268303
def start(self):
269304
if self.mode == 'docker':

0 commit comments

Comments
 (0)