diff --git a/CMakeLists.txt b/CMakeLists.txt index a70da76f8f0b5..bc25da38d5e49 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,17 @@ enable_testing() include(ThirdpartyToolchain) + +# TODO(rkn): Fix all of this. This include is needed for the following +# reason. The local scheduler depends on tables.cc which depends on +# node_manager_generated.h which depends on gcs_generated.h. However, +# the include statement for gcs_generated.h doesn't include the file +# path, so we include the relevant directory here. +set(GCS_FBS_OUTPUT_DIRECTORY + "${CMAKE_CURRENT_LIST_DIR}/src/ray/gcs/format") +include_directories(${GCS_FBS_OUTPUT_DIRECTORY}) + + include_directories(SYSTEM ${ARROW_INCLUDE_DIR}) include_directories(SYSTEM ${PLASMA_INCLUDE_DIR}) include_directories("${CMAKE_CURRENT_LIST_DIR}/src/") diff --git a/doc/source/conf.py b/doc/source/conf.py index 68551d2fe1436..4c2488ce3010f 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -48,6 +48,7 @@ "ray.core.generated.GcsTableEntry", "ray.core.generated.HeartbeatTableData", "ray.core.generated.ErrorTableData", + "ray.core.generated.ProfileTableData", "ray.core.generated.ObjectTableData", "ray.core.generated.ray.protocol.Task", "ray.core.generated.TablePrefix", diff --git a/python/ray/__init__.py b/python/ray/__init__.py index b3a538e2ad884..f0a7fcec1b53d 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -48,7 +48,7 @@ from ray.local_scheduler import ObjectID, _config # noqa: E402 from ray.worker import (error_info, init, connect, disconnect, get, put, wait, - remote, log_event, log_span, flush_log, get_gpu_ids, + remote, profile, flush_profile_data, get_gpu_ids, get_resource_ids, get_webui_url, register_custom_serializer) # noqa: E402 from ray.worker import (SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, @@ -65,7 +65,7 @@ __all__ = [ "error_info", "init", "connect", "disconnect", "get", "put", "wait", - "remote", "log_event", "log_span", "flush_log", "actor", "method", + "remote", "profile", "flush_profile_data", "actor", "method", "get_gpu_ids", "get_resource_ids", "get_webui_url", "register_custom_serializer", "SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE", "global_state", "ObjectID", "_config", "__version__" diff --git a/python/ray/actor.py b/python/ray/actor.py index b7499a95f3b33..9af6c6c36940d 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -14,6 +14,7 @@ import ray.signature as signature import ray.worker from ray.utils import ( + decode, _random_string, check_oversized_pickle, is_cython, @@ -292,10 +293,10 @@ def fetch_and_register_actor(actor_class_key, worker): "checkpoint_interval", "actor_method_names" ]) - class_name = class_name.decode("ascii") - module = module.decode("ascii") + class_name = decode(class_name) + module = decode(module) checkpoint_interval = int(checkpoint_interval) - actor_method_names = json.loads(actor_method_names.decode("ascii")) + actor_method_names = json.loads(decode(actor_method_names)) # Create a temporary actor with some temporary methods so that if the actor # fails to be unpickled, the temporary actor can be used (just to produce diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 42f1de938e2bf..cc10ddb5bc117 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -357,8 +357,6 @@ def _task_table(self, task_id): task_table_message = ray.gcs_utils.Task.GetRootAsTask( gcs_entries.Entries(i), 0) - task_table_message = ray.gcs_utils.Task.GetRootAsTask( - gcs_entries.Entries(0), 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() task_spec = ray.local_scheduler.task_from_string(task_spec) @@ -487,11 +485,10 @@ def client_table(self): decode(value)) elif client_info[b"client_type"] == b"local_scheduler": # The remaining fields are resource types. - client_info_parsed[field.decode("ascii")] = float( + client_info_parsed[decode(field)] = float( decode(value)) else: - client_info_parsed[field.decode("ascii")] = decode( - value) + client_info_parsed[decode(field)] = decode(value) node_info[node_ip_address].append(client_info_parsed) @@ -513,21 +510,19 @@ def client_table(self): gcs_entry.Entries(i), 0)) resources = { - client.ResourcesTotalLabel(i).decode("ascii"): + decode(client.ResourcesTotalLabel(i)): client.ResourcesTotalCapacity(i) for i in range(client.ResourcesTotalLabelLength()) } node_info.append({ "ClientID": ray.utils.binary_to_hex(client.ClientId()), "IsInsertion": client.IsInsertion(), - "NodeManagerAddress": client.NodeManagerAddress().decode( - "ascii"), + "NodeManagerAddress": decode(client.NodeManagerAddress()), "NodeManagerPort": client.NodeManagerPort(), "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": client.ObjectStoreSocketName() - .decode("ascii"), - "RayletSocketName": client.RayletSocketName().decode( - "ascii"), + "ObjectStoreSocketName": decode( + client.ObjectStoreSocketName()), + "RayletSocketName": decode(client.RayletSocketName()), "Resources": resources }) return node_info @@ -543,14 +538,14 @@ def log_files(self): ip_filename_file = {} for filename in relevant_files: - filename = filename.decode("ascii") + filename = decode(filename) filename_components = filename.split(":") ip_addr = filename_components[1] file = self.redis_client.lrange(filename, 0, -1) file_str = [] for x in file: - y = x.decode("ascii") + y = decode(x) file_str.append(y) if ip_addr not in ip_filename_file: @@ -630,7 +625,7 @@ def task_profiles(self, num_tasks, start=None, end=None, fwd=True): event_log_set, **params) for (event, score) in event_list: - event_dict = json.loads(event.decode()) + event_dict = json.loads(decode(event)) task_id = "" for event in event_dict: if "task_id" in event[3]: @@ -643,31 +638,29 @@ def task_profiles(self, num_tasks, start=None, end=None, fwd=True): heap_size += 1 for event in event_dict: - if event[1] == "ray:get_task" and event[2] == 1: + if event[1] == "get_task" and event[2] == 1: task_info[task_id]["get_task_start"] = event[0] - if event[1] == "ray:get_task" and event[2] == 2: + if event[1] == "get_task" and event[2] == 2: task_info[task_id]["get_task_end"] = event[0] - if (event[1] == "ray:import_remote_function" + if (event[1] == "register_remote_function" and event[2] == 1): task_info[task_id]["import_remote_start"] = event[0] - if (event[1] == "ray:import_remote_function" + if (event[1] == "register_remote_function" and event[2] == 2): task_info[task_id]["import_remote_end"] = event[0] - if event[1] == "ray:acquire_lock" and event[2] == 1: - task_info[task_id]["acquire_lock_start"] = event[0] - if event[1] == "ray:acquire_lock" and event[2] == 2: - task_info[task_id]["acquire_lock_end"] = event[0] - if event[1] == "ray:task:get_arguments" and event[2] == 1: + if (event[1] == "task:deserialize_arguments" + and event[2] == 1): task_info[task_id]["get_arguments_start"] = event[0] - if event[1] == "ray:task:get_arguments" and event[2] == 2: + if (event[1] == "task:deserialize_arguments" + and event[2] == 2): task_info[task_id]["get_arguments_end"] = event[0] - if event[1] == "ray:task:execute" and event[2] == 1: + if event[1] == "task:execute" and event[2] == 1: task_info[task_id]["execute_start"] = event[0] - if event[1] == "ray:task:execute" and event[2] == 2: + if event[1] == "task:execute" and event[2] == 2: task_info[task_id]["execute_end"] = event[0] - if event[1] == "ray:task:store_outputs" and event[2] == 1: + if event[1] == "task:store_outputs" and event[2] == 1: task_info[task_id]["store_outputs_start"] = event[0] - if event[1] == "ray:task:store_outputs" and event[2] == 2: + if event[1] == "task:store_outputs" and event[2] == 2: task_info[task_id]["store_outputs_end"] = event[0] if "worker_id" in event[3]: task_info[task_id]["worker_id"] = event[3]["worker_id"] @@ -685,6 +678,173 @@ def task_profiles(self, num_tasks, start=None, end=None, fwd=True): return task_info + def _profile_table(self, component_id): + """Get the profile events for a given component. + + Args: + component_id: An identifier for a component. + + Returns: + A list of the profile events for the specified process. + """ + # TODO(rkn): This method should support limiting the number of log + # events and should also support returning a window of events. + message = self._execute_command(component_id, "RAY.TABLE_LOOKUP", + ray.gcs_utils.TablePrefix.PROFILE, "", + component_id.id()) + + if message is None: + return [] + + gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) + + profile_events = [] + for i in range(gcs_entries.EntriesLength()): + profile_table_message = ( + ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData( + gcs_entries.Entries(i), 0)) + + component_type = decode(profile_table_message.ComponentType()) + component_id = binary_to_hex(profile_table_message.ComponentId()) + node_ip_address = decode(profile_table_message.NodeIpAddress()) + + for j in range(profile_table_message.ProfileEventsLength()): + profile_event_message = profile_table_message.ProfileEvents(j) + + profile_event = { + "event_type": decode(profile_event_message.EventType()), + "component_id": component_id, + "node_ip_address": node_ip_address, + "component_type": component_type, + "start_time": profile_event_message.StartTime(), + "end_time": profile_event_message.EndTime(), + "extra_data": json.loads( + decode(profile_event_message.ExtraData())), + } + + profile_events.append(profile_event) + + return profile_events + + def profile_table(self): + if not self.use_raylet: + raise Exception("This method is only supported in the raylet " + "code path.") + + profile_table_keys = self._keys( + ray.gcs_utils.TablePrefix_PROFILE_string + "*") + component_identifiers_binary = [ + key[len(ray.gcs_utils.TablePrefix_PROFILE_string):] + for key in profile_table_keys + ] + + return { + binary_to_hex(component_id): self._profile_table( + binary_to_object_id(component_id)) + for component_id in component_identifiers_binary + } + + def chrome_tracing_dump(self, + include_task_data=False, + filename=None, + open_browser=False): + """Return a list of profiling events that can viewed as a timeline. + + To view this information as a timeline, simply dump it as a json file + using json.dumps, and then load go to chrome://tracing in the Chrome + web browser and load the dumped file. Make sure to enable "Flow events" + in the "View Options" menu. + + Args: + include_task_data: If true, we will include more task metadata such + as the task specifications in the json. + filename: If a filename is provided, the timeline is dumped to that + file. + open_browser: If true, we will attempt to automatically open the + timeline visualization in Chrome. + + Returns: + If filename is not provided, this returns a list of profiling + events. Each profile event is a dictionary. + """ + # TODO(rkn): Support including the task specification data in the + # timeline. + # TODO(rkn): This should support viewing just a window of time or a + # limited number of events. + + if include_task_data: + raise NotImplementedError("This flag has not been implented yet.") + + if open_browser: + raise NotImplementedError("This flag has not been implented yet.") + + profile_table = self.profile_table() + all_events = [] + + # Colors are specified at + # https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501 + default_color_mapping = defaultdict( + lambda: "generic_work", { + "get_task": "cq_build_abandoned", + "task": "rail_response", + "task:deserialize_arguments": "rail_load", + "task:execute": "rail_animation", + "task:store_outputs": "rail_idle", + "wait_for_function": "detailed_memory_dump", + "ray.get": "good", + "ray.put": "terrible", + "ray.wait": "vsync_highlight_color", + "submit_task": "background_memory_dump", + "fetch_and_run_function": "detailed_memory_dump", + "register_remote_function": "detailed_memory_dump", + }) + + def seconds_to_microseconds(time_in_seconds): + time_in_microseconds = 10**6 * time_in_seconds + return time_in_microseconds + + for component_id_hex, component_events in profile_table.items(): + for event in component_events: + new_event = { + # The category of the event. + "cat": event["event_type"], + # The string displayed on the event. + "name": event["event_type"], + # The identifier for the group of rows that the event + # appears in. + "pid": event["node_ip_address"], + # The identifier for the row that the event appears in. + "tid": event["component_type"] + ":" + + event["component_id"], + # The start time in microseconds. + "ts": seconds_to_microseconds(event["start_time"]), + # The duration in microseconds. + "dur": seconds_to_microseconds(event["end_time"] - + event["start_time"]), + # What is this? + "ph": "X", + # This is the name of the color to display the box in. + "cname": default_color_mapping[event["event_type"]], + # The extra user-defined data. + "args": event["extra_data"], + } + + # Modify the json with the additional user-defined extra data. + # This can be used to add fields or override existing fields. + if "cname" in event["extra_data"]: + new_event["cname"] = event["extra_data"]["cname"] + if "name" in event["extra_data"]: + new_event["name"] = event["extra_data"]["name"] + + all_events.append(new_event) + + if filename is not None: + with open(filename, "w") as outfile: + json.dump(all_events, outfile) + else: + return all_events + def dump_catapult_trace(self, path, task_info, @@ -1047,21 +1207,20 @@ def workers(self): worker_id = binary_to_hex(worker_key[len("Workers:"):]) workers_data[worker_id] = { - "local_scheduler_socket": ( - worker_info[b"local_scheduler_socket"].decode("ascii")), - "node_ip_address": (worker_info[b"node_ip_address"] - .decode("ascii")), - "plasma_manager_socket": (worker_info[b"plasma_manager_socket"] - .decode("ascii")), - "plasma_store_socket": (worker_info[b"plasma_store_socket"] - .decode("ascii")) + "local_scheduler_socket": (decode( + worker_info[b"local_scheduler_socket"])), + "node_ip_address": decode(worker_info[b"node_ip_address"]), + "plasma_manager_socket": decode( + worker_info[b"plasma_manager_socket"]), + "plasma_store_socket": decode( + worker_info[b"plasma_store_socket"]) } if b"stderr_file" in worker_info: - workers_data[worker_id]["stderr_file"] = ( - worker_info[b"stderr_file"].decode("ascii")) + workers_data[worker_id]["stderr_file"] = decode( + worker_info[b"stderr_file"]) if b"stdout_file" in worker_info: - workers_data[worker_id]["stdout_file"] = ( - worker_info[b"stdout_file"].decode("ascii")) + workers_data[worker_id]["stdout_file"] = decode( + worker_info[b"stdout_file"]) return workers_data def actors(self): @@ -1155,8 +1314,8 @@ def _error_messages(self, job_id): error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( gcs_entries.Entries(i), 0) error_message = { - "type": error_data.Type().decode("ascii"), - "message": error_data.ErrorMessage().decode("ascii"), + "type": decode(error_data.Type()), + "message": decode(error_data.ErrorMessage()), "timestamp": error_data.Timestamp(), } error_messages.append(error_message) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 708093f212eb7..c9fa5e2c69d43 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -22,6 +22,7 @@ from ray.core.generated.GcsTableEntry import GcsTableEntry from ray.core.generated.ClientTableData import ClientTableData from ray.core.generated.ErrorTableData import ErrorTableData +from ray.core.generated.ProfileTableData import ProfileTableData from ray.core.generated.HeartbeatTableData import HeartbeatTableData from ray.core.generated.ObjectTableData import ObjectTableData from ray.core.generated.ray.protocol.Task import Task @@ -33,9 +34,9 @@ "SubscribeToNotificationsReply", "ResultTableReply", "TaskExecutionDependencies", "TaskReply", "DriverTableMessage", "LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo", - "GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData", - "ObjectTableData", "Task", "TablePrefix", "TablePubsub", - "construct_error_message" + "GcsTableEntry", "ClientTableData", "ErrorTableData", "ProfileTableData", + "HeartbeatTableData", "ObjectTableData", "Task", "TablePrefix", + "TablePubsub", "construct_error_message" ] # These prefixes must be kept up-to-date with the definitions in @@ -53,6 +54,7 @@ TablePrefix_RAYLET_TASK_string = "RAYLET_TASK" TablePrefix_OBJECT_string = "OBJECT" TablePrefix_ERROR_INFO_string = "ERROR_INFO" +TablePrefix_PROFILE_string = "PROFILE" def construct_error_message(error_type, message, timestamp): diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index 34ecfc68d30e4..fb577d9029c94 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -10,6 +10,7 @@ from ray.services import get_ip_address from ray.services import get_port from ray.services import logger +import ray.utils class LogMonitor(object): @@ -70,7 +71,7 @@ def check_log_files_and_push_updates(self): if len(new_lines) > 0: self.log_files[log_filename] += new_lines redis_key = "LOGFILE:{}:{}".format( - self.node_ip_address, log_filename.decode("ascii")) + self.node_ip_address, ray.utils.decode(log_filename)) self.redis_client.rpush(redis_key, *new_lines) # Pass if we already failed to open the log file. diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 530de154be2c5..7126f863491c4 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -10,6 +10,7 @@ import ray.services as services from ray.autoscaler.commands import (create_or_update_cluster, teardown_cluster, get_head_node_ip) +import ray.utils def check_no_existing_redis_clients(node_ip_address, redis_client): @@ -31,7 +32,7 @@ def check_no_existing_redis_clients(node_ip_address, redis_client): if deleted: continue - if info[b"node_ip_address"].decode("ascii") == node_ip_address: + if ray.utils.decode(info[b"node_ip_address"]) == node_ip_address: raise Exception("This Redis instance is already connected to " "clients with this IP address.") diff --git a/python/ray/services.py b/python/ray/services.py index 4d4624753e5f7..ccee26437d071 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -386,7 +386,7 @@ def check_version_info(redis_client): if redis_reply is None: return - true_version_info = tuple(json.loads(redis_reply.decode("ascii"))) + true_version_info = tuple(json.loads(ray.utils.decode(redis_reply))) version_info = _compute_version_info() if version_info != true_version_info: node_ip_address = ray.services.get_node_ip_address() @@ -776,7 +776,7 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): new_env["REDIS_ADDRESS"] = redis_address # We generate the token used for authentication ourselves to avoid # querying the jupyter server. - token = binascii.hexlify(os.urandom(24)).decode("ascii") + token = ray.utils.decode(binascii.hexlify(os.urandom(24))) command = [ "jupyter", "notebook", "--no-browser", "--port={}".format(port), "--NotebookApp.iopub_data_rate_limit=10000000000", @@ -1373,7 +1373,7 @@ def start_ray_processes(address_info=None, redis_client = redis.StrictRedis( host=redis_ip_address, port=redis_port) redis_shards = redis_client.lrange("RedisShards", start=0, end=-1) - redis_shards = [shard.decode("ascii") for shard in redis_shards] + redis_shards = [ray.utils.decode(shard) for shard in redis_shards] address_info["redis_shards"] = redis_shards # Start the log monitor, if necessary. diff --git a/python/ray/utils.py b/python/ray/utils.py index db19f69bb5e72..82d475b94bff0 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -170,6 +170,8 @@ def random_string(): def decode(byte_str): """Make this unicode in Python 3, otherwise leave it as bytes.""" + if not isinstance(byte_str, bytes): + raise ValueError("The argument must be a bytes object.") if sys.version_info >= (3, 0): return byte_str.decode("ascii") else: diff --git a/python/ray/worker.py b/python/ray/worker.py index 515d65d160eeb..181b58147fd48 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -562,7 +562,7 @@ def submit_task(self, Returns: The return object IDs for this task. """ - with log_span("ray:submit_task", worker=self): + with profile("submit_task", worker=self): check_main_thread() if actor_id is None: assert actor_handle_id is None @@ -867,7 +867,7 @@ def _process_task(self, task): # Get task arguments from the object store. try: - with log_span("ray:task:get_arguments", worker=self): + with profile("task:deserialize_arguments", worker=self): arguments = self._get_arguments_for_execution( function_name, args) except (RayGetError, RayGetArgumentError) as e: @@ -882,7 +882,7 @@ def _process_task(self, task): # Execute the task. try: - with log_span("ray:task:execute", worker=self): + with profile("task:execute", worker=self): if task.actor_id().id() == NIL_ACTOR_ID: outputs = function_executor(*arguments) else: @@ -901,7 +901,7 @@ def _process_task(self, task): # Store the outputs in the local object store. try: - with log_span("ray:task:store_outputs", worker=self): + with profile("task:store_outputs", worker=self): # If this is an actor task, then the last object ID returned by # the task is a dummy output, not returned by the function # itself. Decrement to get the correct number of return values. @@ -976,7 +976,7 @@ def _wait_for_and_process_task(self, task): # Wait until the function to be executed has actually been registered # on this worker. We will push warnings to the user if we spend too # long in this loop. - with log_span("ray:wait_for_function", worker=self): + with profile("wait_for_function", worker=self): self._wait_for_function(function_id, driver_id) # Execute the task. @@ -984,22 +984,26 @@ def _wait_for_and_process_task(self, task): # warning to the user if we are waiting too long to acquire the lock # because that may indicate that the system is hanging, and it'd be # good to know where the system is hanging. - log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=self) with self.lock: - log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=self) function_name = (self.function_execution_info[driver_id][ function_id.id()]).function_name - contents = { - "function_name": function_name, - "task_id": task.task_id().hex(), - "worker_id": binary_to_hex(self.worker_id) - } - with log_span("ray:task", contents=contents, worker=self): + if not self.use_raylet: + extra_data = { + "function_name": function_name, + "task_id": task.task_id().hex(), + "worker_id": binary_to_hex(self.worker_id) + } + else: + extra_data = { + "name": function_name, + "task_id": task.task_id().hex() + } + with profile("task", extra_data=extra_data, worker=self): self._process_task(task) # Push all of the log events to the global state store. - flush_log() + flush_profile_data() # Increase the task execution counter. self.num_task_executions[driver_id][function_id.id()] += 1 @@ -1017,7 +1021,7 @@ def _get_next_task_from_local_scheduler(self): Returns: A task from the local scheduler. """ - with log_span("ray:get_task", worker=self): + with profile("get_task", worker=self): task = self.local_scheduler_client.get_task() # Automatically restrict the GPUs available to this task. @@ -1103,7 +1107,7 @@ def _webui_url_helper(client): The URL of the web UI as a string. """ result = client.hmget("webui", "url")[0] - return result.decode("ascii") if result is not None else result + return ray.utils.decode(result) if result is not None else result def get_webui_url(): @@ -1194,9 +1198,9 @@ def error_info(worker=global_worker): if error_applies_to_driver(error_key, worker=worker): error_contents = worker.redis_client.hgetall(error_key) error_contents = { - "type": error_contents[b"type"].decode("ascii"), - "message": error_contents[b"message"].decode("ascii"), - "data": error_contents[b"data"].decode("ascii") + "type": ray.utils.decode(error_contents[b"type"]), + "message": ray.utils.decode(error_contents[b"message"]), + "data": ray.utils.decode(error_contents[b"data"]) } errors.append(error_contents) @@ -1296,13 +1300,14 @@ def get_address_info_from_redis_helper(redis_address, assert b"ray_client_id" in info assert b"node_ip_address" in info assert b"client_type" in info - client_node_ip_address = info[b"node_ip_address"].decode("ascii") + client_node_ip_address = ray.utils.decode(info[b"node_ip_address"]) if (client_node_ip_address == node_ip_address or (client_node_ip_address == "127.0.0.1" and redis_ip_address == ray.services.get_node_ip_address())): - if info[b"client_type"].decode("ascii") == "plasma_manager": + if ray.utils.decode(info[b"client_type"]) == "plasma_manager": plasma_managers.append(info) - elif info[b"client_type"].decode("ascii") == "local_scheduler": + elif (ray.utils.decode( + info[b"client_type"]) == "local_scheduler"): local_schedulers.append(info) # Make sure that we got at least one plasma manager and local # scheduler. @@ -1311,16 +1316,16 @@ def get_address_info_from_redis_helper(redis_address, # Build the address information. object_store_addresses = [] for manager in plasma_managers: - address = manager[b"manager_address"].decode("ascii") + address = ray.utils.decode(manager[b"manager_address"]) port = services.get_port(address) object_store_addresses.append( services.ObjectStoreAddress( - name=manager[b"store_socket_name"].decode("ascii"), - manager_name=manager[b"manager_socket_name"].decode( - "ascii"), + name=ray.utils.decode(manager[b"store_socket_name"]), + manager_name=ray.utils.decode( + manager[b"manager_socket_name"]), manager_port=port)) scheduler_names = [ - scheduler[b"local_scheduler_socket_name"].decode("ascii") + ray.utils.decode(scheduler[b"local_scheduler_socket_name"]) for scheduler in local_schedulers ] client_info = { @@ -1343,8 +1348,8 @@ def get_address_info_from_redis_helper(redis_address, for client_message in clients: client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData( client_message, 0) - client_node_ip_address = client.NodeManagerAddress().decode( - "ascii") + client_node_ip_address = ray.utils.decode( + client.NodeManagerAddress()) if (client_node_ip_address == node_ip_address or (client_node_ip_address == "127.0.0.1" and redis_ip_address == ray.services.get_node_ip_address())): @@ -1352,12 +1357,12 @@ def get_address_info_from_redis_helper(redis_address, object_store_addresses = [ services.ObjectStoreAddress( - name=raylet.ObjectStoreSocketName().decode("ascii"), + name=ray.utils.decode(raylet.ObjectStoreSocketName()), manager_name=None, manager_port=None) for raylet in raylets ] raylet_socket_names = [ - raylet.RayletSocketName().decode("ascii") for raylet in raylets + ray.utils.decode(raylet.RayletSocketName()) for raylet in raylets ] return { "node_ip_address": node_ip_address, @@ -1807,6 +1812,21 @@ def custom_excepthook(type, value, tb): sys.excepthook = custom_excepthook +def _flush_profile_events(worker): + """Drivers run this as a thread to flush profile data in the background.""" + # Note(rkn): This is run on a background thread in the driver. It uses the + # local scheduler client. This should be ok because it doesn't read from + # the local scheduler client and we have the GIL here. However, if either + # of those things changes, then we could run into issues. + try: + while True: + time.sleep(1) + flush_profile_data(worker=worker) + except AttributeError: + # This is to suppress errors that occur at shutdown. + pass + + def print_error_messages_raylet(worker): """Print error messages in the background on the driver. @@ -1858,7 +1878,7 @@ def print_error_messages_raylet(worker): if job_id not in [worker.task_driver_id.id(), NIL_JOB_ID]: continue - error_message = error_data.ErrorMessage().decode("ascii") + error_message = ray.utils.decode(error_data.ErrorMessage()) if error_message not in old_error_messages: logger.error(error_message) @@ -1900,8 +1920,8 @@ def print_error_messages(worker): error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) for error_key in error_keys: if error_applies_to_driver(error_key, worker=worker): - error_message = worker.redis_client.hget( - error_key, "message").decode("ascii") + error_message = ray.utils.decode( + worker.redis_client.hget(error_key, "message")) if error_message not in old_error_messages: logger.error(error_message) old_error_messages.add(error_message) @@ -1915,8 +1935,8 @@ def print_error_messages(worker): for error_key in worker.redis_client.lrange( "ErrorKeys", num_errors_received, -1): if error_applies_to_driver(error_key, worker=worker): - error_message = worker.redis_client.hget( - error_key, "message").decode("ascii") + error_message = ray.utils.decode( + worker.redis_client.hget(error_key, "message")) if error_message not in old_error_messages: logger.error(error_message) old_error_messages.add(error_message) @@ -1939,9 +1959,9 @@ def fetch_and_register_remote_function(key, worker=global_worker): "module", "resources", "max_calls" ]) function_id = ray.ObjectID(function_id_str) - function_name = function_name.decode("ascii") + function_name = ray.utils.decode(function_name) max_calls = int(max_calls) - module = module.decode("ascii") + module = ray.utils.decode(module) # This is a placeholder in case the function can't be unpickled. This will # be overwritten if the function is successfully registered. @@ -2031,15 +2051,18 @@ def import_thread(worker, mode): # Handle the driver case first. if mode != WORKER_MODE: if key.startswith(b"FunctionsToRun"): - fetch_and_execute_function_to_run(key, worker=worker) + with profile("fetch_and_run_function", worker=worker): + fetch_and_execute_function_to_run(key, worker=worker) # Continue because FunctionsToRun are the only things that the # driver should import. continue if key.startswith(b"RemoteFunction"): - fetch_and_register_remote_function(key, worker=worker) + with profile("register_remote_function", worker=worker): + fetch_and_register_remote_function(key, worker=worker) elif key.startswith(b"FunctionsToRun"): - fetch_and_execute_function_to_run(key, worker=worker) + with profile("fetch_and_run_function", worker=worker): + fetch_and_execute_function_to_run(key, worker=worker) elif key.startswith(b"ActorClass"): # Keep track of the fact that this actor class has been # exported so that we know it is safe to turn this worker into @@ -2063,9 +2086,8 @@ def import_thread(worker, mode): # Handle the driver case first. if mode != WORKER_MODE: if key.startswith(b"FunctionsToRun"): - with log_span( - "ray:import_function_to_run", - worker=worker): + with profile( + "fetch_and_run_function", worker=worker): fetch_and_execute_function_to_run( key, worker=worker) # Continue because FunctionsToRun are the only things @@ -2073,13 +2095,12 @@ def import_thread(worker, mode): continue if key.startswith(b"RemoteFunction"): - with log_span( - "ray:import_remote_function", worker=worker): + with profile( + "register_remote_function", worker=worker): fetch_and_register_remote_function( key, worker=worker) elif key.startswith(b"FunctionsToRun"): - with log_span( - "ray:import_function_to_run", worker=worker): + with profile("fetch_and_run_function", worker=worker): fetch_and_execute_function_to_run( key, worker=worker) elif key.startswith(b"ActorClass"): @@ -2333,6 +2354,13 @@ def connect(info, t.daemon = True t.start() + if mode in [SCRIPT_MODE, SILENT_MODE] and worker.use_raylet: + t = threading.Thread(target=_flush_profile_events, args=(worker, )) + # Making the thread a daemon causes it to exit when the main thread + # exits. + t.daemon = True + t.start() + if mode in [SCRIPT_MODE, SILENT_MODE]: # Add the directory containing the script that is running to the Python # paths of the workers. Also add the current directory. Note that this @@ -2526,7 +2554,8 @@ def __init__(self, event_type, contents=None, worker=global_worker): def __enter__(self): """Log the beginning of a span event.""" - log(event_type=self.event_type, + _log( + event_type=self.event_type, contents=self.contents, kind=LOG_SPAN_START, worker=self.worker) @@ -2534,11 +2563,13 @@ def __enter__(self): def __exit__(self, type, value, tb): """Log the end of a span event. Log any exception that occurred.""" if type is None: - log(event_type=self.event_type, + _log( + event_type=self.event_type, kind=LOG_SPAN_END, worker=self.worker) else: - log(event_type=self.event_type, + _log( + event_type=self.event_type, contents={ "type": str(type), "value": value, @@ -2548,19 +2579,109 @@ def __exit__(self, type, value, tb): worker=self.worker) -def log_span(event_type, contents=None, worker=global_worker): - return RayLogSpan(event_type, contents=contents, worker=worker) +class RayLogSpanRaylet(object): + """An object used to enable logging a span of events with a with statement. + Attributes: + event_type (str): The type of the event being logged. + contents: Additional information to log. + """ -def log_event(event_type, contents=None, worker=global_worker): - log(event_type, kind=LOG_POINT, contents=contents, worker=worker) + def __init__(self, event_type, extra_data=None, worker=global_worker): + """Initialize a RayLogSpan object.""" + self.event_type = event_type + self.extra_data = extra_data if extra_data is not None else {} + self.worker = worker + def set_attribute(self, key, value): + """Add a key-value pair to the extra_data dict. -def log(event_type, kind, contents=None, worker=global_worker): + This can be used to add attributes that are not available when + ray.profile was called. + + Args: + key: The attribute name. + value: The attribute value. + """ + if not isinstance(key, str) or not isinstance(value, str): + raise ValueError("The extra_data argument must be a " + "dictionary mapping strings to strings.") + self.extra_data[key] = value + + def __enter__(self): + """Log the beginning of a span event. + + Returns: + The object itself is returned so that if the block is opened using + "with ray.profile(...) as prof:", we can call + "prof.set_attribute" inside the block. + """ + self.start_time = time.time() + return self + + def __exit__(self, type, value, tb): + """Log the end of a span event. Log any exception that occurred.""" + for key, value in self.extra_data.items(): + if not isinstance(key, str) or not isinstance(value, str): + raise ValueError("The extra_data argument must be a " + "dictionary mapping strings to strings.") + + event = { + "event_type": self.event_type, + "start_time": self.start_time, + "end_time": time.time(), + "extra_data": json.dumps(self.extra_data), + } + + if type is not None: + event["extra_data"] = json.dumps({ + "type": str(type), + "value": str(value), + "traceback": str(traceback.format_exc()), + }) + + self.worker.events.append(event) + + +def profile(event_type, extra_data=None, worker=global_worker): + """Profile a span of time so that it appears in the timeline visualization. + + This function can be used as follows (both on the driver or within a task). + + with ray.profile("custom event", extra_data={'key': 'value'}): + # Do some computation here. + + Optionally, a dictionary can be passed as the "extra_data" argument, and + it can have keys "name" and "cname" if you want to override the default + timeline display text and box color. Other values will appear at the bottom + of the chrome tracing GUI when you click on the box corresponding to this + profile span. + + Args: + event_type: A string describing the type of the event. + extra_data: This must be a dictionary mapping strings to strings. This + data will be added to the json objects that are used to populate + the timeline, so if you want to set a particular color, you can + simply set the "cname" attribute to an appropriate color. + Similarly, if you set the "name" attribute, then that will set the + text displayed on the box in the timeline. + + Returns: + An object that can profile a span of time via a "with" statement. + """ + if not worker.use_raylet: + return RayLogSpan(event_type, contents=extra_data, worker=worker) + else: + return RayLogSpanRaylet( + event_type, extra_data=extra_data, worker=worker) + + +def _log(event_type, kind, contents=None, worker=global_worker): """Log an event to the global state store. This adds the event to a buffer of events locally. The buffer can be - flushed and written to the global state store by calling flush_log(). + flushed and written to the global state store by calling + flush_profile_data(). Args: event_type (str): The type of the event. @@ -2571,6 +2692,9 @@ def log(event_type, kind, contents=None, worker=global_worker): time, and it is LOG_SPAN_END if we are finishing logging a span of time. """ + if worker.use_raylet: + raise Exception( + "This method is not supported in the raylet code path.") # TODO(rkn): This code currently takes around half a microsecond. Since we # call it tens of times per task, this adds up. We will need to redo the # logging code, perhaps in C. @@ -2584,13 +2708,32 @@ def log(event_type, kind, contents=None, worker=global_worker): worker.events.append((time.time(), event_type, kind, contents)) -def flush_log(worker=global_worker): - """Send the logged worker events to the global state store.""" - event_log_key = b"event_log:" + worker.worker_id - event_log_value = json.dumps(worker.events) +# TODO(rkn): Support calling this function in the middle of a task, and also +# call this periodically in the background from the driver. +def flush_profile_data(worker=global_worker): + """Push the logged profiling data to the global control store. + + By default, profiling information for a given task won't appear in the + timeline until after the task has completed. For very long-running tasks, + we may want profiling information to appear more quickly. In such cases, + this function can be called. Note that as an alternative, we could start + a thread in the background on workers that calls this automatically. + """ if not worker.use_raylet: + event_log_key = b"event_log:" + worker.worker_id + event_log_value = json.dumps(worker.events) worker.local_scheduler_client.log_event(event_log_key, event_log_value, time.time()) + else: + if worker.mode == WORKER_MODE: + component_type = "worker" + else: + component_type = "driver" + + worker.local_scheduler_client.push_profile_events( + component_type, ray.ObjectID(worker.worker_id), + worker.node_ip_address, worker.events) + worker.events = [] @@ -2611,7 +2754,7 @@ def get(object_ids, worker=global_worker): A Python object or a list of Python objects. """ worker.check_connected() - with log_span("ray:get", worker=worker): + with profile("ray.get", worker=worker): check_main_thread() if worker.mode == PYTHON_MODE: @@ -2644,7 +2787,7 @@ def put(value, worker=global_worker): The object ID assigned to this value. """ worker.check_connected() - with log_span("ray:put", worker=worker): + with profile("ray.put", worker=worker): check_main_thread() if worker.mode == PYTHON_MODE: @@ -2702,7 +2845,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): type(object_id))) worker.check_connected() - with log_span("ray:wait", worker=worker): + with profile("ray.wait", worker=worker): check_main_thread() # When Ray is run in PYTHON_MODE, all functions are run immediately, diff --git a/src/common/lib/python/common_extension.cc b/src/common/lib/python/common_extension.cc index 22d7877ba7d6c..31178160d2023 100644 --- a/src/common/lib/python/common_extension.cc +++ b/src/common/lib/python/common_extension.cc @@ -165,7 +165,11 @@ static PyObject *PyObjectID_id(PyObject *self) { static PyObject *PyObjectID_hex(PyObject *self) { PyObjectID *s = (PyObjectID *) self; std::string hex_id = s->object_id.hex(); - PyObject *result = PyUnicode_FromString(hex_id.c_str()); +#if PY_MAJOR_VERSION >= 3 + PyObject *result = PyUnicode_FromStringAndSize(hex_id.data(), hex_id.size()); +#else + PyObject *result = PyBytes_FromStringAndSize(hex_id.data(), hex_id.size()); +#endif return result; } diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc index 28b9caf255971..e7edb1c0a1e74 100644 --- a/src/common/redis_module/ray_redis_module.cc +++ b/src/common/redis_module/ray_redis_module.cc @@ -695,6 +695,8 @@ int TableAppend_DoWrite(RedisModuleCtx *ctx, // Check that we actually add a new entry during the append. This is only // necessary since we implement the log with a sorted set, so all entries // must be unique, or else we will have gaps in the log. + // TODO(rkn): We need to get rid of this uniqueness requirement. We can + // easily have multiple log events with the same message. RAY_CHECK(flags == REDISMODULE_ZADD_ADDED) << "Appended a duplicate entry"; return REDISMODULE_OK; } else { diff --git a/src/local_scheduler/lib/python/local_scheduler_extension.cc b/src/local_scheduler/lib/python/local_scheduler_extension.cc index bd9dc6f0fc094..c75994d6556ab 100644 --- a/src/local_scheduler/lib/python/local_scheduler_extension.cc +++ b/src/local_scheduler/lib/python/local_scheduler_extension.cc @@ -309,6 +309,109 @@ static PyObject *PyLocalSchedulerClient_push_error(PyObject *self, Py_RETURN_NONE; } +int PyBytes_or_PyUnicode_to_string(PyObject *py_string, std::string &out) { + // Handle the case where the key is a bytes object and the case where it + // is a unicode object. + if (PyUnicode_Check(py_string)) { + PyObject *ascii_string = PyUnicode_AsASCIIString(py_string); + out = + std::string(PyBytes_AsString(ascii_string), PyBytes_Size(ascii_string)); + Py_DECREF(ascii_string); + } else if (PyBytes_Check(py_string)) { + out = std::string(PyBytes_AsString(py_string), PyBytes_Size(py_string)); + } else { + return -1; + } + + return 0; +} + +static PyObject *PyLocalSchedulerClient_push_profile_events(PyObject *self, + PyObject *args) { + const char *component_type; + int component_type_length; + UniqueID component_id; + PyObject *profile_data; + const char *node_ip_address; + int node_ip_address_length; + + if (!PyArg_ParseTuple(args, "s#O&s#O", &component_type, + &component_type_length, &PyObjectToUniqueID, + &component_id, &node_ip_address, + &node_ip_address_length, &profile_data)) { + return NULL; + } + + ProfileTableDataT profile_info; + profile_info.component_type = + std::string(component_type, component_type_length); + profile_info.component_id = component_id.binary(); + profile_info.node_ip_address = + std::string(node_ip_address, node_ip_address_length); + + if (PyList_Size(profile_data) == 0) { + // Short circuit if there are no profile events. + Py_RETURN_NONE; + } + + for (int64_t i = 0; i < PyList_Size(profile_data); ++i) { + ProfileEventT profile_event; + PyObject *py_profile_event = PyList_GetItem(profile_data, i); + + if (!PyDict_CheckExact(py_profile_event)) { + return NULL; + } + + PyObject *key, *val; + Py_ssize_t pos = 0; + while (PyDict_Next(py_profile_event, &pos, &key, &val)) { + std::string key_string; + if (PyBytes_or_PyUnicode_to_string(key, key_string) == -1) { + return NULL; + } + + // TODO(rkn): If the dictionary is formatted incorrectly, that could lead + // to errors. E.g., if any of the strings are empty, that will cause + // segfaults in the node manager. + + if (key_string == std::string("event_type")) { + if (PyBytes_or_PyUnicode_to_string(val, profile_event.event_type) == + -1) { + return NULL; + } + if (profile_event.event_type.size() == 0) { + return NULL; + } + } else if (key_string == std::string("start_time")) { + profile_event.start_time = PyFloat_AsDouble(val); + } else if (key_string == std::string("end_time")) { + profile_event.end_time = PyFloat_AsDouble(val); + } else if (key_string == std::string("extra_data")) { + if (PyBytes_or_PyUnicode_to_string(val, profile_event.extra_data) == + -1) { + return NULL; + } + if (profile_event.extra_data.size() == 0) { + return NULL; + } + } else { + return NULL; + } + } + + // Note that profile_info.profile_events is a vector of unique pointers, so + // profile_event will be deallocated when profile_info goes out of scope. + profile_info.profile_events.emplace_back(new ProfileEventT(profile_event)); + } + + local_scheduler_push_profile_events( + reinterpret_cast(self) + ->local_scheduler_connection, + profile_info); + + Py_RETURN_NONE; +} + static PyMethodDef PyLocalSchedulerClient_methods[] = { {"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS, "Notify the local scheduler that this client is exiting gracefully."}, @@ -338,6 +441,9 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = { "Wait for a list of objects to be created."}, {"push_error", (PyCFunction) PyLocalSchedulerClient_push_error, METH_VARARGS, "Push an error message to the relevant driver."}, + {"push_profile_events", + (PyCFunction) PyLocalSchedulerClient_push_profile_events, METH_VARARGS, + "Store some profiling events in the GCS."}, {NULL} /* Sentinel */ }; diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc index f8ac3026bd7d9..a9b1113f80aae 100644 --- a/src/local_scheduler/local_scheduler_client.cc +++ b/src/local_scheduler/local_scheduler_client.cc @@ -322,3 +322,17 @@ void local_scheduler_push_error(LocalSchedulerConnection *conn, ray::protocol::MessageType::PushErrorRequest), fbb.GetSize(), fbb.GetBufferPointer()); } + +void local_scheduler_push_profile_events( + LocalSchedulerConnection *conn, + const ProfileTableDataT &profile_events) { + flatbuffers::FlatBufferBuilder fbb; + + auto message = CreateProfileTableData(fbb, &profile_events); + fbb.Finish(message); + + write_message(conn->conn, + static_cast( + ray::protocol::MessageType::PushProfileEventsRequest), + fbb.GetSize(), fbb.GetBufferPointer()); +} diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h index 95a3de0c073c7..e00342ce189b1 100644 --- a/src/local_scheduler/local_scheduler_client.h +++ b/src/local_scheduler/local_scheduler_client.h @@ -225,4 +225,13 @@ void local_scheduler_push_error(LocalSchedulerConnection *conn, const std::string &error_message, double timestamp); +/// Store some profile events in the GCS. +/// +/// \param conn The connection information. +/// \param profile_events A batch of profiling event information. +/// \return Void. +void local_scheduler_push_profile_events( + LocalSchedulerConnection *conn, + const ProfileTableDataT &profile_events); + #endif diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 1700eac5f1e6e..ecae1f1a0f37d 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -17,6 +17,7 @@ AsyncGcsClient::AsyncGcsClient(const ClientID &client_id, CommandType command_ty task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this)); heartbeat_table_.reset(new HeartbeatTable(context_, this)); error_table_.reset(new ErrorTable(primary_context_, this)); + profile_table_.reset(new ProfileTable(context_, this)); command_type_ = command_type; } @@ -84,6 +85,8 @@ HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } ErrorTable &AsyncGcsClient::error_table() { return *error_table_; } +ProfileTable &AsyncGcsClient::profile_table() { return *profile_table_; } + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 4f249fca95c85..852db93c5cb7b 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -59,6 +59,7 @@ class RAY_EXPORT AsyncGcsClient { ClientTable &client_table(); HeartbeatTable &heartbeat_table(); ErrorTable &error_table(); + ProfileTable &profile_table(); // We also need something to export generic code to run on workers from the // driver (to set the PYTHONPATH) @@ -81,6 +82,7 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr task_reconstruction_log_; std::unique_ptr heartbeat_table_; std::unique_ptr error_table_; + std::unique_ptr profile_table_; std::unique_ptr client_table_; // The following contexts write to the data shard std::shared_ptr context_; diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 8f343437bbfa0..115a03349c6f0 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -15,6 +15,7 @@ enum TablePrefix:int { TASK_RECONSTRUCTION, HEARTBEAT, ERROR_INFO, + PROFILE, } // The channel that Add operations to the Table should be published on, if any. @@ -121,6 +122,33 @@ table CustomSerializerData { table ConfigTableData { } +table ProfileEvent { + // The type of the event. + event_type: string; + // The start time of the event. + start_time: double; + // The end time of the event. If the event is a point event, then this should + // be the same as the start time. + end_time: double; + // Additional data associated with the event. This data must be serialized + // using JSON. + extra_data: string; +} + +table ProfileTableData { + // The type of the component that generated the event, e.g., worker or + // object_manager, or node_manager. + component_type: string; + // An identifier for the component that generated the event. + component_id: string; + // An identifier for the node that generated the event. + node_ip_address: string; + // This is a batch of profiling events. We batch these together for + // performance reasons because a single task may generate many events, and + // we don't want each event to require a GCS command. + profile_events: [ProfileEvent]; +} + table RayResource { // The type of the resource. resource_name: string; diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index a499a30aa205e..d5ed2869946ce 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -219,6 +219,45 @@ Status ErrorTable::PushErrorToDriver(const JobID &job_id, const std::string &typ }); } +Status ProfileTable::AddProfileEvent(const std::string &event_type, + const std::string &component_type, + const UniqueID &component_id, + const std::string &node_ip_address, + double start_time, double end_time, + const std::string &extra_data) { + auto data = std::make_shared(); + + ProfileEventT profile_event; + profile_event.event_type = event_type; + profile_event.start_time = start_time; + profile_event.end_time = end_time; + profile_event.extra_data = extra_data; + + data->component_type = component_type; + data->component_id = component_id.binary(); + data->node_ip_address = node_ip_address; + data->profile_events.emplace_back(new ProfileEventT(profile_event)); + + return Append(JobID::nil(), component_id, data, + [](ray::gcs::AsyncGcsClient *client, const JobID &id, + const ProfileTableDataT &data) { + RAY_LOG(DEBUG) << "Profile message pushed callback"; + }); +} + +Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { + auto data = std::make_shared(); + // There is some room for optimization here because the Append function will just + // call "Pack" and undo the "UnPack". + profile_events.UnPackTo(data.get()); + + return Append(JobID::nil(), from_flatbuf(*profile_events.component_id()), data, + [](ray::gcs::AsyncGcsClient *client, const JobID &id, + const ProfileTableDataT &data) { + RAY_LOG(DEBUG) << "Profile message pushed callback"; + }); +} + void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) { client_added_callback_ = callback; // Call the callback for any added clients that are cached. @@ -371,6 +410,7 @@ template class Log; template class Table; template class Log; template class Log; +template class Log; } // namespace gcs diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index e2e719de0bc74..b10efe40b4d27 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -12,6 +12,7 @@ #include "ray/gcs/format/gcs_generated.h" #include "ray/gcs/redis_context.h" +// TODO(rkn): Remove this include. #include "ray/raylet/format/node_manager_generated.h" // TODO(pcm): Remove this @@ -95,7 +96,8 @@ class Log : virtual public PubsubInterface { /// /// \param job_id The ID of the job (= driver). /// \param id The ID of the data that is added to the GCS. - /// \param data Data to append to the log. + /// \param data Data to append to the log. TODO(rkn): This can be made const, + /// right? /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status @@ -438,7 +440,8 @@ class ErrorTable : private Log { /// Push an error message for a specific job. /// /// TODO(rkn): We need to make sure that the errors are unique because - /// duplicate messages currently cause failures (the GCS doesn't allow it). + /// duplicate messages currently cause failures (the GCS doesn't allow it). A + /// natural way to do this is to have finer-grained time stamps. /// /// \param job_id The ID of the job that generated the error. If the error /// should be pushed to all jobs, then this should be nil. @@ -450,6 +453,37 @@ class ErrorTable : private Log { const std::string &error_message, double timestamp); }; +class ProfileTable : private Log { + public: + ProfileTable(const std::shared_ptr &context, AsyncGcsClient *client) + : Log(context, client) { + prefix_ = TablePrefix::PROFILE; + }; + + /// Add a single profile event to the profile table. + /// + /// \param event_type The type of the event. + /// \param component_type The type of the component that the event came from. + /// \param component_id An identifier for the component that generated the event. + /// \param node_ip_address The IP address of the node that generated the event. + /// \param start_time The timestamp of the event start, this should be in seconds since + /// the Unix epoch. + /// \param end_time The timestamp of the event end, this should be in seconds since + /// the Unix epoch. If the event is a point event, this should be equal to start_time. + /// \param extra_data Additional data to associate with the event. + /// \return Status. + Status AddProfileEvent(const std::string &event_type, const std::string &component_type, + const UniqueID &component_id, const std::string &node_ip_address, + double start_time, double end_time, + const std::string &extra_data); + + /// Add a batch of profiling events to the profile table. + /// + /// \param profile_events The profile events to record. + /// \return Status. + Status AddProfileEventBatch(const ProfileTableData &profile_events); +}; + using CustomSerializerTable = Table; using ConfigTable = Table; diff --git a/src/ray/raylet/CMakeLists.txt b/src/ray/raylet/CMakeLists.txt index 2185fd600cf93..735ed417419a1 100644 --- a/src/ray/raylet/CMakeLists.txt +++ b/src/ray/raylet/CMakeLists.txt @@ -12,7 +12,7 @@ add_custom_command( # flatbuffers message Message, which can be used to store deserialized # messages in data structures. This is currently used for ObjectInfo for # example. - COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} ${NODE_MANAGER_FBS_SRC} --cpp --gen-object-api --gen-mutable --scoped-enums + COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} -I ${GCS_FBS_OUTPUT_DIRECTORY} ${NODE_MANAGER_FBS_SRC} --cpp --gen-object-api --gen-mutable --scoped-enums DEPENDS ${FBS_DEPENDS} COMMENT "Running flatc compiler on ${NODE_MANAGER_FBS_SRC}" VERBATIM) @@ -23,7 +23,7 @@ add_custom_target(gen_node_manager_fbs DEPENDS ${NODE_MANAGER_FBS_OUTPUT_FILES}) set(PYTHON_OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/../../../python/ray/core/generated/) add_custom_command( TARGET gen_node_manager_fbs - COMMAND ${FLATBUFFERS_COMPILER} -p -o ${PYTHON_OUTPUT_DIR} ${NODE_MANAGER_FBS_SRC} + COMMAND ${FLATBUFFERS_COMPILER} -p -o ${PYTHON_OUTPUT_DIR} -I ${GCS_FBS_OUTPUT_DIRECTORY} ${NODE_MANAGER_FBS_SRC} DEPENDS ${FBS_DEPENDS} COMMENT "Running flatc compiler on ${NODE_MANAGER_FBS_SRC}" VERBATIM) @@ -38,6 +38,7 @@ ADD_RAY_TEST(task_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main p ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) ADD_RAY_TEST(task_dependency_manager_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) +include_directories(${GCS_FBS_OUTPUT_DIRECTORY}) add_library(rayletlib raylet.cc ${NODE_MANAGER_FBS_OUTPUT_FILES}) target_link_libraries(rayletlib ray_static ${Boost_SYSTEM_LIBRARY}) diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 29c635ad5b4bd..d66d43b5fd6c0 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -1,5 +1,8 @@ // Local scheduler protocol specification +include "gcs.fbs"; + + // TODO(swang): We put the flatbuffer types in a separate namespace for now to // avoid conflicts with legacy Ray types. namespace ray.protocol; @@ -62,6 +65,9 @@ enum MessageType:int { // Push an error to the relevant driver. This is sent from a worker to the // node manager. PushErrorRequest, + // Push some profiling events to the GCS. When sending this message to the + // node manager, the message itself is serialized as a ProfileTableData object. + PushProfileEventsRequest, } table TaskExecutionSpecification { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index ace79e71fa4af..d82b9553a7e31 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -552,6 +552,11 @@ void NodeManager::ProcessClientMessage( RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message, timestamp)); } break; + case protocol::MessageType::PushProfileEventsRequest: { + auto message = flatbuffers::GetRoot(message_data); + + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message)); + } break; default: RAY_LOG(FATAL) << "Received unexpected message type " << message_type; diff --git a/test/jenkins_tests/multi_node_docker_test.py b/test/jenkins_tests/multi_node_docker_test.py index b757d942d048c..bb2a7d6d5c3d6 100644 --- a/test/jenkins_tests/multi_node_docker_test.py +++ b/test/jenkins_tests/multi_node_docker_test.py @@ -11,6 +11,18 @@ import sys +# This is duplicated from ray.utils so that we do not have to introduce a +# dependency on Ray to run this file. +def decode(byte_str): + """Make this unicode in Python 3, otherwise leave it as bytes.""" + if not isinstance(byte_str, bytes): + raise ValueError("The argument must be a bytes object.") + if sys.version_info >= (3, 0): + return byte_str.decode("ascii") + else: + return byte_str + + def wait_for_output(proc): """This is a convenience method to parse a process's stdout and stderr. @@ -27,7 +39,7 @@ def wait_for_output(proc): # NOTE(rkn): This try/except block is here because I once saw an # exception raised here and want to print more information if that # happens again. - stdout_data = stdout_data.decode("ascii") + stdout_data = decode(stdout_data) except UnicodeDecodeError: raise Exception("Failed to decode stdout_data:", stdout_data) @@ -36,7 +48,7 @@ def wait_for_output(proc): # NOTE(rkn): This try/except block is here because I once saw an # exception raised here and want to print more information if that # happens again. - stderr_data = stderr_data.decode("ascii") + stderr_data = decode(stderr_data) except UnicodeDecodeError: raise Exception("Failed to decode stderr_data:", stderr_data) diff --git a/test/multi_node_test.py b/test/multi_node_test.py index d31fb2dc48891..db9c76218e19d 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -24,7 +24,8 @@ def run_string_as_driver(driver_script): with tempfile.NamedTemporaryFile() as f: f.write(driver_script.encode("ascii")) f.flush() - out = subprocess.check_output([sys.executable, f.name]).decode("ascii") + out = ray.utils.decode( + subprocess.check_output([sys.executable, f.name])) return out diff --git a/test/runtest.py b/test/runtest.py index 83edc15713f4b..7cba4f010f9e4 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -974,7 +974,7 @@ def get_path2(): @unittest.skipIf( os.environ.get("RAY_USE_XRAY") == "1", - "This test does not work with xray yet.") + "This test does not work with xray (nor is it intended to).") def testLoggingAPI(self): self.init_ray(driver_mode=ray.SILENT_MODE) @@ -996,38 +996,76 @@ def wait_for_num_events(num_events, timeout=10): time.sleep(0.1) print("Timing out of wait.") - @ray.remote - def test_log_event(): - ray.log_event("event_type1", contents={"key": "val"}) - @ray.remote def test_log_span(): - with ray.log_span("event_type2", contents={"key": "val"}): + with ray.profile("event_type2", extra_data={"key": "val"}): pass - # Make sure that we can call ray.log_event in a remote function. - ray.get(test_log_event.remote()) - # Wait for the event to appear in the event log. - wait_for_num_events(1) - self.assertEqual(len(events()), 1) - # Make sure that we can call ray.log_span in a remote function. ray.get(test_log_span.remote()) # Wait for the events to appear in the event log. - wait_for_num_events(2) - self.assertEqual(len(events()), 2) + wait_for_num_events(1) + self.assertEqual(len(events()), 1) @ray.remote def test_log_span_exception(): - with ray.log_span("event_type2", contents={"key": "val"}): + with ray.log_span("event_type2", extra_data={"key": "val"}): raise Exception("This failed.") # Make sure that logging a span works if an exception is thrown. test_log_span_exception.remote() # Wait for the events to appear in the event log. - wait_for_num_events(3) - self.assertEqual(len(events()), 3) + wait_for_num_events(2) + self.assertEqual(len(events()), 2) + + @unittest.skipIf( + os.environ.get("RAY_USE_XRAY") != "1", + "This test only works with xray.") + def testProfilingAPI(self): + self.init_ray(num_cpus=2) + + @ray.remote + def f(): + with ray.profile( + "custom_event", + extra_data={"name": "custom name"}) as ray_prof: + ray_prof.set_attribute("key", "value") + + ray.put(1) + object_id = f.remote() + ray.wait([object_id]) + ray.get(object_id) + + # Wait until all of the profiling information appears in the profile + # table. + timeout_seconds = 20 + start_time = time.time() + while True: + if time.time() - start_time > timeout_seconds: + raise Exception("Timed out while waiting for information in " + "profile table.") + profile_data = ray.global_state.chrome_tracing_dump() + event_types = {event["cat"] for event in profile_data} + expected_types = [ + "get_task", + "task", + "task:deserialize_arguments", + "task:execute", + "task:store_outputs", + "wait_for_function", + "ray.get", + "ray.put", + "ray.wait", + "submit_task", + "fetch_and_run_function", + "register_remote_function", + "custom_event", # This is the custom one from ray.profile. + ] + + if all(expected_type in event_types + for expected_type in expected_types): + break def testIdenticalFunctionNames(self): # Define a bunch of remote functions and make sure that we don't @@ -1116,7 +1154,11 @@ def init_ray(self, **kwargs): if kwargs is None: kwargs = {} kwargs["start_ray_local"] = True - kwargs["num_redis_shards"] = 20 + if os.environ.get("RAY_USE_XRAY") == "1": + print("XRAY currently supports only a single Redis shard.") + kwargs["num_redis_shards"] = 1 + else: + kwargs["num_redis_shards"] = 20 kwargs["redirect_output"] = True ray.worker._init(**kwargs) @@ -2203,7 +2245,7 @@ def f(): @unittest.skipIf( os.environ.get("RAY_USE_XRAY") == "1", - "This test does not work with xray yet.") + "This test does not work with xray (nor is it intended to).") def testTaskProfileAPI(self): ray.init(redirect_output=True)