Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
137 commits
Select commit Hold shift + click to select a range
7f41242
cholesky uses dtrsm now
Vaishaal Feb 21, 2018
32601f6
lambdapack executor
Vaishaal Feb 23, 2018
5065376
local pipelined jobrunner
Vaishaal Feb 25, 2018
cfdefd8
added region name
Vaishaal Feb 25, 2018
eecc1e2
profiling works
Vaishaal Feb 26, 2018
82aea7f
job runner works on lambda as well
Vaishaal Feb 26, 2018
be04ee0
make cholesky test smaller
Vaishaal Feb 26, 2018
5019ffa
async jobrunner but untested probably fails
Vaishaal Feb 27, 2018
66e0c2d
BROKEN DO NOT USE ASYNC SUPPORT FOR NUMPYWREN
Vaishaal Feb 27, 2018
bfe6c6b
made numpywren stuff async
Vaishaal Feb 28, 2018
88e415e
async jobrunner test passes
Vaishaal Mar 1, 2018
8207592
switched to redis for program state fixes everything
Vaishaal Mar 3, 2018
1dd6668
Big commit
Vaishaal Mar 9, 2018
8055c3f
traceback logging
Vaishaal Mar 10, 2018
c1f81ff
add fast cholesky test
Vaishaal Mar 10, 2018
a96aac4
tiered scheduling
Vaishaal Mar 13, 2018
e54822e
2 level priorities
Vaishaal Mar 15, 2018
a42e66a
added triangular solve
k-rl Mar 16, 2018
f96d950
Merge branch 'cholesky_speedup' into trisolve
k-rl Mar 16, 2018
695d85f
hardcode region to avoid timeout
Vaishaal Mar 19, 2018
f0f1b33
exclude modules site-packages
Vaishaal Mar 19, 2018
a6b13df
performance makes sense
Vaishaal Mar 20, 2018
159eff4
Merge remote-tracking branch 'origin/cholesky_speedup' into trisolve
Mar 20, 2018
f17cc88
caching and dynamic node fusion
Vaishaal Mar 21, 2018
91a2b78
small scale strong scaling experiments work
Vaishaal Mar 23, 2018
d7e1421
Fixed opcode
Mar 26, 2018
107622a
cleaned a lo of the lambdapack code
Vaishaal Mar 29, 2018
a63944b
failure tests pass
Vaishaal Mar 29, 2018
a6cee5b
fast cholesky test added
Vaishaal Mar 29, 2018
844ae69
fixed dumb bug
Vaishaal Mar 29, 2018
514d503
fixed bug in pipelining
Vaishaal Mar 29, 2018
2e5e624
enable autoscaling
Apr 1, 2018
fdc0ee9
untested indexing
k-rl Mar 28, 2018
47af2ba
Added transpose to BigMatrixView.
k-rl Mar 28, 2018
972155d
BigMatrixView transpose passes tests.
k-rl Mar 28, 2018
bbcafe5
Passing simple indexing tests.
k-rl Mar 28, 2018
77b9ecc
Changed indexing to submatrix API.
k-rl Apr 1, 2018
8605787
Added slicing tests and fixed slicing bugs.
k-rl Apr 1, 2018
ca4deec
Added another indexing test.
k-rl Apr 1, 2018
ff395f4
cleaned a lo of the lambdapack code
Vaishaal Mar 29, 2018
9f775b1
failure tests pass
Vaishaal Mar 29, 2018
8c916b0
fast cholesky test added
Vaishaal Mar 29, 2018
91144ee
fixed dumb bug
Vaishaal Mar 29, 2018
2024b9e
fixed bug in pipelining
Vaishaal Mar 29, 2018
07ef5e7
eager_bug_hunt
Vaishaal Apr 2, 2018
ff112dc
This commit fixes a rather hairy bug which I will summarize here
Vaishaal Apr 3, 2018
dacfe03
added redis security
Vaishaal Apr 3, 2018
e768f4d
dumb bug fixes to get failure test to pass correctly
Vaishaal Apr 3, 2018
0fe3368
all local tests pass
Vaishaal Apr 3, 2018
03b06d7
moved profile after clear
Vaishaal Apr 3, 2018
0b7ab4f
Merge branch 'cholesky_speedup' into indexing
k-rl Apr 3, 2018
0714028
Readded matrix multiply transpose optimization.
k-rl Apr 4, 2018
0287c48
Merge branch 'indexing' of github.com:Vaishaal/numpywren into indexing
k-rl Apr 4, 2018
3581b29
added back transpose in lambdapack
k-rl Apr 4, 2018
3a9514a
Merge pull request #6 from Vaishaal/indexing
k-rl Apr 4, 2018
120d8a9
Merge branch 'cholesky_speedup' into trisolve
k-rl Apr 4, 2018
96d2037
Small fixes.
k-rl Apr 4, 2018
acf4fa2
end to end 65k solve works again lots of dumb dumb bugs fixed
Vaishaal Apr 4, 2018
42600ce
Merge branch 'cholesky_speedup' of https://github.com/vaishaal/numpyw…
Vaishaal Apr 4, 2018
f47acfc
Fixed BigMatrixView __str__ bugs.
k-rl Apr 5, 2018
d4373f0
Merge branch 'cholesky_speedup' of github.com:Vaishaal/numpywren into…
k-rl Apr 5, 2018
8f78aff
Changed redis to use conditional atomic counter
Vaishaal Apr 5, 2018
c5c956f
Merge branch 'cholesky_speedup' of https://github.com/vaishaal/numpyw…
Vaishaal Apr 5, 2018
bdc9f27
autoscaling
Apr 5, 2018
2de2e28
merging
Apr 5, 2018
4859b79
Fix minor errors.
Apr 5, 2018
d4be7fa
Redis atomic conditional sum not working properly
Vaishaal Apr 5, 2018
238232e
set condition_key to 1
Vaishaal Apr 5, 2018
e711ec0
atomic try v4
Vaishaal Apr 5, 2018
b1012a1
Fixed indexing bugs.
k-rl Apr 5, 2018
449d49e
Merge branch 'cholesky_speedup' of github.com:Vaishaal/numpywren into…
k-rl Apr 5, 2018
95265dd
Correct size for scratch matrices.
k-rl Apr 5, 2018
20ff3a2
conditional increment works
Vaishaal Apr 5, 2018
120b080
more bug fixes
Vaishaal Apr 6, 2018
3d418a2
add optimizations.py
Vaishaal Apr 6, 2018
f41d84c
fix autoscaling for wide pipeline.
Apr 7, 2018
890c2b7
fix timeout bug
Apr 8, 2018
e53c56e
fixed subtle bug in post op
Vaishaal Apr 8, 2018
446388c
Fixed bug in pipeling so now we get 30 percent bump from pipelining our
Vaishaal Apr 8, 2018
89365d3
finished optimizatons script
Vaishaal Apr 8, 2018
921fddf
update
Apr 8, 2018
6b20cad
Merge remote-tracking branch 'qifan/cholesky_speedup' into cholesky_s…
Vaishaal Apr 8, 2018
cc807f8
log flops
Vaishaal Apr 9, 2018
f0ff5c8
remove extra copies
Vaishaal Apr 9, 2018
50bfd36
optimizations some what work
Vaishaal Apr 9, 2018
7e81b5b
add an idle timeout
Vaishaal Apr 9, 2018
2cdda28
Added buggy tree reduce.
k-rl Apr 10, 2018
d452c4b
Merge remote-tracking branch 'origin/cholesky_speedup' into trisolve
k-rl Apr 10, 2018
63edd95
cleaned up optimization.py
Vaishaal Apr 10, 2018
61811f5
qifan job_runner qualms
Vaishaal Apr 10, 2018
eee44d1
failure experiments work
Vaishaal Apr 10, 2018
90af677
Added working reduce.
k-rl Apr 11, 2018
6076308
Merge branch 'cholesky_speedup' of github.com:Vaishaal/numpywren into…
k-rl Apr 11, 2018
fd9a35f
Re-add REDIS_IP env variable.
k-rl Apr 11, 2018
545b16f
shard the program so we can support larger programs
Vaishaal Apr 13, 2018
90b87c2
Made triangle solve DAG slightly smaller.
k-rl Apr 23, 2018
ae4c796
Merge branch 'trisolve' of github.com:Vaishaal/numpywren into trisolve
k-rl Apr 23, 2018
ab7c694
Reset default reduce width.
k-rl Apr 23, 2018
23c106d
Merge pull request #5 from Vaishaal/trisolve
k-rl Apr 23, 2018
7ce50f0
sped up shard program
Vaishaal Apr 25, 2018
77721f8
Merge branch 'cholesky_speedup' of https://github.com/vaishaal/numpyw…
Vaishaal Apr 25, 2018
dcb56c1
rudimentary config dumping working
Vaishaal May 25, 2018
86dbb3e
redis instance launches but doesn't launch redis
Vaishaal May 31, 2018
bf0e28b
iam hacking redis client works
Vaishaal May 31, 2018
d4bc84b
some untested control plane logic
Vaishaal Jun 2, 2018
6cffa15
a very rudimentary loops front end datastructure
Vaishaal Jun 6, 2018
892f47f
Added start of compilation backend and fixed tests.
k-rl Jun 7, 2018
88dcdbf
control plane cli utils work
Vaishaal Jun 7, 2018
97f1eb2
get rid of loops for this branch
Vaishaal Jun 7, 2018
a838843
More progress on compile backend.
k-rl Jun 8, 2018
7316e21
Non-working IR backend
k-rl Jun 12, 2018
b20d29c
Added compiler backend tests.
k-rl Jun 12, 2018
7dd46b2
More tests and bugfixes.
k-rl Jun 13, 2018
169cee1
More compiler backend test changes.
k-rl Jun 13, 2018
3114631
Added a way to find an expr's parents.
k-rl Jun 13, 2018
a289ba2
Updating lambdapack to use compiler backend.
k-rl Jun 14, 2018
37a2b3d
Starting to change job runner for compiler backend.
k-rl Jun 14, 2018
490b04c
Lambdapack/jobrunner converted to compiler backend (untested).
k-rl Jun 14, 2018
dfffc9b
Updated cholesky matrix inputs/outputs
k-rl Jun 14, 2018
803ca9d
backend compiler bug fixes
k-rl Jun 14, 2018
5104d15
Bug fixes.
k-rl Jun 14, 2018
6e73646
Got more cholesky tests passing.
k-rl Jun 15, 2018
fe89e53
Changing cholesky test for new api.
k-rl Jun 15, 2018
f1affe3
Merge branch 'redis_cleanup' into test_fixes
Vaishaal Jun 18, 2018
d3aa0c6
move compiler stuff to compile
Vaishaal Jun 21, 2018
b0209e2
dag enumeration so f**ing slow
Vaishaal Jun 22, 2018
58f20f0
basic ompiler tests pass but cholesky gives non psd error
Vaishaal Jun 24, 2018
7c65658
test multi works
Vaishaal Jun 26, 2018
bc340d0
test lambda_single works but is slow
Vaishaal Jun 28, 2018
21a40a1
all of test cholesky passes only clean up needed now
Vaishaal Jun 28, 2018
7648491
control plane logic threaded in
Vaishaal Jun 29, 2018
5d63242
all tests except trisolve and gemm pass
Vaishaal Jul 4, 2018
7414355
added cloud init template
Vaishaal Jul 4, 2018
1563b9d
all tests except test_triangle_solve passes
Vaishaal Jul 5, 2018
30c5aa5
reshard tests pass
Vaishaal Jul 9, 2018
1766ef4
fixed bug in synchronous get/put block implementations
Vaishaal Jul 12, 2018
e4af32b
launch separate thread in synchronous get/put/delete block to avoid l…
Vaishaal Jul 12, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 303 additions & 0 deletions experiments/failures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
import argparse
from numpywren import lambdapack as lp
import pywren
import concurrent.futures as fs
import hashlib
import numpy as np
from numpywren.matrix import BigMatrix
from numpywren.matrix_init import shard_matrix
from numpywren import job_runner
import numpywren.binops as binops
from pywren.serialize import serialize
import os
import time
import boto3
import redis
import pickle
import os
import hashlib
import matplotlib
# so plots work in headless mode
matplotlib.use('Agg')
import seaborn as sns
from pylab import plt
import logging
import copy


