Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions python/paddle/v2/master/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import ctypes
import os

path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
lib = ctypes.cdll.LoadLibrary(path)
__lib__ = None


def get_c_lib():
global __lib__
if __lib__ is None:
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
__lib__ = ctypes.cdll.LoadLibrary(path)
return __lib__


class client(object):
Expand All @@ -11,8 +18,8 @@ class client(object):
"""

def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout_sec,
buf_size)
self.c = get_c_lib().paddle_new_etcd_master_client(
etcd_endpoints, timeout_sec, buf_size)

def request_save_model(self, trainer_id, block_ms):
"""request to save model
Expand All @@ -32,10 +39,11 @@ def request_save_model(self, trainer_id, block_ms):
saving the model, -1 if error happened.

"""
return lib.paddle_request_save_model(self.c, trainer_id, block_ms)
return get_c_lib().paddle_request_save_model(self.c, trainer_id,
block_ms)

def release(self):
lib.paddle_release_master_client(self.c)
get_c_lib().paddle_release_master_client(self.c)
self.c = None

def set_dataset(self, paths):
Expand All @@ -45,7 +53,7 @@ def set_dataset(self, paths):
for idx, path in enumerate(paths):
c_ptr = ctypes.c_char_p(path)
holder[idx] = c_ptr
lib.paddle_set_dataset(self.c, holder, len(paths))
get_c_lib().paddle_set_dataset(self.c, holder, len(paths))

def next_record(self):
"""gets next record for training
Expand All @@ -56,7 +64,7 @@ def next_record(self):
"""
p = ctypes.c_char_p()
ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret)
size = get_c_lib().paddle_next_record(self.c, ret)
if size < 0:
# Error
return None, size
Expand All @@ -67,5 +75,5 @@ def next_record(self):

record = ret.contents.value[:size]
# Memory created from C should be freed.
lib.mem_free(ret.contents)
get_c_lib().mem_free(ret.contents)
return record, 0