Skip to content

Commit

Permalink
[Core] Remove multiple core workers in one process 1/n. (ray-project#…
Browse files Browse the repository at this point in the history
…24147)

This is the 1st PR to remove the code path of multiple core workers in one process. This PR is aiming to remove the flags and APIs related to `num_workers`.
After this PR checking in, we needn't to consider the multiple core workers any longer.

The further following PRs are related to the deeper logic refactor, like eliminating the gap between core worker and core worker process,  removing the logic related to multiple workers from workerpool, gcs and etc.

**BREAK CHANGE**
This PR removes these APIs:
- Ray.wrapRunnable();
- Ray.wrapCallable();
- Ray.setAsyncContext();
- Ray.getAsyncContext();

And the following APIs are not allowed to invoke in a user-created thread in local mode:
- Ray.getRuntimeContext().getCurrentActorId();
- Ray.getRuntimeContext().getCurrentTaskId()

Note that this PR shouldn't be merged to 1.x.
  • Loading branch information
jovany-wang authored May 18, 2022
1 parent 1d5e6d9 commit eb29895
Show file tree
Hide file tree
Showing 57 changed files with 137 additions and 1,200 deletions.
1 change: 0 additions & 1 deletion cpp/src/ray/util/process_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback)
options.node_manager_port = ConfigInternal::Instance().node_manager_port;
options.raylet_ip_address = node_ip;
options.driver_name = "cpp_worker";
options.num_workers = 1;
options.metrics_agent_port = -1;
options.task_execution_callback = callback;
options.startup_token = ConfigInternal::Instance().startup_token;
Expand Down
47 changes: 0 additions & 47 deletions java/api/src/main/java/io/ray/api/Ray.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import io.ray.api.runtimecontext.RuntimeContext;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Callable;

