Skip to content

Commit

Permalink
first worker
Browse files Browse the repository at this point in the history
  • Loading branch information
esythan committed Aug 10, 2022
1 parent 1797fde commit ea60e40
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 104 deletions.
217 changes: 126 additions & 91 deletions tools/static_ps_online_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,25 +243,27 @@ def run_online_worker(self):
logger.info(
"get_last_save_model last_day = {}, last_pass = {}, last_path = {}, xbox_base_key = {}".
format(last_day, last_pass, last_path, xbox_base_key))
if last_day != -1 and fleet.is_first_worker():
load_model(last_path, 0, self.hadoop_client)
fleet.barrier_worker()
if last_day != -1:
logger.info("going to load model {}".format(last_path))
begin = time.time()
fleet.load_model(last_path, 0)
end = time.time()
logger.info("load model cost {} min".format((end - begin) / 60.0))

day = self.start_day
infer_first = True
while int(day) <= int(self.end_day):
logger.info("training a new day {}, end_day = {}".format(
day, self.end_day))
if last_day != -1 and int(day) < last_day:
day = get_next_day(day)
continue
# base_model_saved = False

for pass_id in range(1, 1 + len(self.online_intervals)):
print(last_day, day, last_pass, pass_id)
if (last_day != -1 and int(day) == last_day) and (
last_pass != -1 and int(pass_id) <= last_pass):
continue
if self.save_first_base and fleet.is_first_worker():
if self.save_first_base:
self.save_first_base = False
last_base_day, last_base_path, tmp_xbox_base_key = \
get_last_save_xbox_base(self.save_model_path, self.hadoop_client)
Expand All @@ -270,6 +272,7 @@ def run_online_worker(self):
format(last_base_day, last_base_path,
tmp_xbox_base_key))
if int(day) > last_base_day:
logger.info("going to save first base model")
xbox_base_key = int(time.time())
save_xbox_model(self.save_model_path, day, -1,
self.exe, self.inference_feed_vars,
Expand All @@ -283,7 +286,9 @@ def run_online_worker(self):
client=self.hadoop_client)
elif int(day) == last_base_day:
xbox_base_key = tmp_xbox_base_key
fleet.barrier_worker()
logger.info("first base model exists")
else:
logger.info("first base model exists")

logger.info("training a new day = {} new pass = {}".format(
day, pass_id))
Expand All @@ -299,26 +304,23 @@ def run_online_worker(self):

infer_cost = 0
infer_metric_cost = 0
if infer_first:
infer_first = False
else:
logger.info("Day:{}, Pass: {}, Infering Dataset Begin.".
format(day, pass_id))
begin = time.time()
self.dataset_infer_loop(dataset, day, pass_id)
end = time.time()
infer_cost = (end - begin) / 60.0
logger.info("Infering Dataset Done, using time {} mins.".
format(infer_cost))
begin = time.time()
metric_str = get_global_metrics_str(
paddle.static.global_scope(), self.metric_list, "")
logger.info("Day:{}, Pass: {}, Infer Global Metric: {}".
format(day, pass_id, metric_str))
clear_metrics(paddle.static.global_scope(),
self.metric_list, self.metric_types)
end = time.time()
infer_metric_cost = (end - begin) / 60.0
logger.info("Day:{}, Pass: {}, Infering Dataset Begin.".format(
day, pass_id))
begin = time.time()
self.dataset_infer_loop(dataset, day, pass_id)
end = time.time()
infer_cost = (end - begin) / 60.0
logger.info("Infering Dataset Done, using time {} mins.".
format(infer_cost))
begin = time.time()
metric_str = get_global_metrics_str(
paddle.static.global_scope(), self.metric_list, "")
logger.info("Day:{}, Pass: {}, Infer Global Metric: {}".format(
day, pass_id, metric_str))
clear_metrics(paddle.static.global_scope(), self.metric_list,
self.metric_types)
end = time.time()
infer_metric_cost = (end - begin) / 60.0

logger.info("Day:{}, Pass: {}, Training Dataset Begin.".format(
day, pass_id))
Expand All @@ -333,11 +335,6 @@ def run_online_worker(self):
logger.info("Training Dataset Done, using time {} mins.".
format(train_cost))

begin = time.time()
dataset.release_memory()
end = time.time()
release_cost = (end - begin) / 60.0

begin = time.time()
metric_str = get_global_metrics_str(
paddle.static.global_scope(), self.metric_list, "")
Expand All @@ -347,6 +344,12 @@ def run_online_worker(self):
self.metric_types)
end = time.time()
metric_cost = (end - begin) / 60