REDIS_ADDR = os.environ.get("REDIS_ADDR", "")
REDIS_PASS = os.environ.get("REDIS_PASS", "")
REDIS_PORT = os.environ.get("REDIS_PORT", "9001")
INFO_FREQ = 5

''' OSDI numpywren optimization effectiveness experiments '''

def run_experiment(problem_size, shard_size, pipeline, priority, lru, eager, truncate, max_cores, start_cores, trial, launch_granularity, timeout, log_granularity, autoscale_policy, failure_percentage, max_failure_events, failure_time):
# set up logging
logger = logging.getLogger()
for key in logging.Logger.manager.loggerDict:
logging.getLogger(key).setLevel(logging.CRITICAL)
logger.setLevel(logging.DEBUG)
arg_bytes = pickle.dumps((problem_size, shard_size, pipeline, priority, lru, eager, truncate, max_cores, start_cores, trial, launch_granularity, timeout, log_granularity, autoscale_policy, failure_percentage, max_failure_events, failure_time))
arg_hash = hashlib.md5(arg_bytes).hexdigest()
log_file = "failure_experiments/{0}.log".format(arg_hash)
fh = logging.FileHandler(log_file)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
logger.info("Logging to {0}".format(log_file))

X = np.random.randn(problem_size, 1)
pwex = pywren.default_executor()
shard_sizes = [shard_size, 1]
X_sharded = BigMatrix("cholesky_test_{0}_{1}".format(problem_size, shard_size), shape=X.shape, shard_sizes=shard_sizes, write_header=True)
shard_matrix(X_sharded, X)
print("Generating PSD matrix...")
XXT_sharded = binops.gemm(pwex, X_sharded, X_sharded.T, overwrite=False)
XXT_sharded.lambdav = problem_size*10
instructions ,L_sharded,trailing= lp._chol(XXT_sharded)
pipeline_width = args.pipeline
if (priority):
num_priorities = 5
else:
num_priorities = 1
if (lru):
cache_size = 5
else:
cache_size = 0

