Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

package io.mantisrx.master.resourcecluster;

import akka.actor.AbstractActor;
import akka.actor.AbstractActorWithTimers;
import akka.actor.ActorRef;
import akka.actor.Cancellable;
import akka.actor.Props;
import akka.actor.Status;
import akka.japi.pf.ReceiveBuilder;
Expand Down Expand Up @@ -68,7 +67,7 @@
*/
@ToString(of = {"clusterID"})
@Slf4j
class ResourceClusterActor extends AbstractActor {
class ResourceClusterActor extends AbstractActorWithTimers {

private final Duration heartbeatTimeout;
private final Duration assignmentTimeout;
Expand All @@ -80,8 +79,8 @@ class ResourceClusterActor extends AbstractActor {
private final ClusterID clusterID;
private final MantisJobStore mantisJobStore;

static Props props(final ClusterID clusterID, final Duration heartbeatTimeout, Duration assignmentTimeout, Clock clock, RpcService rpcService) {
return Props.create(ResourceClusterActor.class, clusterID, heartbeatTimeout, assignmentTimeout, clock, rpcService);
static Props props(final ClusterID clusterID, final Duration heartbeatTimeout, Duration assignmentTimeout, Clock clock, RpcService rpcService, MantisJobStore mantisJobStore) {
return Props.create(ResourceClusterActor.class, clusterID, heartbeatTimeout, assignmentTimeout, clock, rpcService, mantisJobStore);
}

ResourceClusterActor(
Expand Down Expand Up @@ -113,6 +112,7 @@ public Receive createReceive() {
.match(GetTaskExecutorStatusRequest.class, req -> sender().tell(getTaskExecutorStatus(req.getTaskExecutorID()), self()))
.match(Ack.class, ack -> log.info("Received ack from {}", sender()))

.match(TaskExecutorAssignmentTimeout.class, this::onTaskExecutorAssignmentTimeout)
.match(TaskExecutorRegistration.class, this::onTaskExecutorRegistration)
.match(InitializeTaskExecutorRequest.class, this::onTaskExecutorInitialization)
.match(TaskExecutorHeartbeat.class, this::onHeartbeat)
Expand Down Expand Up @@ -278,15 +278,10 @@ private void onTaskExecutorAssignmentRequest(TaskExecutorAssignmentRequest reque
matchedExecutor.get().getValue().onAssignment(request.getWorkerId());
// let's give some time for the assigned executor to be scheduled work. otherwise, the assigned executor
// will be returned back to the pool.
context()
.system()
.scheduler()
.scheduleOnce(
assignmentTimeout,
self(),
new TaskExecutorAssignmentTimeout(matchedExecutor.get().getKey()),
getContext().getDispatcher(),
self());
getTimers().startSingleTimer(
"Assignment-" + matchedExecutor.get().getKey().toString(),
new TaskExecutorAssignmentTimeout(matchedExecutor.get().getKey()),
assignmentTimeout);
sender().tell(matchedExecutor.get().getKey(), self());
} else {
sender().tell(new Status.Failure(new NoResourceAvailableException(
Expand Down Expand Up @@ -350,10 +345,14 @@ private void disconnectTaskExecutor(TaskExecutorID taskExecutorID) {
boolean stateChange = state.onDisconnection();
if (stateChange) {
taskExecutorsReadyToPerformWork.remove(taskExecutorID);
state.setNextHeartbeatChecker(null);
getTimers().cancel(getHeartbeatTimerFor(taskExecutorID));
}
}

private String getHeartbeatTimerFor(TaskExecutorID taskExecutorID) {
return "Heartbeat-" + taskExecutorID.toString();
}

private void onTaskExecutorHeartbeatTimeout(HeartbeatTimeout timeout) {
setupTaskExecutorStateIfNecessary(timeout.getTaskExecutorID());
try {
Expand All @@ -376,17 +375,10 @@ private void setupTaskExecutorStateIfNecessary(TaskExecutorID taskExecutorID) {

private void updateHeartbeatTimeout(TaskExecutorID taskExecutorID) {
final TaskExecutorState state = taskExecutorStateMap.get(taskExecutorID);
final Cancellable nextHeartbeatChecker =
context()
.system()
.scheduler()
.scheduleOnce(
heartbeatTimeout,
self(),
new HeartbeatTimeout(taskExecutorID, state.getLastActivity()),
getContext().getDispatcher(),
self());
state.setNextHeartbeatChecker(nextHeartbeatChecker);
getTimers().startSingleTimer(
getHeartbeatTimerFor(taskExecutorID),
new HeartbeatTimeout(taskExecutorID, state.getLastActivity()),
heartbeatTimeout);
}

@Value
Expand Down Expand Up @@ -494,8 +486,6 @@ enum AvailabilityState {
private AvailabilityState availabilityState;
@Nullable
private WorkerId workerId;
@Nullable
private Cancellable nextHeartbeatChecker;
private Instant lastActivity;
private final Clock clock;
private final RpcService rpcService;
Expand All @@ -507,7 +497,6 @@ static TaskExecutorState of(Clock clock, RpcService rpcService) {
null,
null,
null,
null,
clock.instant(),
clock,
rpcService);
Expand Down Expand Up @@ -689,14 +678,6 @@ private void throwInvalidTransition(WorkerId workerId) throws IllegalStateExcept
this.availabilityState, this.workerId, workerId));
}

private void setNextHeartbeatChecker(@Nullable Cancellable nextHeartbeatChecker) {
if (this.nextHeartbeatChecker != null) {
this.nextHeartbeatChecker.cancel();
}

this.nextHeartbeatChecker = nextHeartbeatChecker;
}

private void updateTicker() {
this.lastActivity = clock.instant();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright 2022 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.mantisrx.master.resourcecluster;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.Props;
import akka.testkit.javadsl.TestKit;
import io.mantisrx.common.Ack;
import io.mantisrx.common.WorkerPorts;
import io.mantisrx.runtime.MachineDefinition;
import io.mantisrx.server.core.TestingRpcService;
import io.mantisrx.server.core.domain.WorkerId;
import io.mantisrx.server.master.persistence.MantisJobStore;
import io.mantisrx.server.master.resourcecluster.ClusterID;
import io.mantisrx.server.master.resourcecluster.ResourceCluster;
import io.mantisrx.server.master.resourcecluster.ResourceClusterTaskExecutorMapper;
import io.mantisrx.server.master.resourcecluster.TaskExecutorHeartbeat;
import io.mantisrx.server.master.resourcecluster.TaskExecutorID;
import io.mantisrx.server.master.resourcecluster.TaskExecutorRegistration;
import io.mantisrx.server.master.resourcecluster.TaskExecutorReport;
import io.mantisrx.server.worker.TaskExecutorGateway;
import io.mantisrx.shaded.com.google.common.collect.ImmutableList;
import java.time.Clock;
import java.time.Duration;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.Matchers;

public class ResourceClusterActorTest {
private static final TaskExecutorID TASK_EXECUTOR_ID = TaskExecutorID.of("taskExecutorId");
private static final String TASK_EXECUTOR_ADDRESS = "address";
private static final ClusterID CLUSTER_ID = ClusterID.of("clusterId");
private static final Duration heartbeatTimeout = Duration.ofSeconds(10);
private static final Duration assignmentTimeout = Duration.ofSeconds(1);
private static final String HOST_NAME = "hostname";
private static final WorkerPorts WORKER_PORTS = new WorkerPorts(1, 2, 3, 4, 5);
private static final MachineDefinition MACHINE_DEFINITION =
new MachineDefinition(2f, 2014, 128.0, 1024, 1);
private static final TaskExecutorRegistration TASK_EXECUTOR_REGISTRATION =
new TaskExecutorRegistration(
TASK_EXECUTOR_ID,
CLUSTER_ID,
TASK_EXECUTOR_ADDRESS,
HOST_NAME,
WORKER_PORTS,
MACHINE_DEFINITION);
private static final WorkerId WORKER_ID =
WorkerId.fromIdUnsafe("late-sine-function-tutorial-1-worker-0-1");

static ActorSystem actorSystem;

private final TestingRpcService rpcService = new TestingRpcService();
private final TaskExecutorGateway gateway = mock(TaskExecutorGateway.class);

private MantisJobStore mantisJobStore;
private ResourceClusterTaskExecutorMapper mapper;
private ActorRef resourceClusterActor;
private ResourceCluster resourceCluster;

@BeforeClass
public static void setup() {
actorSystem = ActorSystem.create();
}

@AfterClass
public static void teardown() {
TestKit.shutdownActorSystem(actorSystem);
actorSystem = null;
}

@Before
public void setupRpcService() {
rpcService.registerGateway(TASK_EXECUTOR_ADDRESS, gateway);
mantisJobStore = mock(MantisJobStore.class);
mapper = ResourceClusterTaskExecutorMapper.inMemory();
}

@Before
public void setupActor() {
final Props props =
ResourceClusterActor.props(
CLUSTER_ID,
heartbeatTimeout,
assignmentTimeout,
Clock.systemDefaultZone(),
rpcService,
mantisJobStore);

resourceClusterActor = actorSystem.actorOf(props);
resourceCluster =
new ResourceClusterAkkaImpl(
resourceClusterActor,
Duration.ofSeconds(1),
CLUSTER_ID,
mapper);
}

@Test
public void testRegistration() throws Exception {
assertEquals(Ack.getInstance(), resourceCluster.registerTaskExecutor(TASK_EXECUTOR_REGISTRATION).get());
assertEquals(ImmutableList.of(TASK_EXECUTOR_ID), resourceCluster.getRegisteredTaskExecutors().get());
}

@Test
public void testInitializationAfterRestart() throws Exception {
when(mantisJobStore.getTaskExecutor(Matchers.eq(TASK_EXECUTOR_ID))).thenReturn(TASK_EXECUTOR_REGISTRATION);
assertEquals(
Ack.getInstance(),
resourceCluster.initializeTaskExecutor(TASK_EXECUTOR_ID, WORKER_ID).get());
assertEquals(ImmutableList.of(TASK_EXECUTOR_ID), resourceCluster.getBusyTaskExecutors().get());
}

@Test
public void testGetFreeTaskExecutors() throws Exception {
assertEquals(Ack.getInstance(), resourceCluster.registerTaskExecutor(TASK_EXECUTOR_REGISTRATION).get());
assertEquals(Ack.getInstance(),
resourceCluster
.heartBeatFromTaskExecutor(
new TaskExecutorHeartbeat(
TASK_EXECUTOR_ID,
CLUSTER_ID,
TaskExecutorReport.available())).get());
assertEquals(TASK_EXECUTOR_ID, resourceCluster.getTaskExecutorFor(MACHINE_DEFINITION, WORKER_ID).get());
assertEquals(ImmutableList.of(), resourceCluster.getAvailableTaskExecutors().get());
assertEquals(ImmutableList.of(TASK_EXECUTOR_ID), resourceCluster.getRegisteredTaskExecutors().get());
}

@Test
public void testAssignmentTimeout() throws Exception {
assertEquals(Ack.getInstance(), resourceCluster.registerTaskExecutor(TASK_EXECUTOR_REGISTRATION).get());
assertEquals(Ack.getInstance(),
resourceCluster
.heartBeatFromTaskExecutor(
new TaskExecutorHeartbeat(
TASK_EXECUTOR_ID,
CLUSTER_ID,
TaskExecutorReport.available())).get());
assertEquals(TASK_EXECUTOR_ID, resourceCluster.getTaskExecutorFor(MACHINE_DEFINITION, WORKER_ID).get());
assertEquals(ImmutableList.of(), resourceCluster.getAvailableTaskExecutors().get());
Thread.sleep(2000);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be done using a mock Clock?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I can make Akka schedule something (such as the assignment timeout in this case) based on a mock clock. Let me take a look into this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if we could. All these Thread.sleep() start to add up over time.

assertEquals(ImmutableList.of(TASK_EXECUTOR_ID), resourceCluster.getAvailableTaskExecutors().get());
assertEquals(TASK_EXECUTOR_ID, resourceCluster.getTaskExecutorFor(MACHINE_DEFINITION, WORKER_ID).get());
}
}