Skip to content

Commit

Permalink
fixed jobs run race condition issue. (NVIDIA#624)
Browse files Browse the repository at this point in the history
  • Loading branch information
yhwen authored May 29, 2022
1 parent a0d75ea commit a6dcce6
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 51 deletions.
55 changes: 25 additions & 30 deletions nvflare/private/fed/app/client/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Provides a command line interface for a federated client trainer."""

import argparse
import logging
import os
import sys
import threading
Expand Down Expand Up @@ -53,6 +54,12 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument("--startup", "-w", type=str, help="startup folder", required=True)
parser.add_argument("--token", "-t", type=str, help="token", required=True)
parser.add_argument("--ssid", "-d", type=str, help="ssid", required=True)
parser.add_argument("--run_number", "-n", type=str, help="run_number", required=True)
parser.add_argument("--client_name", "-c", type=str, help="client name", required=True)
parser.add_argument("--listen_port", "-p", type=str, help="listen port", required=True)
parser.add_argument("--sp_target", "-g", type=str, help="Sp target", required=True)

parser.add_argument(
"--fed_client", "-s", type=str, help="an aggregation server specification json file", required=True
Expand Down Expand Up @@ -103,25 +110,11 @@ def main():
thread = threading.Thread(target=check_parent_alive, args=(parent_pid, stop_event))
thread.start()

token_file = os.path.join(args.workspace, EngineConstant.CLIENT_TOKEN_FILE)
with open(token_file, "r") as f:
token = f.readline().strip()
ssid = f.readline().strip()
run_number = f.readline().strip()
client_name = f.readline().strip()
listen_port = f.readline().strip()
sp_target = f.readline().strip()
print(
"token is: {} ssid is: {} run_number is: {} client_name: {} listen_port: {}".format(
token, ssid, run_number, client_name, listen_port
)
)

startup = args.startup
app_root = os.path.join(
args.workspace,
WorkspaceConstants.WORKSPACE_PREFIX + str(run_number),
WorkspaceConstants.APP_PREFIX + client_name,
WorkspaceConstants.WORKSPACE_PREFIX + str(args.run_number),
WorkspaceConstants.APP_PREFIX + args.client_name,
)

app_log_config = os.path.join(app_root, config_folder, "log.config")
Expand All @@ -138,18 +131,20 @@ def main():
)
conf.configure()

log_file = os.path.join(args.workspace, run_number, "log.txt")
log_file = os.path.join(args.workspace, args.run_number, "log.txt")
add_logfile_handler(log_file)
logger = logging.getLogger("worker_process")
logger.info("Worker_process started.")

deployer = conf.base_deployer
federated_client = deployer.create_fed_client(args, sp_target)
federated_client = deployer.create_fed_client(args, args.sp_target)
federated_client.status = ClientStatus.STARTING

federated_client.token = token
federated_client.ssid = ssid
federated_client.client_name = client_name
federated_client.fl_ctx.set_prop(FLContextKey.CLIENT_NAME, client_name, private=False)
federated_client.fl_ctx.set_prop(EngineConstant.FL_TOKEN, token, private=False)
federated_client.token = args.token
federated_client.ssid = args.ssid
federated_client.client_name = args.client_name
federated_client.fl_ctx.set_prop(FLContextKey.CLIENT_NAME, args.client_name, private=False)
federated_client.fl_ctx.set_prop(EngineConstant.FL_TOKEN, args.token, private=False)
federated_client.fl_ctx.set_prop(FLContextKey.WORKSPACE_ROOT, args.workspace, private=True)

client_config_file_name = os.path.join(app_root, args.client_config)
Expand All @@ -158,10 +153,10 @@ def main():
)
conf.configure()

workspace = Workspace(args.workspace, client_name, config_folder)
workspace = Workspace(args.workspace, args.client_name, config_folder)
run_manager = ClientRunManager(
client_name=client_name,
run_num=run_number,
client_name=args.client_name,
run_num=args.run_number,
workspace=workspace,
client=federated_client,
components=conf.runner_config.components,
Expand All @@ -171,20 +166,20 @@ def main():
federated_client.run_manager = run_manager

with run_manager.new_context() as fl_ctx:
fl_ctx.set_prop(FLContextKey.CLIENT_NAME, client_name, private=False)
fl_ctx.set_prop(EngineConstant.FL_TOKEN, token, private=False)
fl_ctx.set_prop(FLContextKey.CLIENT_NAME, args.client_name, private=False)
fl_ctx.set_prop(EngineConstant.FL_TOKEN, args.token, private=False)
fl_ctx.set_prop(FLContextKey.WORKSPACE_ROOT, args.workspace, private=True)
fl_ctx.set_prop(FLContextKey.ARGS, args, sticky=True)
fl_ctx.set_prop(FLContextKey.APP_ROOT, app_root, private=True, sticky=True)
fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True)
fl_ctx.set_prop(FLContextKey.SECURE_MODE, secure_train, private=True, sticky=True)

client_runner = ClientRunner(config=conf.runner_config, run_num=run_number, engine=run_manager)
client_runner = ClientRunner(config=conf.runner_config, run_num=args.run_number, engine=run_manager)
run_manager.add_handler(client_runner)
fl_ctx.set_prop(FLContextKey.RUNNER, client_runner, private=True)

# Start the command agent
command_agent = CommandAgent(federated_client, int(listen_port), client_runner)
command_agent = CommandAgent(federated_client, int(args.listen_port), client_runner)
command_agent.start(fl_ctx)

federated_client.status = ClientStatus.STARTED
Expand Down
2 changes: 1 addition & 1 deletion nvflare/private/fed/client/client_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def start_app(
self.logger.info("Starting client app. rank: {}".format(self.rank))

open_port = get_open_ports(1)[0]
self._write_token_file(run_number, open_port)

self.client_executor.start_train(
self.client,
Expand All @@ -158,6 +157,7 @@ def start_app(
token,
resource_consumer,
resource_manager,
list(self.client.servers.values())[0]["target"],
)

return "Start the client app..."
Expand Down
15 changes: 15 additions & 0 deletions nvflare/private/fed/client/client_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def start_train(
token,
resource_consumer,
resource_manager,
target: str,
):
"""start_train method to start the FL client training.
Expand All @@ -70,6 +71,7 @@ def start_train(
token: token from resource manager
resource_consumer: resource consumer
resource_manager: resource manager
target: SP target location
"""
pass
Expand Down Expand Up @@ -185,6 +187,7 @@ def start_train(
token,
resource_consumer: ResourceConsumerSpec,
resource_manager: ResourceManagerSpec,
target: str,
):
if allocated_resource:
resource_consumer.consume(allocated_resource)
Expand All @@ -201,6 +204,18 @@ def start_train(
+ args.workspace
+ " -w "
+ self.startup
+ " -t "
+ client.token
+ " -d "
+ client.ssid
+ " -n "
+ run_number
+ " -c "
+ client.client_name
+ " -p "
+ str(listen_port)
+ " -g "
+ target
+ " -s fed_client.json "
" --set" + command_options + " print_conf=True"
)
Expand Down
40 changes: 20 additions & 20 deletions nvflare/private/fed/server/job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,27 +243,27 @@ def run(self, fl_ctx: FLContext):
if self.scheduler:
(ready_job, sites) = self.scheduler.schedule_job(job_candidates=approved_jobs, fl_ctx=fl_ctx)
if ready_job:
client_sites = {k: v for k, v in sites.items() if k != "server"}
try:
self.log_info(fl_ctx, f"Got the job:{ready_job.job_id} from the scheduler to run")
fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, ready_job.job_id)
run_number = self._deploy_job(ready_job, sites, fl_ctx)
job_manager.set_status(ready_job.job_id, RunStatus.DISPATCHED, fl_ctx)
self._start_run(
run_number=run_number,
job=ready_job,
client_sites=client_sites,
fl_ctx=fl_ctx,
)
with self.lock:
with self.lock:
client_sites = {k: v for k, v in sites.items() if k != "server"}
try:
self.log_info(fl_ctx, f"Got the job:{ready_job.job_id} from the scheduler to run")
fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, ready_job.job_id)
run_number = self._deploy_job(ready_job, sites, fl_ctx)
job_manager.set_status(ready_job.job_id, RunStatus.DISPATCHED, fl_ctx)
self._start_run(
run_number=run_number,
job=ready_job,
client_sites=client_sites,
fl_ctx=fl_ctx,
)
self.running_jobs[run_number] = ready_job
job_manager.set_status(ready_job.job_id, RunStatus.RUNNING, fl_ctx)
except Exception as e:
run_number = fl_ctx.get_prop(FLContextKey.JOB_RUN_NUMBER)
if run_number:
self._delete_run(run_number, list(client_sites.keys()), fl_ctx)
job_manager.set_status(ready_job.job_id, RunStatus.FAILED_TO_RUN, fl_ctx)
self.log_error(fl_ctx, f"Failed to run the Job ({ready_job.job_id}): {e}")
job_manager.set_status(ready_job.job_id, RunStatus.RUNNING, fl_ctx)
except Exception as e:
run_number = fl_ctx.get_prop(FLContextKey.JOB_RUN_NUMBER)
if run_number:
self._delete_run(run_number, list(client_sites.keys()), fl_ctx)
job_manager.set_status(ready_job.job_id, RunStatus.FAILED_TO_RUN, fl_ctx)
self.log_error(fl_ctx, f"Failed to run the Job ({ready_job.job_id}): {e}")

time.sleep(1.0)
else:
Expand Down

0 comments on commit a6dcce6

Please sign in to comment.