REDIS_CLIENT = redis.StrictRedis(REDIS_ADDR, port=REDIS_PORT, password=REDIS_PASS, db=0, socket_timeout=5)

if (truncate is not None):
instructions = instructions[:truncate]
config = pwex.config

program = lp.LambdaPackProgram(instructions, executor=pywren.lambda_executor, pywren_config=config, num_priorities=num_priorities, eager=eager)
redis_env ={"REDIS_ADDR": os.environ.get("REDIS_ADDR", ""), "REDIS_PASS": os.environ.get("REDIS_PASS", "")}


done_counts = []
ready_counts = []
post_op_counts = []
not_ready_counts = []
running_counts = []
sqs_invis_counts = []
sqs_vis_counts = []
up_workers_counts = []
busy_workers_counts = []
times = []
flops = []
reads = []
writes = []
failure_times = []
exp = {}
exp["redis_done_counts"] = done_counts
exp["redis_ready_counts"] = ready_counts
exp["redis_post_op_counts"] = post_op_counts
exp["redis_not_ready_counts"] = not_ready_counts
exp["redis_running_counts"] = running_counts
exp["sqs_invis_counts"] = sqs_invis_counts
exp["sqs_vis_counts"] = sqs_vis_counts
exp["busy_workers"] = busy_workers_counts
exp["up_workers"] = up_workers_counts
exp["times"] = times
exp["lru"] = lru
exp["priority"] = priority
exp["eager"] = eager
exp["truncate"] = truncate
exp["max_cores"] = max_cores
exp["problem_size"] = problem_size
exp["shard_size"] = shard_size
exp["pipeline"] = pipeline
exp["flops"] = flops
exp["reads"] = reads
exp["writes"] = writes
exp["trial"] = trial
exp["launch_granularity"] = launch_granularity
exp["log_granularity"] = log_granularity
exp["autoscale_policy"] = autoscale_policy
exp["failure_times"] = failure_times