begin = time.time()
dataset.release_memory()
end = time.time()
release_cost = (end - begin) / 60.0

end_train = time.time()
total_cost = (end_train - begin_train) / 60
other_cost = total_cost - read_data_cost - train_cost - release_cost - metric_cost - infer_cost - infer_metric_cost
Expand Down Expand Up @@ -375,38 +378,53 @@ def run_online_worker(self):

dump_dataset.release_memory()

if fleet.is_first_worker():
if pass_id % self.checkpoint_per_pass == 0:
save_model(self.exe, self.save_model_path, day,
pass_id)
write_model_donefile(
if pass_id % self.checkpoint_per_pass == 0 and pass_id != len(
self.online_intervals):
begin = time.time()
save_model(self.exe, self.save_model_path, day, pass_id)
end = time.time()
save_cost = (end - begin) / 60.0
begin = time.time()
write_model_donefile(
output_path=self.save_model_path,
day=day,
pass_id=pass_id,
xbox_base_key=xbox_base_key,
client=self.hadoop_client)
end = time.time()
donefile_cost = (end - begin) / 60.0
log_str = "finished save checkpoint model epoch %d [save_model: %s min][donefile: %s min]" % (
pass_index, save_cost, donefile_cost)
logger.info(log_str)
if pass_id % self.save_delta_frequency == 0:
last_xbox_day, last_xbox_pass, last_xbox_path, _ = get_last_save_xbox(
self.save_model_path, self.hadoop_client)
if int(day) < last_xbox_day or int(
day) == last_xbox_day and int(
pass_id) <= last_xbox_pass:
logger.info("delta model exists")
else:
begin = time.time()
save_xbox_model(self.save_model_path, day, pass_id,
self.exe, self.inference_feed_vars,
self.inference_target_var,
self.hadoop_client) # 1 delta
end = time.time()
save_cost = (end - begin) / 60.0
begin = time.time()
write_xbox_donefile(
output_path=self.save_model_path,
day=day,
pass_id=pass_id,
xbox_base_key=xbox_base_key,
client=self.hadoop_client)
if pass_id % self.save_delta_frequency == 0:
last_xbox_day, last_xbox_pass, last_xbox_path, _ = get_last_save_xbox(
self.save_model_path, self.hadoop_client)
if int(day) < last_xbox_day or int(
day) == last_xbox_day and int(
pass_id) <= last_xbox_pass:
log_str = "delta model exists"
logger.info(log_str)
else:
save_xbox_model(self.save_model_path, day, pass_id,
self.exe, self.inference_feed_vars,
self.inference_target_var,
self.hadoop_client) # 1 delta
write_xbox_donefile(
output_path=self.save_model_path,
day=day,
pass_id=pass_id,
xbox_base_key=xbox_base_key,
client=self.hadoop_client,
hadoop_fs_name=self.hadoop_fs_name,
monitor_data=metric_str)
fleet.barrier_worker()
client=self.hadoop_client,
hadoop_fs_name=self.hadoop_fs_name,
monitor_data=metric_str)
end = time.time()
donefile_cost = (end - begin) / 60.0
log_str = "finished save delta model epoch %d [save_model: %s min][donefile: %s min]" % (
pass_index, save_cost, donefile_cost)
logger.info(log_str)

logger.info("shrink table")
begin = time.time()
Expand All @@ -415,37 +433,54 @@ def run_online_worker(self):
logger.info("shrink table done, cost %s min" % (
(end - begin) / 60.0))

