Skip to content

Commit 35e57e5

Browse files
weixiao-huanghuangweixiao
authored andcommitted
feat: add zmq_address_counter
Signed-off-by: huangweixiao <huangweixiao@msh.team>
1 parent 8d55398 commit 35e57e5

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

examples/offline_inference/rlhf_colocate.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
import gc
3232
import os
33-
import uuid
3433

3534
import ray
3635
import torch
@@ -91,14 +90,26 @@ def __init__(self):
9190

9291
self.device_uuid = current_platform.get_device_uuid(0)
9392
self.zmq_context = zmq.Context()
93+
self.zmq_address_counter = 0
9494

9595
def report_device_id(self) -> str:
9696
return self.device_uuid
9797

98-
def get_zmq_handles(self):
99-
return {self.device_uuid: f"ipc:///tmp/rl-colocate-zmq-{uuid.uuid4()}.sock"}
98+
@property
99+
def _zmq_handle(self) -> str:
100+
suffix = f"{self.device_uuid}-{self.zmq_address_counter}"
101+
return f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock"
100102

101-
def update_weights(self, zmq_handle: dict[str, str]):
103+
def get_zmq_handles(self) -> dict[str, str]:
104+
return {self.device_uuid: self._zmq_handle}
105+
106+
def update_weights(self):
107+
try:
108+
self._update_weights()
109+
finally:
110+
self.zmq_address_counter += 1
111+
112+
def _update_weights(self):
102113
# align size to avoid misaligned address
103114
align_size = 256
104115

@@ -112,7 +123,7 @@ def get_size(p: torch.Tensor) -> int:
112123
# use max_tensor_size * 2 as buffer size
113124
buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0")
114125
s = self.zmq_context.socket(zmq.REQ)
115-
s.bind(zmq_handle[self.device_uuid])
126+
s.bind(self._zmq_handle)
116127
handle = reduce_tensor(buffer)
117128

118129
offset = 0
@@ -234,7 +245,7 @@ def get_size(p: torch.Tensor) -> int:
234245

235246
print("Update the weights of the inference engines.")
236247
ray.get(
237-
[actor.update_weights.remote(zmq_handles) for actor in training_actors]
248+
[actor.update_weights.remote() for actor in training_actors]
238249
+ [
239250
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
240251
for llm in inference_engines

0 commit comments

Comments
 (0)