Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 134 additions & 104 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

This script is used to launch the Parallax server.
It will start the following services:
1.Executor with tp_rank=0 in the main process.
2.Executor with tp_rank>0, each tp_rank as a subprocess.
3.HTTP server as a subprocess.
4.P2P server as a thread in the main process.
1.Executor each tp_rank as a subprocess.
2.HTTP server as a subprocess.
3.P2P server as a subprocess.

Example command:
python src/parallax/launch.py \
Expand All @@ -21,16 +20,13 @@
import multiprocessing
import os
import tempfile
import threading

from parallax.p2p.server import ServerState, launch_p2p_server
from parallax.server.executor import (
Executor,
run_executor_process,
stop_executor_process,
)
import time

from parallax.p2p.server import ServerState, launch_p2p_server_process, stop_p2p_server
from parallax.server.executor import run_executor_process, stop_executor_process
from parallax.server.http_server import launch_http_server, stop_http_server
from parallax.server.server_args import parse_args
from parallax.utils.shared_state import SharedState
from parallax.utils.utils import fetch_model_from_hf, initialize_nccl_port
from parallax_utils.ascii_anime import display_parallax_join
from parallax_utils.logging_config import get_logger, set_log_level
Expand All @@ -39,13 +35,55 @@
logger = get_logger("parallax.launch")


def _update_args_from_shared_state(args, shared_state: SharedState):
"""Update args with layer allocation from shared state"""
model_info = shared_state.get_model_info()
args.start_layer = model_info["block_start_index"]
args.end_layer = model_info["block_end_index"]
# Update model_path if provided and not already set
if model_info["model_name"] and args.model_path is None:
args.model_path = model_info["model_name"]
# Update tp_size if provided, otherwise keep current value
args.tp_size = model_info["tp_size"] or args.tp_size


def _stop_executor_processes(executor_subprocs):
"""Stop all executor processes"""
for executor_process in executor_subprocs:
if executor_process.is_alive():
logger.debug(f"Terminating executor process {executor_process.pid}")
stop_executor_process(executor_process)


def _wait_executors_check_layer_change(shared_state: SharedState, executor_subprocs):
"""Wait for executor processes and check if layer allocation changed.

Returns:
True if layer allocation changed (need to reload executors),
False if all executors exited normally.
"""
while any(proc.is_alive() for proc in executor_subprocs):
for proc in executor_subprocs:
if proc.is_alive():
proc.join(timeout=1.0) # Check every second

if shared_state.get_layer_allocation_changed():
return True

# Check race condition: layer allocation changed after all processes exited
return shared_state.get_layer_allocation_changed()


if __name__ == "__main__":
multiprocessing.set_start_method("spawn", force=True)

gradient_server = None
p2p_server_process = None
http_server_process = None
executor = None
executor_subprocs = []
# Shared state for layer allocation info (used when P2P server is in subprocess)
shared_state = SharedState.create()
shared_state.set_status(ServerState.JOINING.value)

try:
args = parse_args()
set_log_level(args.log_level)
Expand All @@ -72,7 +110,8 @@
# only launch http server on head node
if args.start_layer == 0:
http_server_process = launch_http_server(args)
launch_p2p_server(
# Launch P2P server as subprocess
p2p_server_process = launch_p2p_server_process(
initial_peers=args.initial_peers,
scheduler_addr=args.scheduler_addr,
relay_servers=args.relay_servers,
Expand All @@ -93,26 +132,34 @@
max_sequence_length=args.max_sequence_length,
param_mem_ratio=args.param_mem_ratio,
kvcache_mem_ratio=args.kvcache_mem_ratio,
shared_state=shared_state.dict, # Pass dict to subprocess
log_level=args.log_level,
)
if gradient_server is not None:
gradient_server.status = ServerState.READY

# For each tp_rank > 0, create a subprocess and run executor
for tp_rank in range(1, args.tp_size):
# Launch all executor processes (including tp_rank=0)
for tp_rank in range(args.tp_size):
args_copy = argparse.Namespace(**vars(args))
args_copy.tp_rank = tp_rank
proc = multiprocessing.Process(
target=run_executor_process,
args=(args_copy,),
args=(
args_copy,
shared_state.dict, # Pass dict to subprocess
),
)
proc.start()
executor_subprocs.append(proc)
# Launch executor with tp_rank=0 in the main process
args.tp_rank = 0
executor = Executor.create_from_args(args)
executor.run_loop()

time.sleep(2) # Give executors time to start
shared_state.set_status(ServerState.READY.value)

# Wait for all executor processes
for proc in executor_subprocs:
proc.join()
else:
gradient_server = launch_p2p_server(
# Launch P2P server as subprocess (with scheduler)
# Pass dict to subprocess (multiprocessing requires serializable objects)
p2p_server_process = launch_p2p_server_process(
initial_peers=args.initial_peers,
scheduler_addr=args.scheduler_addr,
relay_servers=args.relay_servers,
Expand All @@ -133,18 +180,34 @@
max_sequence_length=args.max_sequence_length,
param_mem_ratio=args.param_mem_ratio,
kvcache_mem_ratio=args.kvcache_mem_ratio,
shared_state=shared_state.dict, # Pass dict to subprocess
log_level=args.log_level,
)
args.start_layer = gradient_server.block_start_index
args.end_layer = gradient_server.block_end_index
# Only read model_name from scheduler if model_path is not set, so we can use local path as model_path
if args.model_path is None:
args.model_path = gradient_server.model_name
args.tp_size = gradient_server.tp_size

# Wait for layer allocation from scheduler (via shared state)
logger.debug("Waiting for layer allocation from scheduler...")
max_wait_time = 300 # 5 minutes
wait_start = time.time()
while True:
model_info = shared_state.get_model_info()
if (
model_info["block_start_index"] is not None
and model_info["block_end_index"] is not None
and model_info["model_name"] is not None
):
break
if time.time() - wait_start > max_wait_time:
logger.error("Timeout waiting for layer allocation from scheduler")
raise RuntimeError("Failed to get layer allocation from scheduler")
time.sleep(1)

# Get layer allocation from shared state
_update_args_from_shared_state(args, shared_state)

logger.debug(
f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}"
f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}, "
f"model: {args.model_path}"
)
gradient_server.status = ServerState.INITIALIZING

