From 6ca2c2a609744f8c752deab75230016eb947c2df Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sun, 15 Apr 2018 20:36:01 -0700 Subject: [PATCH] Allow numpy arrays to be passed by value into tasks (and inlined in the task spec). (#1816) * Allow numpy arrays and larger objects to be passed by value in task specifications. * Fix bug. * Fix bug. Inline all bug numpy object arrays. * Increase size limit for inlining args in task spec. * Give numpy init different signatures in Python 2 and Python 3. * Simplify code. * Fix test. * Use import_array1 instead of import_array. --- python/ray/common/test/test.py | 9 +++-- src/common/cmake/Common.cmake | 1 + src/common/lib/python/common_extension.cc | 36 ++++++++++++++----- src/common/lib/python/common_extension.h | 2 ++ src/common/state/ray_config.h | 4 +-- src/local_scheduler/CMakeLists.txt | 1 + .../local_scheduler_extension.cc | 1 + 7 files changed, 41 insertions(+), 13 deletions(-) diff --git a/python/ray/common/test/test.py b/python/ray/common/test/test.py index b1e9427abad7..5892d289fa73 100644 --- a/python/ray/common/test/test.py +++ b/python/ray/common/test/test.py @@ -30,7 +30,9 @@ def random_task_id(): BASE_SIMPLE_OBJECTS = [ 0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"", - 990 * u"h" + 990 * u"h", + np.ones(3), + np.array([True, False]), None, True, False ] if sys.version_info < (3, 0): @@ -60,8 +62,9 @@ def __init__(self): BASE_COMPLEX_OBJECTS = [ - 999 * "h", 999 * u"h", lst, - Foo(), 10 * [10 * [10 * [1]]] + 15000 * "h", 15000 * u"h", lst, + Foo(), 100 * [100 * [10 * [1]]], + np.array([Foo()]) ] LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS] diff --git a/src/common/cmake/Common.cmake b/src/common/cmake/Common.cmake index 134f35680aa0..9e12ab9cc0bb 100644 --- a/src/common/cmake/Common.cmake +++ b/src/common/cmake/Common.cmake @@ -39,6 +39,7 @@ set(CMAKE_C_FLAGS "-g -Wall -Wextra -Werror=implicit-function-declaration -Wno-s # Code for finding Python find_package(PythonInterp REQUIRED) +find_package(NumPy REQUIRED) # Now find the Python include directories. execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "from distutils.sysconfig import *; print(get_python_inc())" diff --git a/src/common/lib/python/common_extension.cc b/src/common/lib/python/common_extension.cc index 0ef0bb00fefa..230a576e60c6 100644 --- a/src/common/lib/python/common_extension.cc +++ b/src/common/lib/python/common_extension.cc @@ -2,6 +2,11 @@ #include "bytesobject.h" #include "node.h" +// Don't use the deprecated Numpy functions. +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include + #include "common.h" #include "common_extension.h" #include "common_protocol.h" @@ -23,6 +28,11 @@ PyObject *pickle_loads = NULL; PyObject *pickle_dumps = NULL; PyObject *pickle_protocol = NULL; +int init_numpy_module(void) { + import_array1(-1); + return 0; +} + void init_pickle_module(void) { #if PY_MAJOR_VERSION >= 3 pickle_module = PyImport_ImportModule("pickle"); @@ -783,16 +793,17 @@ PyObject *PyTask_make(TaskSpec *task_spec, int64_t task_size) { * objects recursively contained within this object will be added to the * value at this address. This is used to make sure that we do not * serialize objects that are too large. - * @return 0 if the object cannot be serialized in the task and 1 if it can. + * @return False if the object cannot be serialized in the task and true if it + * can. */ -int is_simple_value(PyObject *value, int *num_elements_contained) { +bool is_simple_value(PyObject *value, int *num_elements_contained) { *num_elements_contained += 1; if (*num_elements_contained >= RayConfig::instance().num_elements_limit()) { - return 0; + return false; } if (PyInt_Check(value) || PyLong_Check(value) || value == Py_False || value == Py_True || PyFloat_Check(value) || value == Py_None) { - return 1; + return true; } if (PyBytes_CheckExact(value)) { *num_elements_contained += PyBytes_Size(value); @@ -808,7 +819,7 @@ int is_simple_value(PyObject *value, int *num_elements_contained) { PyList_Size(value) < RayConfig::instance().size_limit()) { for (Py_ssize_t i = 0; i < PyList_Size(value); ++i) { if (!is_simple_value(PyList_GetItem(value, i), num_elements_contained)) { - return 0; + return false; } } return (*num_elements_contained < @@ -821,7 +832,7 @@ int is_simple_value(PyObject *value, int *num_elements_contained) { while (PyDict_Next(value, &pos, &key, &val)) { if (!is_simple_value(key, num_elements_contained) || !is_simple_value(val, num_elements_contained)) { - return 0; + return false; } } return (*num_elements_contained < @@ -831,13 +842,22 @@ int is_simple_value(PyObject *value, int *num_elements_contained) { PyTuple_Size(value) < RayConfig::instance().size_limit()) { for (Py_ssize_t i = 0; i < PyTuple_Size(value); ++i) { if (!is_simple_value(PyTuple_GetItem(value, i), num_elements_contained)) { - return 0; + return false; } } return (*num_elements_contained < RayConfig::instance().num_elements_limit()); } - return 0; + if (PyArray_CheckExact(value)) { + PyArrayObject *array = reinterpret_cast(value); + if (PyArray_TYPE(array) == NPY_OBJECT) { + return false; + } + *num_elements_contained += PyArray_NBYTES(array); + return (*num_elements_contained < + RayConfig::instance().num_elements_limit()); + } + return false; } PyObject *check_simple_value(PyObject *self, PyObject *args) { diff --git a/src/common/lib/python/common_extension.h b/src/common/lib/python/common_extension.h index 81a9f0f48887..b24e45a1f88d 100644 --- a/src/common/lib/python/common_extension.h +++ b/src/common/lib/python/common_extension.h @@ -43,6 +43,8 @@ extern PyObject *pickle_module; extern PyObject *pickle_dumps; extern PyObject *pickle_loads; +int init_numpy_module(void); + void init_pickle_module(void); extern TaskBuilder *g_task_builder; diff --git a/src/common/state/ray_config.h b/src/common/state/ray_config.h index ee2830fd9f91..3ee68e4a3a0d 100644 --- a/src/common/state/ray_config.h +++ b/src/common/state/ray_config.h @@ -105,8 +105,8 @@ class RayConfig { manager_timeout_milliseconds_(1000), buf_size_(80 * 1024), max_time_for_handler_milliseconds_(1000), - size_limit_(100), - num_elements_limit_(1000), + size_limit_(10000), + num_elements_limit_(10000), max_time_for_loop_(1000), redis_db_connect_retries_(50), redis_db_connect_wait_milliseconds_(100), diff --git a/src/local_scheduler/CMakeLists.txt b/src/local_scheduler/CMakeLists.txt index f4168ce3b59b..18054c581de5 100644 --- a/src/local_scheduler/CMakeLists.txt +++ b/src/local_scheduler/CMakeLists.txt @@ -19,6 +19,7 @@ if(APPLE) endif(APPLE) include_directories("${PYTHON_INCLUDE_DIRS}") +include_directories("${NUMPY_INCLUDE_DIR}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall") diff --git a/src/local_scheduler/local_scheduler_extension.cc b/src/local_scheduler/local_scheduler_extension.cc index 2bf8e8c31734..6e13de5a02ca 100644 --- a/src/local_scheduler/local_scheduler_extension.cc +++ b/src/local_scheduler/local_scheduler_extension.cc @@ -311,6 +311,7 @@ MOD_INIT(liblocal_scheduler_library) { "A module for the local scheduler."); #endif + init_numpy_module(); init_pickle_module(); Py_INCREF(&PyTaskType);