From e23661c375a1554d658c8399c5eb6699bf9a977a Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Thu, 15 Dec 2016 19:22:23 -0800 Subject: [PATCH] Task table Redis module (#125) * Task table redis module implementation * Publish tasks and take in individual fields as args, not task object * Scheduling state integer has width 1, error on illegal put * Unit tests for task table and more documentation * Task table subscribe, fix publish topics and address Philipp and Alexey's comments * Helper function to create prefixed strings * Factor out the table prefixes in the test cases --- src/common/redis_module/Makefile | 2 +- src/common/redis_module/ray_redis_module.c | 271 ++++++++++++++++++++- src/common/redis_module/runtest.py | 63 ++++- 3 files changed, 322 insertions(+), 14 deletions(-) diff --git a/src/common/redis_module/Makefile b/src/common/redis_module/Makefile index d39424b53865..61e2332afadc 100644 --- a/src/common/redis_module/Makefile +++ b/src/common/redis_module/Makefile @@ -17,7 +17,7 @@ SHOBJ_CFLAGS += -I../thirdparty all: ray_redis_module.so .c.xo: - $(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) -fPIC -c $< -o $@ + $(CC) -I. -I.. $(CFLAGS) $(SHOBJ_CFLAGS) -fPIC -c $< -o $@ ray_redis_module.xo: redismodule.h diff --git a/src/common/redis_module/ray_redis_module.c b/src/common/redis_module/ray_redis_module.c index fd972c227011..b06b8cf38d21 100644 --- a/src/common/redis_module/ray_redis_module.c +++ b/src/common/redis_module/ray_redis_module.c @@ -23,20 +23,29 @@ #define OBJECT_INFO_PREFIX "OI:" #define OBJECT_LOCATION_PREFIX "OL:" #define OBJECT_SUBSCRIBE_PREFIX "OS:" +#define TASK_PREFIX "TT:" #define CHECK_ERROR(STATUS, MESSAGE) \ if ((STATUS) == REDISMODULE_ERR) { \ return RedisModule_ReplyWithError(ctx, (MESSAGE)); \ } -RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, - const char *prefix, - RedisModuleString *keyname, - int mode) { +RedisModuleString *CreatePrefixedString(RedisModuleCtx *ctx, + const char *prefix, + RedisModuleString *keyname) { size_t length; const char *value = RedisModule_StringPtrLen(keyname, &length); RedisModuleString *prefixed_keyname = RedisModule_CreateStringPrintf( ctx, "%s%*.*s", prefix, length, length, value); + return prefixed_keyname; +} + +RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, + const char *prefix, + RedisModuleString *keyname, + int mode) { + RedisModuleString *prefixed_keyname = + CreatePrefixedString(ctx, prefix, keyname); RedisModuleKey *key = RedisModule_OpenKey(ctx, prefixed_keyname, mode); RedisModule_FreeString(ctx, prefixed_keyname); return key; @@ -233,6 +242,7 @@ int ObjectTableAdd_RedisCommand(RedisModuleCtx *ctx, return RedisModule_ReplyWithError(ctx, "data_size must be integer"); } + /* Set the fields in the object info table. */ RedisModuleKey *key; key = OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, REDISMODULE_READ | REDISMODULE_WRITE); @@ -252,7 +262,9 @@ int ObjectTableAdd_RedisCommand(RedisModuleCtx *ctx, RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "hash", new_hash, NULL); RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "data_size", data_size, NULL); + RedisModule_CloseKey(key); + /* Add the location in the object location table. */ RedisModuleKey *table_key; table_key = OpenPrefixedKey(ctx, OBJECT_LOCATION_PREFIX, object_id, REDISMODULE_READ | REDISMODULE_WRITE); @@ -260,7 +272,12 @@ int ObjectTableAdd_RedisCommand(RedisModuleCtx *ctx, /* Sets are not implemented yet, so we use ZSETs instead. */ RedisModule_ZsetAdd(table_key, 0.0, manager, NULL); - /* Inform subscribers. */ + /* Build the PUBLISH topic and message for object table subscribers. The + * topic is a string in the format "OBJECT_LOCATION_PREFIX:". The + * message is a string in the format: " ... ". */ + RedisModuleString *publish_topic = + CreatePrefixedString(ctx, OBJECT_LOCATION_PREFIX, object_id); const char *MANAGERS = "MANAGERS"; RedisModuleString *publish = RedisModule_CreateString(ctx, MANAGERS, strlen(MANAGERS)); @@ -277,17 +294,15 @@ int ObjectTableAdd_RedisCommand(RedisModuleCtx *ctx, RedisModule_StringAppendBuffer(ctx, publish, val, size); } while (RedisModule_ZsetRangeNext(table_key)); - RedisModuleCallReply *reply; - reply = RedisModule_Call(ctx, "PUBLISH", "ss", object_id, publish); + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish); RedisModule_FreeString(ctx, publish); + RedisModule_FreeString(ctx, publish_topic); + RedisModule_CloseKey(table_key); if (reply == NULL) { return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); } - /* Clean up. */ - RedisModule_CloseKey(key); - RedisModule_CloseKey(table_key); - RedisModule_ReplyWithSimpleString(ctx, "OK"); return REDISMODULE_OK; } @@ -345,7 +360,7 @@ int ResultTableAdd_RedisCommand(RedisModuleCtx *ctx, } /** - * Add a new entry to the result table or update an existing one. + * Lookup an entry in the result table. * * This is called from a client with the command: * @@ -387,6 +402,214 @@ int ResultTableLookup_RedisCommand(RedisModuleCtx *ctx, return REDISMODULE_OK; } +int TaskTableWrite(RedisModuleCtx *ctx, + RedisModuleString *task_id, + RedisModuleString *state, + RedisModuleString *node_id, + RedisModuleString *task_spec) { + /* Pad the state integer to a fixed-width integer, and make sure it has width + * less than or equal to 2. */ + long long state_integer; + int status = RedisModule_StringToLongLong(state, &state_integer); + if (status != REDISMODULE_OK) { + return RedisModule_ReplyWithError( + ctx, "Invalid scheduling state (must be an integer)"); + } + state = RedisModule_CreateStringPrintf(ctx, "%2d", state_integer); + size_t length; + RedisModule_StringPtrLen(state, &length); + if (length != 2) { + return RedisModule_ReplyWithError( + ctx, "Invalid scheduling state width (must have width 2)"); + } + + /* Add the task to the task table. If no spec was provided, get the existing + * spec out of the task table so we can publish it. */ + RedisModuleKey *key = + OpenPrefixedKey(ctx, TASK_PREFIX, task_id, REDISMODULE_WRITE); + if (task_spec == NULL) { + RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, "node", + node_id, NULL); + RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "task_spec", &task_spec, + NULL); + if (task_spec == NULL) { + return RedisModule_ReplyWithError( + ctx, "Cannot update a task that doesn't exist yet"); + } + } else { + RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, "node", + node_id, "task_spec", task_spec, NULL); + } + RedisModule_CloseKey(key); + + /* Build the PUBLISH topic and message for task table subscribers. The topic + * is a string in the format "TASK_PREFIX::". The + * message is a string in the format: " ". */ + RedisModuleString *publish_topic = + CreatePrefixedString(ctx, TASK_PREFIX, node_id); + RedisModule_StringAppendBuffer(ctx, publish_topic, ":", strlen(":")); + const char *state_string = RedisModule_StringPtrLen(state, &length); + RedisModule_StringAppendBuffer(ctx, publish_topic, state_string, length); + /* Append the fields to the PUBLISH message. */ + RedisModuleString *publish_message = + RedisModule_CreateStringFromString(ctx, task_id); + const char *publish_field; + /* Append the scheduling state. */ + publish_field = state_string; + RedisModule_StringAppendBuffer(ctx, publish_message, " ", strlen(" ")); + RedisModule_StringAppendBuffer(ctx, publish_message, publish_field, length); + /* Append the node ID. */ + publish_field = RedisModule_StringPtrLen(node_id, &length); + RedisModule_StringAppendBuffer(ctx, publish_message, " ", strlen(" ")); + RedisModule_StringAppendBuffer(ctx, publish_message, publish_field, length); + /* Append the task specification. */ + publish_field = RedisModule_StringPtrLen(task_spec, &length); + RedisModule_StringAppendBuffer(ctx, publish_message, " ", strlen(" ")); + RedisModule_StringAppendBuffer(ctx, publish_message, publish_field, length); + + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message); + if (reply == NULL) { + return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); + } + + RedisModule_FreeString(ctx, publish_message); + RedisModule_FreeString(ctx, publish_topic); + RedisModule_ReplyWithSimpleString(ctx, "ok"); + + return REDISMODULE_OK; +} + +/** + * Add a new entry to the task table. This will overwrite any existing entry + * with the same task ID. + * + * This is called from a client with the command: + * + * RAY.task_table_add + * + * @param task_id A string that is the ID of the task. + * @param state A string that is the current scheduling state (a + * scheduling_state enum instance). The string's value must be a + * nonnegative integer less than 100, so that it has width at most 2. If + * less than 2, the value will be left-padded with spaces to a width of + * 2. + * @param node_id A string that is the ID of the associated node, if any. + * @param task_spec A string that is the specification of the task, which can + * be cast to a `task_spec`. + * @return OK if the operation was successful. + */ +int TaskTableAddTask_RedisCommand(RedisModuleCtx *ctx, + RedisModuleString **argv, + int argc) { + if (argc != 5) { + return RedisModule_WrongArity(ctx); + } + + return TaskTableWrite(ctx, argv[1], argv[2], argv[3], argv[4]); +} + +/** + * Update an entry in the task table. This does not update the task + * specification in the table. + * + * This is called from a client with the command: + * + * RAY.task_table_update_task + * + * @param task_id A string that is the ID of the task. + * @param state A string that is the current scheduling state (a + * scheduling_state enum instance). The string's value must be a + * nonnegative integer less than 100, so that it has width at most 2. If + * less than 2, the value will be left-padded with spaces to a width of + * 2. + * @param node_id A string that is the ID of the associated node, if any. + * @return OK if the operation was successful. + */ +int TaskTableUpdate_RedisCommand(RedisModuleCtx *ctx, + RedisModuleString **argv, + int argc) { + if (argc != 4) { + return RedisModule_WrongArity(ctx); + } + + return TaskTableWrite(ctx, argv[1], argv[2], argv[3], NULL); +} + +/** + * Get an entry from the task table. + * + * This is called from a client with the command: + * + * RAY.task_table_get + * + * @param task_id A string of the task ID to look up. + * @return An array of strings representing the task fields in the following + * order: 1) (integer) scheduling state 2) (string) associated node ID, + * if any 3) (string) the task specification, which can be casted to a + * task_spec. If the task ID is not in the table, returns nil. + */ +int TaskTableGetTask_RedisCommand(RedisModuleCtx *ctx, + RedisModuleString **argv, + int argc) { + if (argc != 2) { + return RedisModule_WrongArity(ctx); + } + + RedisModuleKey *key = + OpenPrefixedKey(ctx, TASK_PREFIX, argv[1], REDISMODULE_READ); + + int keytype = RedisModule_KeyType(key); + if (keytype != REDISMODULE_KEYTYPE_EMPTY) { + /* If the key exists, look up the fields and return them in an array. */ + RedisModuleString *state = NULL, *node = NULL, *task_spec = NULL; + RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", &state, "node", + &node, "task_spec", &task_spec, NULL); + if (state == NULL || node == NULL || task_spec == NULL) { + /* We must have either all fields or no fields. */ + return RedisModule_ReplyWithError( + ctx, "Missing fields in the task table entry"); + } + + size_t state_length; + const char *state_string = RedisModule_StringPtrLen(state, &state_length); + int state_integer; + int scanned = sscanf(state_string, "%2d", &state_integer); + if (scanned != 1 || state_length != 2) { + return RedisModule_ReplyWithError(ctx, + "Found invalid scheduling state (must " + "be an integer of width 2"); + } + + RedisModule_ReplyWithArray(ctx, 3); + RedisModule_ReplyWithLongLong(ctx, state_integer); + RedisModule_ReplyWithString(ctx, node); + RedisModule_ReplyWithString(ctx, task_spec); + + RedisModule_FreeString(ctx, task_spec); + RedisModule_FreeString(ctx, node); + RedisModule_FreeString(ctx, state); + } else { + /* If the key does not exist, return nil. */ + RedisModule_ReplyWithNull(ctx); + } + + RedisModule_CloseKey(key); + + return REDISMODULE_OK; +} + +int TaskTableSubscribe_RedisCommand(RedisModuleCtx *ctx, + RedisModuleString **argv, + int argc) { + /* TODO(swang): Implement this. */ + REDISMODULE_NOT_USED(ctx); + REDISMODULE_NOT_USED(argv); + REDISMODULE_NOT_USED(argc); + return REDISMODULE_OK; +} + /* This function must be present on each Redis module. It is used in order to * register the commands into the Redis server. */ int RedisModule_OnLoad(RedisModuleCtx *ctx, @@ -447,5 +670,29 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, return REDISMODULE_ERR; } + if (RedisModule_CreateCommand(ctx, "ray.task_table_add", + TaskTableAddTask_RedisCommand, "write pubsub", + 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + + if (RedisModule_CreateCommand(ctx, "ray.task_table_update", + TaskTableUpdate_RedisCommand, "write pubsub", 0, + 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + + if (RedisModule_CreateCommand(ctx, "ray.task_table_get", + TaskTableGetTask_RedisCommand, "readonly", 0, 0, + 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + + if (RedisModule_CreateCommand(ctx, "ray.task_table_subscribe", + TaskTableSubscribe_RedisCommand, "pubsub", 0, 0, + 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + return REDISMODULE_OK; } diff --git a/src/common/redis_module/runtest.py b/src/common/redis_module/runtest.py index 0ff9fdce360a..761051e877fb 100644 --- a/src/common/redis_module/runtest.py +++ b/src/common/redis_module/runtest.py @@ -18,6 +18,11 @@ module_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "ray_redis_module.so") print("path to the redis module is {}".format(module_path)) +OBJECT_INFO_PREFIX = "OI:" +OBJECT_LOCATION_PREFIX = "OL:" +OBJECT_SUBSCRIBE_PREFIX = "OS:" +TASK_PREFIX = "TT:" + class TestGlobalStateStore(unittest.TestCase): def setUp(self): @@ -72,7 +77,7 @@ def testObjectTableAddAndLookup(self): def testObjectTableSubscribe(self): p = self.redis.pubsub() # Subscribe to an object ID. - p.subscribe("object_id1") + p.psubscribe("{0}*".format(OBJECT_LOCATION_PREFIX)) self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1") # Receive the acknowledgement message. self.assertEqual(p.get_message()["data"], 1) @@ -92,5 +97,61 @@ def testResultTableAddAndLookup(self): response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id2") self.assertEqual(response, b"task_id2") + def testInvalidTaskTableAdd(self): + # Check that Redis returns an error when RAY.TASK_TABLE_ADD is called with + # the wrong arguments. + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.TASK_TABLE_ADD") + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello") + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3, "node_id") + with self.assertRaises(redis.ResponseError): + # Non-integer scheduling states should not be added. + self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", + "invalid_state", "node_id", "task_spec") + with self.assertRaises(redis.ResponseError): + # Scheduling states with invalid width should not be added. + self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 101, + "node_id", "task_spec") + with self.assertRaises(redis.ResponseError): + # Should not be able to update a non-existent task. + self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10, + "node_id") + + def testTaskTableAddAndLookup(self): + # Check that task table adds, updates, and lookups work correctly. + task_args = [1, "node_id", "task_spec"] + response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", + *task_args) + response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") + self.assertEqual(response, task_args) + + task_args[0] = 2 + self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:2]) + response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") + self.assertEqual(response, task_args) + + def testTaskTableSubscribe(self): + scheduling_state = 1 + node_id = "node_id" + # Subscribe to the task table. + p = self.redis.pubsub() + p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) + p.psubscribe("{prefix}*:{state: >2}".format(prefix=TASK_PREFIX, state=scheduling_state)) + p.psubscribe("{prefix}{node}:*".format(prefix=TASK_PREFIX, node=node_id)) + task_args = ["task_id", scheduling_state, node_id, "task_spec"] + self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) + # Receive the acknowledgement message. + self.assertEqual(p.get_message()["data"], 1) + self.assertEqual(p.get_message()["data"], 2) + self.assertEqual(p.get_message()["data"], 3) + # Receive the actual data. + for i in range(3): + message = p.get_message()["data"] + message = message.split() + message[1] = int(message[1]) + self.assertEqual(message, task_args) + if __name__ == "__main__": unittest.main(verbosity=2)