if args.log_level != "DEBUG":
display_parallax_join(args.model_path)
Expand All @@ -157,100 +220,67 @@
# Main execution loop with layer reallocation support
while True:
try:
# For each tp_rank > 0, create a subprocess and run executor
for tp_rank in range(1, args.tp_size):
# Launch all executor processes (including tp_rank=0)
executor_subprocs = []
for tp_rank in range(args.tp_size):
args_copy = argparse.Namespace(**vars(args))
args_copy.tp_rank = tp_rank
proc = multiprocessing.Process(
target=run_executor_process,
args=(args_copy,),
args=(
args_copy,
shared_state.dict, # Pass dict to subprocess
),
)
proc.start()
executor_subprocs.append(proc)
# Launch executor with tp_rank=0 in the main process
args.tp_rank = 0
executor = Executor.create_from_args(args, gradient_server=gradient_server)
if gradient_server is not None:
gradient_server.status = ServerState.READY

executor.run_loop()

# Check if layer allocation changed (executor exited due to reallocation)
if gradient_server is not None and gradient_server._layer_allocation_changed:
logger.warning(
"Layer allocation changed! Reloading executor with new layers..."
)

# shutdown all executor processes
thread_pool = []
for executor_process in executor_subprocs:
t = threading.Thread(
target=stop_executor_process, args=(executor_process,)
)
t.start()
thread_pool.append(t)
executor.shutdown()
for t in thread_pool:
t.join()

if args.start_layer == 0:
http_server_process = stop_http_server(http_server_process)
if gradient_server.block_start_index == 0:
http_server_process = launch_http_server(args)

# Update args with new layer allocation
args.start_layer = gradient_server.block_start_index
args.end_layer = gradient_server.block_end_index
if gradient_server.model_name:
args.model_path = gradient_server.model_name

# Wait for executors and restart if layer allocation changes
if _wait_executors_check_layer_change(shared_state, executor_subprocs):
logger.warning("Layer allocation changed! Stopping executors to reload...")
# Reset flag and set status to INITIALIZING
shared_state.update(
_layer_allocation_changed=False,
status=ServerState.INITIALIZING.value,
)
_stop_executor_processes(executor_subprocs)
_update_args_from_shared_state(args, shared_state)
logger.info(
f"Creating new executor with layers [{args.start_layer}, {args.end_layer})"
f"Reloading executor with layers [{args.start_layer}, {args.end_layer})"
)
continue

gradient_server._layer_allocation_changed = False
continue # Create new executor in next iteration
else:
break # Normal exit
# All processes exited normally
break
except KeyboardInterrupt:
logger.debug("Received interrupt signal, shutting down...")
break
except Exception as e:
logger.exception(f"Executor error: {e}")
# If layer allocation changed, try to reload
if gradient_server is not None and gradient_server._layer_allocation_changed:
logger.info("Attempting to reload executor after error...")
if executor is not None:
executor.shutdown()
continue
else:
raise
# Shutdown all executor processes on error
for proc in executor_subprocs:
if proc.is_alive():
stop_executor_process(proc)
raise
except KeyboardInterrupt:
logger.debug("Received interrupt signal, shutting down...")
except Exception as e:
logger.exception(e)
finally:
thread_pool = []

# Shutdown http server
if http_server_process is not None:
t = threading.Thread(target=stop_http_server, args=(http_server_process,))
t.start()
thread_pool.append(t)

# Shutdown gradient server
if gradient_server is not None:
gradient_server.shutdown()
# Shutdown all processes
logger.debug("Shutting down all processes...")

# Shutdown executor subprocesses
for executor_process in executor_subprocs:
t = threading.Thread(target=stop_executor_process, args=(executor_process,))
t.start()
thread_pool.append(t)
if executor_process.is_alive():
stop_executor_process(executor_process)

# Shutdown executor main process
if executor is not None:
executor.shutdown()
# Shutdown P2P server subprocess
if p2p_server_process is not None:
stop_p2p_server(p2p_server_process)

# Shutdown http server
if http_server_process is not None:
stop_http_server(http_server_process)

for t in thread_pool:
t.join()
logger.debug("All processes shut down.")
Loading