-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_worker.py
74 lines (62 loc) · 2.45 KB
/
run_worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import protobuf.roll_out_service_pb2_grpc
import gin
import grpc
import time
import misc.utility
from concurrent import futures
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
_MAX_MSG_LEN = 40 * 1024 * 1024
def main(config):
"""Start the worker."""
gin.parse_config_file(config.config)
logger = misc.utility.create_logger(
name='es_worker{}'.format(config.worker_id), log_dir=config.log_dir)
if config.master_address is not None:
logger.info('master_address: {}'.format(config.master_address))
channel = grpc.insecure_channel(
config.master_address,
[("grpc.max_receive_message_length", _MAX_MSG_LEN)])
stub = protobuf.roll_out_service_pb2_grpc.ParameterSyncServiceStub(
channel)
worker = misc.utility.get_es_worker(logger=logger, load_model_path=config.restore_model, master=stub)
else:
worker = misc.utility.get_es_worker(logger=logger, load_model_path=config.restore_model)
if config.run_on_gke:
port = config.port
else:
port = config.port + config.worker_id
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
options=[("grpc.max_send_message_length", _MAX_MSG_LEN),
("grpc.max_receive_message_length", _MAX_MSG_LEN)])
# Start the RPC server.
protobuf.roll_out_service_pb2_grpc.add_RollOutServiceServicer_to_server(
worker, server)
server.add_insecure_port('[::]:{}'.format(port))
server.start()
logger.info('Listening to port {} ...'.format(port))
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
logger.info('Worker quit.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--port', help='Port to start the service.', type=int, default=20000)
parser.add_argument(
'--config', help='Path to the config file.')
parser.add_argument(
'--restore-model', help='Path to existing model file.', default=None)
parser.add_argument(
'--log-dir', help='Path to the log directory.', default='./log')
parser.add_argument(
'--worker-id', help='Worker ID.', type=int, default=0)
parser.add_argument(
'--master-address', help='Master address.')
parser.add_argument(
'--run-on-gke', help='Whether run this on GKE.', default=False,
action='store_true')
args, _ = parser.parse_known_args()
main(args)