logger.info("Longest Path: {0}".format(program.longest_path))
program.start()
t = time.time()
logger.info("Starting with {0} cores".format(start_cores))
failure_keys = ["{0}_failure_{1}_{2}".format(program.hash, i, 0) for i in range(start_cores)]
all_futures = pwex.map(lambda x: job_runner.lambdapack_run_with_failures(failure_keys[x], program, pipeline_width=pipeline_width, cache_size=cache_size, timeout=timeout), range(start_cores), extra_env=redis_env)
start_time = time.time()
last_run_time = start_time
last_failure = time.time()
num_failure_events = 0

while(program.program_status() == lp.PS.RUNNING):
curr_time = int(time.time() - start_time)
max_pc = program.get_max_pc()
times.append(int(time.time()))
time.sleep(log_granularity)
waiting = 0
running = 0
for i, queue_url in enumerate(program.queue_urls):
client = boto3.client('sqs')
attrs = client.get_queue_attributes(QueueUrl=queue_url, AttributeNames=['ApproximateNumberOfMessages', 'ApproximateNumberOfMessagesNotVisible'])['Attributes']
waiting += int(attrs["ApproximateNumberOfMessages"])
running += int(attrs["ApproximateNumberOfMessagesNotVisible"])
sqs_invis_counts.append(running)
sqs_vis_counts.append(waiting)
busy_workers = REDIS_CLIENT.get("{0}_busy".format(program.hash))
if (busy_workers == None):
busy_workers = 0
else:
busy_workers = int(busy_workers)
up_workers = program.get_up()

