@@ -340,6 +340,55 @@ async def _start(self):
340340 await self ._close ()
341341 raise RuntimeError (f"Cluster failed to start: { e } " ) from e
342342
343+ def _spec_name_to_worker_names (self , spec_name ):
344+ """Convert a spec name to the set of worker names it represents.
345+
346+ For regular workers, returns {spec_name} as a string.
347+ For grouped workers, returns {spec_name + suffix for suffix in group}.
348+
349+ Parameters
350+ ----------
351+ spec_name : int or str
352+ The spec name (key in worker_spec)
353+
354+ Returns
355+ -------
356+ set of str
357+ The worker names that the scheduler knows about
358+ """
359+ spec = self .worker_spec .get (spec_name )
360+ if spec and "group" in spec :
361+ return {str (spec_name ) + suffix for suffix in spec ["group" ]}
362+ return {str (spec_name )}
363+
364+ def _worker_name_to_spec_name (self , worker_name ):
365+ """Convert a worker name to its spec name.
366+
367+ For regular workers, returns the worker name (converted to int if numeric).
368+ For grouped workers, extracts the prefix before the first "-".
369+
370+ Parameters
371+ ----------
372+ worker_name : str or int
373+ The worker name from the scheduler
374+
375+ Returns
376+ -------
377+ int or str
378+ The spec name (key in worker_spec)
379+ """
380+ worker_name_str = str (worker_name )
381+ if "-" in worker_name_str :
382+ spec_name = worker_name_str .split ("-" )[0 ]
383+ # Convert to int if numeric to match worker_spec keys
384+ if spec_name .isdigit ():
385+ return int (spec_name )
386+ return spec_name
387+ # Try to convert to int if numeric
388+ if worker_name_str .isdigit ():
389+ return int (worker_name_str )
390+ return worker_name
391+
343392 def _correct_state (self ):
344393 if self ._correct_state_waiting :
345394 # If people call this frequently, we only want to run it once
@@ -356,7 +405,29 @@ async def _correct_state_internal(self) -> None:
356405 to_close = set (self .workers ) - set (self .worker_spec )
357406 if to_close :
358407 if self .scheduler .status == Status .running :
359- await self .scheduler_comm .retire_workers (workers = list (to_close ))
408+ # For grouped workers, we need to retire the actual worker names
409+ # that the scheduler knows about, not the spec names
410+ actual_workers_to_retire = []
411+ active_worker_names = {
412+ str (w ["name" ])
413+ for w in self .scheduler_info .get ("workers" , {}).values ()
414+ }
415+
416+ for spec_name in to_close :
417+ # Get all worker names for this spec (handles both regular and grouped)
418+ expected_worker_names = self ._spec_name_to_worker_names (
419+ spec_name
420+ )
421+ # Only retire workers that actually exist in the scheduler
422+ for worker_name in expected_worker_names :
423+ if worker_name in active_worker_names :
424+ actual_workers_to_retire .append (worker_name )
425+
426+ if actual_workers_to_retire :
427+ await self .scheduler_comm .retire_workers (
428+ workers = actual_workers_to_retire
429+ )
430+
360431 tasks = [
361432 asyncio .create_task (self .workers [w ].close ())
362433 for w in to_close
@@ -397,6 +468,11 @@ async def _correct_state_internal(self) -> None:
397468
398469 def _update_worker_status (self , op , msg ):
399470 if op == "remove" :
471+ # Get worker name - might already be gone from scheduler_info
472+ if msg not in self .scheduler_info .get ("workers" , {}):
473+ super ()._update_worker_status (op , msg )
474+ return
475+
400476 removed_worker_name = self .scheduler_info ["workers" ][msg ]["name" ]
401477
402478 # Closure to handle removal of a worker from the cluster
@@ -408,35 +484,26 @@ def f():
408484 if removed_worker_name in active_workers :
409485 return
410486
411- # Build mapping from individual worker names to their worker spec names
412- # - For non-grouped workers: worker name == spec name (1:1)
413- # - For grouped workers: multiple workers map to one spec entry
414- worker_to_spec = {}
415- for worker_spec_name , spec in self .worker_spec .items ():
416- if "group" not in spec :
417- worker_to_spec [worker_spec_name ] = worker_spec_name
418- else :
419- grouped_workers = {
420- str (worker_spec_name ) + suffix : worker_spec_name
421- for suffix in spec ["group" ]
422- }
423- worker_to_spec .update (grouped_workers )
424-
425- # Find and remove the worker spec entry
426- # Note: For grouped workers, we remove the entire spec when ANY worker dies.
427- # This assumes that partial failure means the whole group is compromised
487+ # Convert worker name to spec name using helper method
488+ worker_spec_name = self ._worker_name_to_spec_name (removed_worker_name )
489+
490+ # Check if this is a grouped worker
491+ spec = self .worker_spec .get (worker_spec_name )
492+ is_grouped = spec and "group" in spec
493+
494+ # Close and remove the worker object
495+ if worker_spec_name in self .workers :
496+ self ._futures .add (
497+ asyncio .ensure_future (self .workers [worker_spec_name ].close ())
498+ )
499+ del self .workers [worker_spec_name ]
500+
501+ # Only remove spec for grouped workers
502+ # For grouped workers: when ANY worker dies, the whole group is compromised
428503 # (e.g., in HPC systems, if one process in a multi-process job fails, the
429504 # entire job allocation is typically lost).
430- worker_spec_name = worker_to_spec .get (removed_worker_name )
431- if worker_spec_name and worker_spec_name in self .worker_spec :
432- # Close and remove the worker object
433- if worker_spec_name in self .workers :
434- self ._futures .add (
435- asyncio .ensure_future (
436- self .workers [worker_spec_name ].close ()
437- )
438- )
439- del self .workers [worker_spec_name ]
505+ # For regular workers: keep spec so cluster can recreate them
506+ if is_grouped and worker_spec_name in self .worker_spec :
440507 del self .worker_spec [worker_spec_name ]
441508
442509 delay = parse_timedelta (
@@ -541,24 +608,57 @@ def _memory_per_worker(self) -> int:
541608 )
542609
543610 def scale (self , n = 0 , memory = None , cores = None ):
544- if memory is not None :
545- n = max (n , int (math .ceil (parse_bytes (memory ) / self ._memory_per_worker ())))
611+ # For grouped workers, n represents number of workers, but memory/cores
612+ # calculations represent number of specs (since _memory_per_worker and
613+ # _threads_per_worker return values for the entire MultiWorker spec)
614+ if self .new_spec and "group" in self .new_spec :
615+ workers_per_spec = len (self .new_spec ["group" ])
616+
617+ # Convert n from number of workers to number of specs
618+ target_specs_from_n = int (math .ceil (n / workers_per_spec )) if n > 0 else 0
619+
620+ # memory/cores calculations already give us number of specs
621+ if memory is not None :
622+ target_specs_from_memory = int (
623+ math .ceil (parse_bytes (memory ) / self ._memory_per_worker ())
624+ )
625+ target_specs = max (target_specs_from_n , target_specs_from_memory )
626+ else :
627+ target_specs = target_specs_from_n
628+
629+ if cores is not None :
630+ target_specs_from_cores = int (
631+ math .ceil (cores / self ._threads_per_worker ())
632+ )
633+ target_specs = max (target_specs , target_specs_from_cores )
634+ else :
635+ # For regular workers, everything is in terms of workers (which equals specs)
636+ if memory is not None :
637+ n = max (
638+ n , int (math .ceil (parse_bytes (memory ) / self ._memory_per_worker ()))
639+ )
640+
641+ if cores is not None :
642+ n = max (n , int (math .ceil (cores / self ._threads_per_worker ())))
643+
644+ target_specs = n
546645
547- if cores is not None :
548- n = max (n , int (math .ceil (cores / self ._threads_per_worker ())))
646+ if len (self .worker_spec ) > target_specs :
647+ # Build set of spec names that have launched at least one worker
648+ launched_spec_names = set ()
649+ for worker_info in self .scheduler_info .get ("workers" , {}).values ():
650+ spec_name = self ._worker_name_to_spec_name (worker_info ["name" ])
651+ launched_spec_names .add (spec_name )
549652
550- if len (self .worker_spec ) > n :
551- not_yet_launched = set (self .worker_spec ) - {
552- v ["name" ] for v in self .scheduler_info ["workers" ].values ()
553- }
554- while len (self .worker_spec ) > n and not_yet_launched :
653+ not_yet_launched = set (self .worker_spec ) - launched_spec_names
654+ while len (self .worker_spec ) > target_specs and not_yet_launched :
555655 del self .worker_spec [not_yet_launched .pop ()]
556656
557- while len (self .worker_spec ) > n :
657+ while len (self .worker_spec ) > target_specs :
558658 self .worker_spec .popitem ()
559659
560660 if self .status not in (Status .closing , Status .closed ):
561- while len (self .worker_spec ) < n :
661+ while len (self .worker_spec ) < target_specs :
562662 self .worker_spec .update (self .new_worker_spec ())
563663
564664 self .loop .add_callback (self ._correct_state )
@@ -597,17 +697,9 @@ def _supports_scaling(self):
597697 return bool (self .new_spec )
598698
599699 async def scale_down (self , workers ):
600- # We may have groups, if so, map worker addresses to job names
700+ # Convert worker names to spec names (handles both regular and grouped workers)
601701 if not all (w in self .worker_spec for w in workers ):
602- mapping = {}
603- for name , spec in self .worker_spec .items ():
604- if "group" in spec :
605- for suffix in spec ["group" ]:
606- mapping [str (name ) + suffix ] = name
607- else :
608- mapping [name ] = name
609-
610- workers = {mapping .get (w , w ) for w in workers }
702+ workers = {self ._worker_name_to_spec_name (w ) for w in workers }
611703
612704 for w in workers :
613705 if w in self .worker_spec :
0 commit comments