3030
3131import gc
3232import os
33- import uuid
3433
3534import ray
3635import torch
@@ -91,14 +90,26 @@ def __init__(self):
9190
9291 self .device_uuid = current_platform .get_device_uuid (0 )
9392 self .zmq_context = zmq .Context ()
93+ self .zmq_address_counter = 0
9494
9595 def report_device_id (self ) -> str :
9696 return self .device_uuid
9797
98- def get_zmq_handles (self ):
99- return {self .device_uuid : f"ipc:///tmp/rl-colocate-zmq-{ uuid .uuid4 ()} .sock" }
98+ @property
99+ def _zmq_handle (self ) -> str :
100+ suffix = f"{ self .device_uuid } -{ self .zmq_address_counter } "
101+ return f"ipc:///tmp/rl-colocate-zmq-{ suffix } .sock"
100102
101- def update_weights (self , zmq_handle : dict [str , str ]):
103+ def get_zmq_handles (self ) -> dict [str , str ]:
104+ return {self .device_uuid : self ._zmq_handle }
105+
106+ def update_weights (self ):
107+ try :
108+ self ._update_weights ()
109+ finally :
110+ self .zmq_address_counter += 1
111+
112+ def _update_weights (self ):
102113 # align size to avoid misaligned address
103114 align_size = 256
104115
@@ -112,7 +123,7 @@ def get_size(p: torch.Tensor) -> int:
112123 # use max_tensor_size * 2 as buffer size
113124 buffer = torch .empty (max_tensor_size * 2 , dtype = torch .uint8 , device = "cuda:0" )
114125 s = self .zmq_context .socket (zmq .REQ )
115- s .bind (zmq_handle [ self .device_uuid ] )
126+ s .bind (self ._zmq_handle )
116127 handle = reduce_tensor (buffer )
117128
118129 offset = 0
@@ -234,7 +245,7 @@ def get_size(p: torch.Tensor) -> int:
234245
235246print ("Update the weights of the inference engines." )
236247ray .get (
237- [actor .update_weights .remote (zmq_handles ) for actor in training_actors ]
248+ [actor .update_weights .remote () for actor in training_actors ]
238249 + [
239250 llm .collective_rpc .remote ("update_weights_from_ipc" , args = (zmq_handles ,))
240251 for llm in inference_engines
0 commit comments