Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Actor Group API #7712

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
mock implementation of placement group
  • Loading branch information
yuyiming committed Apr 20, 2020
commit b530030004ad4332d0f142ed0f5c73c45e494c4c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import java.util.List;
import java.util.concurrent.Callable;
import org.ray.api.BaseActor;
import org.ray.api.PlacementGroup;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
Expand All @@ -20,7 +19,6 @@
import org.ray.api.id.ObjectId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
import org.ray.api.options.PlacementGroupOptions;
import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtimecontext.RuntimeContext;
import org.ray.runtime.config.RayConfig;
Expand Down Expand Up @@ -153,12 +151,6 @@ public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
}

@Override
public PlacementGroup createPlacementGroup(PlacementGroupOptions options) {
// TODO(yuyiming): impl
return null;
}

private void checkPyArguments(Object[] args) {
for (Object arg : args) {
Preconditions.checkArgument(
Expand Down Expand Up @@ -213,6 +205,9 @@ private BaseActor createActorImpl(FunctionDescriptor functionDescriptor,
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
}
if (gcsClient != null) {
gcsClient.getGcsServer().placeActor(options);
}
BaseActor actor = taskSubmitter.createActor(functionDescriptor, functionArgs, options);
return actor;
}
Expand Down
7 changes: 7 additions & 0 deletions java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.BaseActor;
import org.ray.api.PlacementGroup;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.PlacementGroupOptions;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.LocalModeWorkerContext;
import org.ray.runtime.functionmanager.FunctionManager;
Expand Down Expand Up @@ -53,6 +55,11 @@ public void killActor(BaseActor actor, boolean noReconstruction) {
throw new UnsupportedOperationException();
}

@Override
public PlacementGroup createPlacementGroup(PlacementGroupOptions options) {
return null;
}

@Override
public Object getAsyncContext() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,

@Override
public PlacementGroup createPlacementGroup(PlacementGroupOptions options) {
// TODO(yuyiming): impl
return null;
return getCurrentRuntime().createPlacementGroup(options);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.ray.api.BaseActor;
import org.ray.api.PlacementGroup;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.PlacementGroupOptions;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.NativeWorkerContext;
import org.ray.runtime.functionmanager.FunctionManager;
Expand Down Expand Up @@ -139,6 +141,11 @@ public void killActor(BaseActor actor, boolean noReconstruction) {
nativeKillActor(nativeCoreWorkerPointer, actor.getId().getBytes(), noReconstruction);
}

@Override
public PlacementGroup createPlacementGroup(PlacementGroupOptions options) {
return gcsClient.createPlacementGroup(options);
}

@Override
public Object getAsyncContext() {
return null;
Expand Down
14 changes: 14 additions & 0 deletions java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.PlacementGroup;
import org.ray.api.id.ActorId;
import org.ray.api.id.BaseId;
import org.ray.api.id.JobId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.PlacementGroupOptions;
import org.ray.api.runtimecontext.NodeInfo;
import org.ray.runtime.generated.Gcs;
import org.ray.runtime.generated.Gcs.ActorCheckpointIdData;
import org.ray.runtime.generated.Gcs.GcsNodeInfo;
import org.ray.runtime.generated.Gcs.TablePrefix;
import org.ray.runtime.mockgcsserver.MockGcsServer;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -34,6 +37,8 @@ public class GcsClient {

private List<RedisClient> shards;

private MockGcsServer gcsServer;

public GcsClient(String redisAddress, String redisPassword) {
primary = new RedisClient(redisAddress, redisPassword);
int numShards = 0;
Expand All @@ -50,6 +55,8 @@ public GcsClient(String redisAddress, String redisPassword) {
shards = shardAddresses.stream().map((byte[] address) -> {
return new RedisClient(new String(address), redisPassword);
}).collect(Collectors.toList());

gcsServer = new MockGcsServer(this);
}

public List<NodeInfo> getAllNodeInfo() {
Expand Down Expand Up @@ -177,4 +184,11 @@ private RedisClient getShardClient(BaseId key) {
shards.size()));
}

public PlacementGroup createPlacementGroup(PlacementGroupOptions options) {
return gcsServer.createPlacementGroup(options);
}

public MockGcsServer getGcsServer() {
return gcsServer;
}
}
18 changes: 18 additions & 0 deletions java/runtime/src/main/java/org/ray/runtime/group/NativeBundle.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.ray.runtime.group;

import org.ray.api.Bundle;
import org.ray.api.id.GroupId;

public class NativeBundle implements Bundle {

private final GroupId id;

public NativeBundle(GroupId id) {
this.id = id;
}

@Override
public GroupId getId() {
return id;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.ray.runtime.group;

import java.util.List;
import org.ray.api.Bundle;
import org.ray.api.PlacementGroup;
import org.ray.api.id.GroupId;

public class NativePlacementGroup implements PlacementGroup {

public final String name;

public final GroupId id;

public final List<Bundle> bundles;

public NativePlacementGroup(String name, GroupId id, List<Bundle> bundles) {
this.name = name;
this.id = id;
this.bundles = bundles;
}

@Override
public GroupId getId() {
return id;
}

@Override
public List<Bundle> getBundles() {
return bundles;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.ray.runtime.mockgcsserver;

import java.util.List;
import org.ray.api.id.GroupId;

public class BundleTable {

public final GroupId id;

public final List<UnitTable> units;

public BundleTable(GroupId id, List<UnitTable> units) {
this.id = id;
this.units = units;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package org.ray.runtime.mockgcsserver;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.IntStream;
import org.ray.api.Bundle;
import org.ray.api.PlacementGroup;
import org.ray.api.Ray;
import org.ray.api.id.GroupId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.BundleOptions;
import org.ray.api.options.PlacementGroupOptions;
import org.ray.api.options.PlacementStrategy;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.group.NativeBundle;
import org.ray.runtime.group.NativePlacementGroup;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MockGcsServer {

private static final Logger LOGGER = LoggerFactory.getLogger(MockGcsServer.class);

public static final Double LABEL_RESOURCE_AMOUNT = Math.pow(2, 15);

private GcsClient gcsClient;

private Map<GroupId, PlacementGroupTable> groups = new HashMap<>();

private Map<GroupId, BundleTable> bundles = new HashMap<>();

public MockGcsServer(GcsClient gcsClient) {
this.gcsClient = gcsClient;
}

public void placeActor(ActorCreationOptions options) {
if (options == null || options.bundle == null) {
return;
}

BundleTable bundleTable = bundles.get(options.bundle.getId());
Preconditions.checkNotNull(bundleTable, "Bundle %s does not exist.", options.bundle.getId());
String label = allocateResources(bundleTable, options.resources);
Preconditions
.checkNotNull(label, "There is not enough resources in Bundle %s.", options.bundle.getId());
options.resources.put(label, 1.0);
LOGGER.info("Placed Actor in Bundle {} in Node labeled {}.", bundleTable.id, label);
}

private String allocateResources(BundleTable bundleTable, Map<String, Double> requirements) {
LOGGER.info("Trying to allocate resources in Bundle {}.", bundleTable.id);
for (UnitTable unitTable : bundleTable.units) {
String label = allocateResources(unitTable, requirements);
if (label != null) {
return label;
}
}
return null;
}

private String allocateResources(UnitTable unitTable, Map<String, Double> requirements) {
boolean canAllocate = true;
for (Map.Entry<String, Double> entry : requirements.entrySet()) {
if (unitTable.availableResources.getOrDefault(entry.getKey(), 0.0) < entry.getValue()) {
LOGGER.info("Resource \"{}\" = {} is not enough for requirement {}.", entry.getKey(),
unitTable.availableResources.getOrDefault(entry.getKey(), 0.0), entry.getValue());
canAllocate = false;
break;
}
}

if (canAllocate) {
requirements.forEach(
(k, v) -> unitTable.availableResources.computeIfPresent(k, (name, amount) -> amount - v));
return unitTable.label;
} else {
return null;
}
}

public PlacementGroup createPlacementGroup(PlacementGroupOptions options) {
LOGGER.info("Creating a placement group with name \"{}\".", options.name);
PlacementGroupTable groupTable = buildPlacementGroupTable(options);
groups.put(groupTable.id, groupTable);
LOGGER.info("Created placement group {} with name \"{}\".", groupTable.id, groupTable.name);
return buildPlacementGroup(groupTable);
}

private PlacementGroupTable buildPlacementGroupTable(PlacementGroupOptions options) {
// get node info
List<NodeResource> nodes = getAllNodeResource();
int nodeCount = nodes.size();
int nodeIndex = new Random().nextInt(nodeCount);

// construct bundles
List<BundleTable> bundleTables = new ArrayList<>();
for (BundleOptions bundleOptions : options.bundles) {
List<UnitTable> unitTable = new ArrayList<>();
// allocate resources one node by one node
int remainingUnits = bundleOptions.unitCount;
for (int i = 0; i < nodeCount; i++) {
NodeResource node = nodes.get(nodeIndex % nodeCount);
int allocated = preallocateResources(node, bundleOptions.resources, remainingUnits);
IntStream.range(0, allocated).forEach(j -> unitTable
.add(new UnitTable(new HashMap<>(bundleOptions.resources), node.getNodeLabel())));
remainingUnits -= allocated;

if (remainingUnits == 0) {
break;
}
++nodeIndex;
}
if (remainingUnits > 0) {
throw new RuntimeException("There are not enough resources in this cluster.");
}
BundleTable bundleTable = new BundleTable(GroupId.fromRandom(), unitTable);
bundles.put(bundleTable.id, bundleTable);
bundleTables.add(bundleTable);
LOGGER.info("Added Bundle {}.", bundleTable.id);

if (options.strategy == PlacementStrategy.SPREAD) {
++nodeIndex;
}
}

return new PlacementGroupTable(GroupId.fromRandom(), options.name, options.strategy,
bundleTables);
}

private List<NodeResource> getAllNodeResource() {
List<NodeResource> nodes = new ArrayList<>();
gcsClient.getAllNodeInfo().forEach(nodeInfo -> {
if (!nodeInfo.isAlive) {
return;
}
LOGGER.info("Detected Node {}.", nodeInfo.nodeId);
NodeResource node = new NodeResource(nodeInfo.nodeId, nodeInfo.resources);
nodes.add(node);

// set node label
Ray.internal().setResource(node.getNodeLabel(), LABEL_RESOURCE_AMOUNT, nodeInfo.nodeId);
});
return nodes;
}

private int preallocateResources(NodeResource node, Map<String, Double> unitResources,
int remainingUnits) {
int minUnits = Integer.MAX_VALUE;
for (Map.Entry<String, Double> entry : unitResources.entrySet()) {
minUnits = Integer.min(minUnits, Double
.valueOf(node.remainingResources.getOrDefault(entry.getKey(), 0.0) / entry.getValue())
.intValue());
if (minUnits == 0) {
break;
}
}

int allocatedUnits = Integer.min(minUnits, remainingUnits);
if (allocatedUnits > 0) {
unitResources.forEach((k, v) -> node.remainingResources
.compute(k, (name, amount) -> amount - allocatedUnits * v));
}

LOGGER.info("Allocated {} / {} units in Node {}.", allocatedUnits, remainingUnits, node.nodeId);
return allocatedUnits;
}

private PlacementGroup buildPlacementGroup(PlacementGroupTable groupTable) {
List<Bundle> bundles = new ArrayList<>();
groupTable.bundles.forEach(bundle -> bundles.add(new NativeBundle(bundle.id)));
return new NativePlacementGroup(groupTable.name, groupTable.id, bundles);
}
}
Loading