if (up_workers == None):
up_workers = 0
else:
up_workers = int(up_workers)
up_workers_counts.append(up_workers)
busy_workers_counts.append(busy_workers)

logger.debug("Waiting: {0}, Currently Processing: {1}".format(waiting, running))
logger.debug("{2}: Up Workers: {0}, Busy Workers: {1}".format(up_workers, busy_workers, curr_time))
if ((curr_time % INFO_FREQ) == 0):
logger.info("Max PC is {0}".format(max_pc))
logger.info("Waiting: {0}, Currently Processing: {1}".format(waiting, running))
logger.info("{2}: Up Workers: {0}, Busy Workers: {1}".format(up_workers, busy_workers, curr_time))

#print("{5}: Not Ready: {0}, Ready: {1}, Running: {4}, Post OP: {2}, Done: {3}".format(not_ready_count, ready_count, post_op_count, done_count, running_count, curr_time))
current_gflops = program.get_flops()
if (current_gflops is None):
current_gflops = 0
else:
current_gflops = int(current_gflops)/1e9

flops.append(current_gflops)
current_gbytes_read = program.get_read()
if (current_gbytes_read is None):
current_gbytes_read = 0
else:
current_gbytes_read = int(current_gbytes_read)/1e9

reads.append(current_gbytes_read)
current_gbytes_write = program.get_write()
if (current_gbytes_write is None):
current_gbytes_write = 0
else:
current_gbytes_write = int(current_gbytes_write)/1e9
writes.append(current_gbytes_write)
#print("{0}: Total GFLOPS {1}, Total GBytes Read {2}, Total GBytes Write {3}".format(curr_time, current_gflops, current_gbytes_read, current_gbytes_write))

time_since_launch = time.time() - last_run_time
if (autoscale_policy == "dynamic"):
if (time_since_launch > launch_granularity and up_workers < np.ceil(waiting*0.5/pipeline_width) and up_workers < max_cores):
cores_to_launch = int(min(np.ceil(waiting/pipeline_width) - up_workers, max_cores - up_workers))
logger.info("launching {0} new tasks....".format(cores_to_launch))
_failure_keys = ["{0}_failure_{1}_{2}".format(program.hash, i, curr_time) for i in range(cores_to_launch)]
new_futures = pwex.map(lambda x: job_runner.lambdapack_run_with_failures(_failure_keys[x], program, pipeline_width=pipeline_width, cache_size=cache_size, timeout=timeout), range(cores_to_launch), extra_env=redis_env)
last_run_time = time.time()
# check if we OOM-erred
# [x.result() for x in all_futures]
all_futures.extend(new_futures)
elif (autoscale_policy == "constant_timeout"):
if (time_since_launch > (0.75*timeout)):
cores_to_launch = max_cores
logger.info("launching {0} new tasks....".format(cores_to_launch))
_failure_keys = ["{0}_failure_{1}_{2}".format(program.hash, i, curr_time) for i in range(cores_to_launch)]
new_futures = pwex.map(lambda x: job_runner.lambdapack_run_with_failures(_failure_keys[x], program, pipeline_width=pipeline_width, cache_size=cache_size, timeout=timeout), range(cores_to_launch), extra_env=redis_env)
last_run_time = time.time()
failure_keys += _failure_keys
# check if we OOM-erred
# [x.result() for x in all_futures]
all_futures.extend(new_futures)
else:
raise Exception("unknown autoscale policy")

