Skip to content

Commit

Permalink
Merge pull request ray-project#7 from amplab/tests
Browse files Browse the repository at this point in the history
Tests
  • Loading branch information
pcmoritz committed Mar 10, 2016
2 parents b63d201 + d2aa71d commit 2535d26
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 55 deletions.
37 changes: 17 additions & 20 deletions lib/orchpy/orchpy/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,39 @@

def cleanup():
global all_processes
for p, port in all_processes:
for p, address in all_processes:
if p.poll() is not None: # process has already terminated
print "Process at port " + str(port) + " has already terminated."
print "Process at address " + address + " has already terminated."
continue
print "Attempting to kill process at port " + str(port) + "."
print "Attempting to kill process at address " + address + "."
p.kill()
time.sleep(0.05) # is this necessary?
if p.poll() is not None:
print "Successfully killed process at port " + str(port) + "."
print "Successfully killed process at address " + address + "."
continue
print "Kill attempt failed, attempting to terminate process at port " + str(port) + "."
print "Kill attempt failed, attempting to terminate process at address " + address + "."
p.terminate()
time.sleep(0.05) # is this necessary?
if p.poll is not None:
print "Successfully terminated process at port " + str(port) + "."
print "Successfully terminated process at address " + address + "."
continue
print "Termination attempt failed, giving up."
all_processes = []

atexit.register(cleanup)

def start_scheduler(host, port):
scheduler_address = host + ":" + str(port)
p = subprocess.Popen([os.path.join(_services_path, "scheduler"), str(scheduler_address)])
all_processes.append((p, port))
def start_scheduler(scheduler_address):
p = subprocess.Popen([os.path.join(_services_path, "scheduler"), scheduler_address])
all_processes.append((p, scheduler_address))

def start_objstore(host, port):
objstore_address = host + ":" + str(port)
p = subprocess.Popen([os.path.join(_services_path, "objstore"), str(objstore_address)])
all_processes.append((p, port))
def start_objstore(objstore_address):
p = subprocess.Popen([os.path.join(_services_path, "objstore"), objstore_address])
all_processes.append((p, objstore_address))

def start_worker(test_path, host, scheduler_port, worker_port, objstore_port):
def start_worker(test_path, scheduler_address, objstore_address, worker_address):
p = subprocess.Popen(["python",
test_path,
"--ip_address=" + host,
"--scheduler_port=" + str(scheduler_port),
"--objstore_port=" + str(objstore_port),
"--worker_port=" + str(worker_port)])
all_processes.append((p, worker_port))
"--scheduler-address=" + scheduler_address,
"--objstore-address=" + objstore_address,
"--worker-address=" + worker_address])
all_processes.append((p, worker_address))
24 changes: 12 additions & 12 deletions lib/orchpy/orchpy/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,24 @@ def func_call(*args):
def get_arguments_for_execution(function, args, worker=global_worker):
arguments = []
# check the number of args
if len(args) != len(function.types) and function.types[-1] is not None:
raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.types), len(args)))
elif len(args) < len(function.types) - 1 and function.types[-1] is None:
raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.types) - 1, len(args)))
if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))

for (i, arg) in enumerate(args):
print "Pulling argument {} for function {}.".format(i, function.__name__)
if i < len(function.types) - 1:
expected_type = function.types[i]
elif i == len(function.types) - 1 and function.types[-1] is not None:
expected_type = function.types[-1]
elif function.types[-1] is None and len(function.types > 1):
expected_type = function.types[-2]
if i < len(function.arg_types) - 1:
expected_type = function.arg_types[i]
elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None:
expected_type = function.arg_types[-1]
elif function.arg_types[-1] is None and len(function.arg_types > 1):
expected_type = function.arg_types[-2]
else:
assert False, "This code should be unreachable."

argument = worker.get_object(arg) if type(arg) == orchpy.ObjRef else arg
if type(arg) == orchpy.ObjRef:
argument = worker.get_object(arg) if type(arg) == orchpy.lib.ObjRef else arg
if type(arg) == orchpy.lib.ObjRef:
# get the object from the local object store
# TODO(rkn): Do we know that it is already there? Maybe we should call pull(arg, worker).
argument = worker.get_object(arg)
Expand Down
61 changes: 54 additions & 7 deletions test/runtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def testObjStore(self):
worker1_port = new_worker_port()
worker2_port = new_worker_port()

services.start_scheduler(IP_ADDRESS, scheduler_port)
services.start_objstore(IP_ADDRESS, objstore1_port)
services.start_objstore(IP_ADDRESS, objstore2_port)
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
services.start_objstore(address(IP_ADDRESS, objstore1_port))
services.start_objstore(address(IP_ADDRESS, objstore2_port))

time.sleep(0.2)