if fleet.is_first_worker():
last_base_day, last_base_path, last_base_key = get_last_save_xbox_base(
self.save_model_path, self.hadoop_client)
logger.info(
"one epoch finishes, get_last_save_xbox, last_base_day = {}, last_base_path = {}, last_base_key = {}".
format(last_base_day, last_base_path, last_base_key))
next_day = get_next_day(day)
if int(next_day) <= last_base_day:
logger.info("batch model/base xbox model exists")
else:
xbox_base_key = int(time.time())
save_xbox_model(self.save_model_path, next_day, -1,
self.exe, self.inference_feed_vars,
self.inference_target_var,
self.hadoop_client)
write_xbox_donefile(
output_path=self.save_model_path,
day=next_day,
pass_id=-1,
xbox_base_key=xbox_base_key,
client=self.hadoop_client,
hadoop_fs_name=self.hadoop_fs_name,
monitor_data=metric_str)
save_batch_model(self.exe, self.save_model_path, next_day)
write_model_donefile(
output_path=self.save_model_path,
day=next_day,
pass_id=-1,
xbox_base_key=xbox_base_key,
client=self.hadoop_client)
fleet.barrier_worker()
last_base_day, last_base_path, last_base_key = get_last_save_xbox_base(
self.save_model_path, self.hadoop_client)
logger.info(
"one epoch finishes, get_last_save_xbox, last_base_day = {}, last_base_path = {}, last_base_key = {}".
format(last_base_day, last_base_path, last_base_key))
next_day = get_next_day(day)
if int(next_day) <= last_base_day:
xbox_base_key = last_base_key
logger.info("batch model/base xbox model exists")
else:
xbox_base_key = int(time.time())
begin = time.time()
save_xbox_model(self.save_model_path, next_day, -1, self.exe,
self.inference_feed_vars,
self.inference_target_var, self.hadoop_client)
end = time.time()
save_cost = (end - begin) / 60.0
begin = time.time()
write_xbox_donefile(
output_path=self.save_model_path,
day=next_day,
pass_id=-1,
xbox_base_key=xbox_base_key,
client=self.hadoop_client,
hadoop_fs_name=self.hadoop_fs_name,
monitor_data=metric_str)
end = time.time()
donefile_cost = (end - begin) / 60.0
log_str = "finished save base model day %d [save_model: %s min][donefile: %s min]" % (
next_day, save_cost, donefile_cost)
logger.info(log_str)

begin = time.time()
save_batch_model(self.exe, self.save_model_path, next_day)
end = time.time()
save_cost = (end - begin) / 60.0
begin = time.time()
write_model_donefile(
output_path=self.save_model_path,
day=next_day,
pass_id=-1,
xbox_base_key=xbox_base_key,
client=self.hadoop_client)
end = time.time()
donefile_cost = (end - begin) / 60.0
log_str = "finished save batch model day %d [save_model: %s min][donefile: %s min]" % (
next_day, save_cost, donefile_cost)
logger.info(log_str)
day = get_next_day(day)

def dataset_train_loop(self, cur_dataset, day, pass_index,
Expand Down
19 changes: 6 additions & 13 deletions tools/utils/static_ps/flow_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,6 @@ def get_online_pass_interval(split_interval, split_per_pass,
return online_pass_interval


def load_model(model_path, mode, client):
if not is_local(model_path) and (mode == 1 or mode == 2):
local_path = "./dnn_plugin"
if os.path.exists(local_path):
shutil.rmtree(local_path)
os.mkdir(local_path)
client.download(model_path + "/dnn_plugin/", local_path)
fleet.load_model(model_path, mode)


def save_model(exe, output_path, day, pass_id, mode=0):
"""
Args:
Expand Down Expand Up @@ -180,7 +170,7 @@ def write_model_donefile(output_path,
suffix_name = "/%s/0/" % day
model_path = output_path.rstrip("/") + suffix_name

if fleet.worker_index() == 0:
if fleet.is_first_worker():
donefile_path = output_path + "/" + donefile_name
content = "%s\t%lu\t%s\t%s\t%d" % (day, xbox_base_key, \
model_path, pass_id, 0)
Expand Down Expand Up @@ -245,6 +235,7 @@ def write_model_donefile(output_path,
else:
logger.info("not write %s because %s/%s already "
"exists" % (donefile_name, day, pass_id))
fleet.barrier_worker()


def get_last_save_model(output_path, client):
Expand Down Expand Up @@ -403,8 +394,9 @@ def save_xbox_model(output_path, day, pass_id, exe, feed_vars, target_vars,
model_path, [feed.name for feed in feed_vars],
target_vars,
mode=mode)
if not is_local(model_path):
if not is_local(model_path) and fleet.is_first_worker():
client.upload("./dnn_plugin", model_path)
fleet.barrier_worker()


def write_xbox_donefile(output_path,
Expand Down Expand Up @@ -444,7 +436,7 @@ def write_xbox_donefile(output_path,
if donefile_name is None:
donefile_name = "xbox_base_done.txt"

if fleet.worker_index() == 0:
if fleet.is_first_worker():
donefile_path = output_path + "/" + donefile_name
xbox_str = _get_xbox_str(
model_path=model_path,
Expand Down Expand Up @@ -512,6 +504,7 @@ def write_xbox_donefile(output_path,
with open(donefile_path, "w") as f:
f.write(pre_content + "\n")
f.write(xbox_str + "\n")
fleet.barrier_worker()


def _get_xbox_str(model_path,
Expand Down

0 comments on commit ea60e40

Please sign in to comment.