@@ -801,6 +801,49 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
801
801
return webui_url
802
802
803
803
804
+ def check_and_update_resources (resources ):
805
+ """Sanity check a resource dictionary and add sensible defaults.
806
+
807
+ Args:
808
+ resources: A dictionary mapping resource names to resource quantities.
809
+
810
+ Returns:
811
+ A new resource dictionary.
812
+ """
813
+ if resources is None :
814
+ resources = {}
815
+ resources = resources .copy ()
816
+ if "CPU" not in resources :
817
+ # By default, use the number of hardware execution threads for the
818
+ # number of cores.
819
+ resources ["CPU" ] = psutil .cpu_count ()
820
+
821
+ # See if CUDA_VISIBLE_DEVICES has already been set.
822
+ gpu_ids = ray .utils .get_cuda_visible_devices ()
823
+
824
+ # Check that the number of GPUs that the local scheduler wants doesn't
825
+ # excede the amount allowed by CUDA_VISIBLE_DEVICES.
826
+ if ("GPU" in resources and gpu_ids is not None
827
+ and resources ["GPU" ] > len (gpu_ids )):
828
+ raise Exception ("Attempting to start local scheduler with {} GPUs, "
829
+ "but CUDA_VISIBLE_DEVICES contains {}." .format (
830
+ resources ["GPU" ], gpu_ids ))
831
+
832
+ if "GPU" not in resources :
833
+ # Try to automatically detect the number of GPUs.
834
+ resources ["GPU" ] = _autodetect_num_gpus ()
835
+ # Don't use more GPUs than allowed by CUDA_VISIBLE_DEVICES.
836
+ if gpu_ids is not None :
837
+ resources ["GPU" ] = min (resources ["GPU" ], len (gpu_ids ))
838
+
839
+ # Check types.
840
+ for _ , resource_quantity in resources .items ():
841
+ assert (isinstance (resource_quantity , int )
842
+ or isinstance (resource_quantity , float ))
843
+
844
+ return resources
845
+
846
+
804
847
def start_local_scheduler (redis_address ,
805
848
node_ip_address ,
806
849
plasma_store_name ,
@@ -839,30 +882,7 @@ def start_local_scheduler(redis_address,
839
882
Return:
840
883
The name of the local scheduler socket.
841
884
"""
842
- if resources is None :
843
- resources = {}
844
- if "CPU" not in resources :
845
- # By default, use the number of hardware execution threads for the
846
- # number of cores.
847
- resources ["CPU" ] = psutil .cpu_count ()
848
-
849
- # See if CUDA_VISIBLE_DEVICES has already been set.
850
- gpu_ids = ray .utils .get_cuda_visible_devices ()
851
-
852
- # Check that the number of GPUs that the local scheduler wants doesn't
853
- # excede the amount allowed by CUDA_VISIBLE_DEVICES.
854
- if ("GPU" in resources and gpu_ids is not None
855
- and resources ["GPU" ] > len (gpu_ids )):
856
- raise Exception ("Attempting to start local scheduler with {} GPUs, "
857
- "but CUDA_VISIBLE_DEVICES contains {}." .format (
858
- resources ["GPU" ], gpu_ids ))
859
-
860
- if "GPU" not in resources :
861
- # Try to automatically detect the number of GPUs.
862
- resources ["GPU" ] = _autodetect_num_gpus ()
863
- # Don't use more GPUs than allowed by CUDA_VISIBLE_DEVICES.
864
- if gpu_ids is not None :
865
- resources ["GPU" ] = min (resources ["GPU" ], len (gpu_ids ))
885
+ resources = check_and_update_resources (resources )
866
886
867
887
print ("Starting local scheduler with the following resources: {}."
868
888
.format (resources ))
@@ -889,6 +909,7 @@ def start_raylet(redis_address,
889
909
node_ip_address ,
890
910
plasma_store_name ,
891
911
worker_path ,
912
+ resources = None ,
892
913
stdout_file = None ,
893
914
stderr_file = None ,
894
915
cleanup = True ):
@@ -913,6 +934,15 @@ def start_raylet(redis_address,
913
934
Returns:
914
935
The raylet socket name.
915
936
"""
937
+ static_resources = check_and_update_resources (resources )
938
+
939
+ # Format the resource argument in a form like 'CPU,1.0,GPU,0,Custom,3'.
940
+ resource_argument = "," .join ([
941
+ "{},{}" .format (resource_name , resource_value )
942
+ for resource_name , resource_value in zip (static_resources .keys (),
943
+ static_resources .values ())
944
+ ])
945
+
916
946
gcs_ip_address , gcs_port = redis_address .split (":" )
917
947
raylet_name = "/tmp/raylet{}" .format (random_name ())
918
948
@@ -927,7 +957,7 @@ def start_raylet(redis_address,
927
957
928
958
command = [
929
959
RAYLET_EXECUTABLE , raylet_name , plasma_store_name , node_ip_address ,
930
- gcs_ip_address , gcs_port , start_worker_command
960
+ gcs_ip_address , gcs_port , start_worker_command , resource_argument
931
961
]
932
962
pid = subprocess .Popen (command , stdout = stdout_file , stderr = stderr_file )
933
963
@@ -1437,6 +1467,7 @@ def start_ray_processes(address_info=None,
1437
1467
node_ip_address ,
1438
1468
object_store_addresses [i ].name ,
1439
1469
worker_path ,
1470
+ resources = resources [i ],
1440
1471
stdout_file = None ,
1441
1472
stderr_file = None ,
1442
1473
cleanup = cleanup )
0 commit comments