/** This class contains all public APIs of Ray. */
public final class Ray extends RayCall {
Expand Down Expand Up @@ -205,52 +204,6 @@ public static <T extends BaseActorHandle> Optional<T> getActor(String name, Stri
return internal().getActor(name, namespace);
}

/**
* If users want to use Ray API in their own threads, call this method to get the async context
* and then call {@link #setAsyncContext} at the beginning of the new thread.
*
* @return The async context.
*/
public static Object getAsyncContext() {
return internal().getAsyncContext();
}

/**
* Set the async context for the current thread.
*
* @param asyncContext The async context to set.
*/
public static void setAsyncContext(Object asyncContext) {
internal().setAsyncContext(asyncContext);
}

// TODO (kfstorm): add the `rollbackAsyncContext` API to allow rollbacking the async context of
// the current thread to the one before `setAsyncContext` is called.

// TODO (kfstorm): unify the `wrap*` methods.

/**
* If users want to use Ray API in their own threads, they should wrap their {@link Runnable}
* objects with this method.
*
* @param runnable The runnable to wrap.
* @return The wrapped runnable.
*/
public static Runnable wrapRunnable(Runnable runnable) {
return internal().wrapRunnable(runnable);
}

/**
* If users want to use Ray API in their own threads, they should wrap their {@link Callable}
* objects with this method.
*
* @param callable The callable to wrap.
* @return The wrapped callable.
*/
public static <T> Callable<T> wrapCallable(Callable<T> callable) {
return internal().wrapCallable(callable);
}

/** Get the underlying runtime instance. */
public static RayRuntime internal() {
if (runtime == null) {
Expand Down
21 changes: 0 additions & 21 deletions java/api/src/main/java/io/ray/api/runtime/RayRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;

/** Base interface of a Ray runtime. */
public interface RayRuntime {
Expand Down Expand Up @@ -207,26 +206,6 @@ <T> ActorHandle<T> createActor(

RuntimeContext getRuntimeContext();

Object getAsyncContext();

void setAsyncContext(Object asyncContext);

/**
* Wrap a {@link Runnable} with necessary context capture.
*
* @param runnable The runnable to wrap.
* @return The wrapped runnable.
*/
Runnable wrapRunnable(Runnable runnable);

/**
* Wrap a {@link Callable} with necessary context capture.
*
* @param callable The callable to wrap.
* @return The wrapped callable.
*/
<T> Callable<T> wrapCallable(Callable<T> callable);

/** Intentionally exit the current actor. */
void exitActor();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ public static void run(
boolean hasReturn,
boolean ignoreReturn,
int argSize,
boolean useDirectByteBuffer,
int numJavaWorkerPerProcess) {
System.setProperty(
"ray.job.num-java-workers-per-process", String.valueOf(numJavaWorkerPerProcess));
boolean useDirectByteBuffer) {
System.setProperty("ray.raylet.startup-token", "0");
Ray.init();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ public static void main(String[] args) {
final int argSize = 0;
final boolean useDirectByteBuffer = false;
final boolean ignoreReturn = false;
final int numJavaWorkerPerProcess = 1;
ActorPerformanceTestBase.run(
args,
layers,
actorsPerLayer,
hasReturn,
ignoreReturn,
argSize,
useDirectByteBuffer,
numJavaWorkerPerProcess);
useDirectByteBuffer);
}
}
72 changes: 3 additions & 69 deletions java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import io.ray.runtime.functionmanager.FunctionManager;
import io.ray.runtime.functionmanager.PyFunctionDescriptor;
import io.ray.runtime.functionmanager.RayFunction;
import io.ray.runtime.generated.Common;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.ObjectRefImpl;
import io.ray.runtime.object.ObjectStore;
Expand All @@ -46,7 +45,6 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -67,13 +65,8 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {

private static ParallelActorContextImpl parallelActorContextImpl = new ParallelActorContextImpl();

/** Whether the required thread context is set on the current thread. */
final ThreadLocal<Boolean> isContextSet = ThreadLocal.withInitial(() -> false);

public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
setIsContextSet(rayConfig.workerMode == Common.WorkerType.DRIVER);
functionManager = new FunctionManager(rayConfig.codeSearchPath);
runtimeContext = new RuntimeContextImpl(this);
}

Expand Down Expand Up @@ -158,7 +151,7 @@ public <T> WaitResult<T> wait(

@Override
public ObjectRef call(RayFunc func, Object[] args, CallOptions options) {
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
RayFunction rayFunction = functionManager.getFunction(func);
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
Optional<Class<?>> returnType = rayFunction.getReturnType();
return callNormalFunction(functionDescriptor, args, returnType, options);
Expand All @@ -176,7 +169,7 @@ public ObjectRef call(PyFunction pyFunction, Object[] args, CallOptions options)
@Override
public ObjectRef callActor(
ActorHandle<?> actor, RayFunc func, Object[] args, CallOptions options) {
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
RayFunction rayFunction = functionManager.getFunction(func);
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
Optional<Class<?>> returnType = rayFunction.getReturnType();
return callActorFunction(actor, functionDescriptor, args, returnType, options);
Expand All @@ -201,8 +194,7 @@ public ObjectRef callActor(PyActorHandle pyActor, PyActorMethod pyActorMethod, O
public <T> ActorHandle<T> createActor(
RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) {
FunctionDescriptor functionDescriptor =
functionManager.getFunction(workerContext.getCurrentJobId(), actorFactoryFunc)
.functionDescriptor;
functionManager.getFunction(actorFactoryFunc).functionDescriptor;
return (ActorHandle<T>) createActorImpl(functionDescriptor, args, options);
}

Expand Down Expand Up @@ -256,31 +248,6 @@ public <T extends BaseActorHandle> T getActorHandle(ActorId actorId) {
return (T) taskSubmitter.getActor(actorId);
}

@Override
public void setAsyncContext(Object asyncContext) {
isContextSet.set(true);
}

@Override
public final Runnable wrapRunnable(Runnable runnable) {
Object asyncContext = getAsyncContext();
return () -> {
try (RayAsyncContextUpdater updater = new RayAsyncContextUpdater(asyncContext, this)) {
runnable.run();
}
};
}

@Override
public final <T> Callable<T> wrapCallable(Callable<T> callable) {
Object asyncContext = getAsyncContext();
return () -> {
try (RayAsyncContextUpdater updater = new RayAsyncContextUpdater(asyncContext, this)) {
return callable.call();
}
};
}

@Override
public ConcurrencyGroup createConcurrencyGroup(
String name, int maxConcurrency, List<RayFunc> funcs) {
Expand Down Expand Up @@ -387,34 +354,6 @@ private BaseActorHandle createActorImpl(
return actor;
}

/// An auto closable class that is used for updating the async context when invoking Ray APIs.
private static final class RayAsyncContextUpdater implements AutoCloseable {

private AbstractRayRuntime runtime;

private boolean oldIsContextSet;

private Object oldAsyncContext = null;

public RayAsyncContextUpdater(Object asyncContext, AbstractRayRuntime runtime) {
this.runtime = runtime;
oldIsContextSet = runtime.isContextSet.get();
if (oldIsContextSet) {
oldAsyncContext = runtime.getAsyncContext();
}
runtime.setAsyncContext(asyncContext);
}

@Override
public void close() {
if (oldIsContextSet) {
runtime.setAsyncContext(oldAsyncContext);
} else {
runtime.setIsContextSet(false);
}
}
}

abstract List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId);

@Override
Expand Down Expand Up @@ -446,11 +385,6 @@ public RuntimeContext getRuntimeContext() {
return runtimeContext;
}

@Override
public void setIsContextSet(boolean isContextSet) {
this.isContextSet.set(isContextSet);
}

/// A helper to validate if the prepared return ids is as expected.
void validatePreparedReturnIds(List<ObjectId> preparedReturnIds, List<ObjectId> realReturnIds) {
if (rayConfig.runMode == RunMode.CLUSTER) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ public ConcurrencyGroupImpl(String name, int maxConcurrency, List<RayFunc> funcs
funcs.forEach(
func -> {
RayFunction rayFunc =
((RayRuntimeInternal) Ray.internal())
.getFunctionManager()
.getFunction(Ray.getRuntimeContext().getCurrentJobId(), func);
((RayRuntimeInternal) Ray.internal()).getFunctionManager().getFunction(func);
functionDescriptors.add(rayFunc.getFunctionDescriptor());
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ public RayRuntime createRayRuntime() {
rayConfig.runMode == RunMode.SINGLE_PROCESS
? new RayDevRuntime(rayConfig)
: new RayNativeRuntime(rayConfig);
RayRuntimeInternal runtime =
rayConfig.numWorkersPerProcess > 1
? RayRuntimeProxy.newInstance(innerRuntime)
: innerRuntime;
RayRuntimeInternal runtime = innerRuntime;
runtime.start();
return runtime;
} catch (Exception e) {
Expand Down
20 changes: 2 additions & 18 deletions java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.ray.runtime;

import com.google.common.base.Preconditions;
import io.ray.api.BaseActorHandle;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
Expand All @@ -10,6 +9,7 @@
import io.ray.api.runtimecontext.ResourceValue;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.context.LocalModeWorkerContext;
import io.ray.runtime.functionmanager.FunctionManager;
import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.generated.Common.TaskSpec;
import io.ray.runtime.object.LocalModeObjectStore;
Expand All @@ -24,13 +24,9 @@
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RayDevRuntime extends AbstractRayRuntime {

private static final Logger LOGGER = LoggerFactory.getLogger(RayDevRuntime.class);

private AtomicInteger jobCounter = new AtomicInteger(0);

public RayDevRuntime(RayConfig rayConfig) {
Expand All @@ -49,6 +45,7 @@ public void start() {
taskExecutor = new LocalModeTaskExecutor(this);
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
objectStore = new LocalModeObjectStore(workerContext);
functionManager = new FunctionManager(rayConfig.codeSearchPath);
taskSubmitter =
new LocalModeTaskSubmitter(this, taskExecutor, (LocalModeObjectStore) objectStore);
((LocalModeObjectStore) objectStore)
Expand Down Expand Up @@ -90,19 +87,6 @@ public GcsClient getGcsClient() {
throw new UnsupportedOperationException("Ray doesn't have gcs client in local mode.");
}

@Override
public Object getAsyncContext() {
return new AsyncContext(((LocalModeWorkerContext) workerContext).getCurrentTask());
}

@Override
public void setAsyncContext(Object asyncContext) {
Preconditions.checkNotNull(asyncContext);
TaskSpec task = ((AsyncContext) asyncContext).task;
((LocalModeWorkerContext) workerContext).setCurrentTask(task);
super.setAsyncContext(asyncContext);
}

@Override
public Map<String, List<ResourceValue>> getAvailableResourceIds() {
throw new UnsupportedOperationException("Ray doesn't support get resources ids in local mode.");
Expand Down
Loading

0 comments on commit eb29895

Please sign in to comment.