if ((time.time() - last_failure) > failure_time and num_failure_events < max_failure_events):
logging.info("Killing some jobs")
idxs = np.random.choice(len(failure_keys), int(failure_percentage*len(failure_keys)), replace=False)
num_failure_events += 1
last_failure = time.time()
failure_times.append(last_failure)
for i in idxs:
logging.info("Killing: job {0}".format(i))
REDIS_CLIENT.set(failure_keys[i], 1)






exp["all_futures"] = all_futures
for pc in range(program.num_inst_blocks):
run_count = REDIS_CLIENT.get("{0}_{1}_start".format(program.hash, pc))
if (run_count is None):
run_count = 0
else:
run_count = int(run_count)

if (run_count != 1):
logger.info("PC: {0}, Run Count: {1}".format(pc, run_count))

e = time.time()
logger.info(program.program_status())
logger.info("PROGRAM STATUS " + str(program.program_status()))
logger.info("PROGRAM HASH " + str(program.hash))
logger.info("Took {0} seconds".format(e - t))
exp["total_runtime"] = e - t
exp["num_failure_events"] = num_failure_events
# collect in
executor = fs.ThreadPoolExecutor(72)
futures = []
for i in range(0,program.num_inst_blocks,1):
futures.append(executor.submit(program.get_profiling_info, i))
res = fs.wait(futures)
profiled_blocks = [f.result() for f in futures]
serializer = serialize.SerializeIndependent()
byte_string = serializer([profiled_blocks])[0][0]
exp["profiled_block_pickle_bytes"] = byte_string

read,write,total_flops,bins, instructions, runtimes = lp.perf_profile(profiled_blocks, num_bins=100)
flop_rate = sum(total_flops)/max(bins)
exp["flop_rate"] = flop_rate
print("Average Flop rate of {0}".format(flop_rate))
# save other stuff
try:
os.mkdir("failure_experiments/")
except FileExistsError:
pass
exp_bytes = pickle.dumps(exp)
dump_path = "failure_experiments/{0}.pickle".format(arg_hash)
print("Dumping experiment pickle to {0}".format(dump_path))
with open(dump_path, "wb+") as f:
f.write(exp_bytes)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run OSDI optimization effectiveness experiments')
parser.add_argument("problem_size", type=int)
parser.add_argument("--shard_size", type=int, default=4096)
parser.add_argument('--truncate', type=int, default=None)
parser.add_argument('--max_cores', type=int, default=32)
parser.add_argument('--start_cores', type=int, default=32)
parser.add_argument('--pipeline', type=int, default=1)
parser.add_argument('--failure_percentage', type=float, default=0.25)
parser.add_argument('--max_failure_events', type=float, default=1)
parser.add_argument('--failure_time', type=int, default=60)
parser.add_argument('--timeout', type=int, default=200)
parser.add_argument('--autoscale_policy', type=str, default="dynamic")
parser.add_argument('--log_granularity', type=int, default=1)
parser.add_argument('--launch_granularity', type=int, default=10)
parser.add_argument('--trial', type=int, default=0)
parser.add_argument('--priority', action='store_true')
parser.add_argument('--lru', action='store_true')
parser.add_argument('--eager', action='store_true')
args = parser.parse_args()
run_experiment(args.problem_size, args.shard_size, args.pipeline, args.priority, args.lru, args.eager, args.truncate, args.max_cores, args.start_cores, args.trial, args.launch_granularity, args.timeout, args.log_granularity, args.autoscale_policy, args.failure_percentage, args.max_failure_events, args.failure_time)



Loading