Expand Down Expand Up @@ -110,8 +110,8 @@ def testCall(self):
worker1_port = new_worker_port()
worker2_port = new_worker_port()

services.start_scheduler(IP_ADDRESS, scheduler_port)
services.start_objstore(IP_ADDRESS, objstore_port)
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
services.start_objstore(address(IP_ADDRESS, objstore_port))

time.sleep(0.2)

Expand All @@ -125,17 +125,64 @@ def testCall(self):

test_dir = os.path.dirname(os.path.abspath(__file__))
test_path = os.path.join(test_dir, "testrecv.py")
services.start_worker(test_path, IP_ADDRESS, scheduler_port, worker2_port, objstore_port)
services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port))

time.sleep(0.2)

worker1.remote_call("print_string", ["hi"])
value_before = "test_string"
objref = worker1.remote_call("__main__.print_string", [value_before])

time.sleep(0.2)

# value_after = worker.pull(objref, worker1)
# self.assertEqual(value_before, value_after)

time.sleep(0.1)

reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS)

services.cleanup()

class WorkerTest(unittest.TestCase):

def testPushPull(self):
scheduler_port = new_scheduler_port()
objstore_port = new_objstore_port()
worker1_port = new_worker_port()

services.start_scheduler(address(IP_ADDRESS, scheduler_port))
services.start_objstore(address(IP_ADDRESS, objstore_port))

time.sleep(0.2)

worker1 = worker.Worker()
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)

for i in range(100):
value_before = i * 10 ** 6
objref = worker.push(value_before, worker1)
value_after = worker.pull(objref, worker1)
self.assertEqual(value_before, value_after)

for i in range(100):
value_before = i * 10 ** 6 * 1.0
objref = worker.push(value_before, worker1)
value_after = worker.pull(objref, worker1)
self.assertEqual(value_before, value_after)

for i in range(100):
value_before = "h" * i
objref = worker.push(value_before, worker1)
value_after = worker.pull(objref, worker1)
self.assertEqual(value_before, value_after)

for i in range(100):
value_before = [1] * i
objref = worker.push(value_before, worker1)
value_after = worker.pull(objref, worker1)
self.assertEqual(value_before, value_after)

services.cleanup()

if __name__ == '__main__':
unittest.main()
15 changes: 7 additions & 8 deletions test/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import types_pb2

parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.')
parser.add_argument("--ip_address", default="127.0.0.1", help="the IP address to use for both the scheduler and objstore")
parser.add_argument("--scheduler_port", default=10001, type=int, help="the scheduler's port")
parser.add_argument("--objstore_port", default=20001, type=int, help="the objstore's port")
parser.add_argument("--worker_port", default=40001, type=int, help="the worker's port")
parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address")
parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address")
parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address")

@worker.distributed([str], [str])
def print_string(string):
Expand All @@ -24,16 +23,16 @@ def print_string(string):
def handle_int(a, b):
return a + 1, b + 1

def connect_to_scheduler(host, port):
channel = implementations.insecure_channel(host, port)
def connect_to_scheduler(address):
channel = implementations.insecure_channel(address)
return orchestra_pb2.beta_create_Scheduler_stub(channel)

def address(host, port):
return host + ":" + str(port)

if __name__ == '__main__':
args = parser.parse_args()
scheduler_stub = connect_to_scheduler(args.ip_address, args.scheduler_port)
worker.connect(address(args.ip_address, args.scheduler_port), address(args.ip_address, args.objstore_port), address(args.ip_address, args.worker_port))
scheduler_stub = connect_to_scheduler(args.scheduler_address)
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address))
import IPython
IPython.embed()
12 changes: 4 additions & 8 deletions test/testrecv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import orchpy.worker as worker

parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.')
parser.add_argument("--ip_address", default="127.0.0.1", help="the IP address to use for both the scheduler and objstore")
parser.add_argument("--scheduler_port", default=10001, type=int, help="the scheduler's port")
parser.add_argument("--objstore_port", default=20001, type=int, help="the objstore's port")
parser.add_argument("--worker_port", default=40001, type=int, help="the worker's port")
parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address")
parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address")
parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address")

@worker.distributed([str], [str])
def print_string(string):
Expand All @@ -21,12 +20,9 @@ def print_string(string):
def handle_int(a, b):
return a + 1, b + 1

def address(host, port):
return host + ":" + str(port)

if __name__ == '__main__':
args = parser.parse_args()
worker.connect(address(args.ip_address, args.scheduler_port), address(args.ip_address, args.objstore_port), address(args.ip_address, args.worker_port))
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)

worker.global_worker.register_function(print_string)
worker.global_worker.register_function(handle_int)
Expand Down

0 comments on commit 2535d26

Please sign in to comment.