Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
ReplicationState.INIT_FOLLOW -> {
log.info("Starting shard tasks")
addIndexBlockForReplication()
startShardFollowTasks(emptyMap())
FollowingState(startNewOrMissingShardTasks())

}
ReplicationState.FOLLOWING -> {
if (currentTaskState is FollowingState) {
Expand All @@ -206,8 +207,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
// Tasks need to be started
state
} else {
state = pollShardTaskStatus((followingTaskState as FollowingState).shardReplicationTasks)
followingTaskState = startMissingShardTasks((followingTaskState as FollowingState).shardReplicationTasks)
state = pollShardTaskStatus()
followingTaskState = FollowingState(startNewOrMissingShardTasks())
when (state) {
is MonitoringState -> {
updateMetadata()
Expand Down Expand Up @@ -284,24 +285,7 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
clusterService.addListener(this)
}

private suspend fun startMissingShardTasks(shardTasks: Map<ShardId, PersistentTask<ShardReplicationParams>>): IndexReplicationState {
val persistentTasks = clusterService.state().metadata.custom<PersistentTasksCustomMetadata>(PersistentTasksCustomMetadata.TYPE)

val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()
.map { task -> task.params as ShardReplicationParams }
.collect(Collectors.toList())

val runningTasksForCurrentIndex = shardTasks.filter { entry -> runningShardTasks.find { task -> task.followerShardId == entry.key } != null}

val numMissingTasks = shardTasks.size - runningTasksForCurrentIndex.size
if (numMissingTasks > 0) {
log.info("Starting $numMissingTasks missing shard task(s)")
return startShardFollowTasks(runningTasksForCurrentIndex)
}
return FollowingState(shardTasks)
}

private suspend fun pollShardTaskStatus(shardTasks: Map<ShardId, PersistentTask<ShardReplicationParams>>): IndexReplicationState {
private suspend fun pollShardTaskStatus(): IndexReplicationState {
val failedShardTasks = findAllReplicationFailedShardTasks(followerIndexName, clusterService.state())
if (failedShardTasks.isNotEmpty()) {
log.info("Failed shard tasks - ", failedShardTasks)
Expand Down Expand Up @@ -342,11 +326,16 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
registerCloseListeners()
val clusterState = clusterService.state()
val persistentTasks = clusterState.metadata.custom<PersistentTasksCustomMetadata>(PersistentTasksCustomMetadata.TYPE)
val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()

val followerShardIds = clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards()
.map { shard -> shard.value.shardId }
.stream().collect(Collectors.toSet())
val runningShardTasksForIndex = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()
.map { task -> task.params as ShardReplicationParams }
.filter {taskParam -> followerShardIds.contains(taskParam.followerShardId) }
.collect(Collectors.toList())

if (runningShardTasks.size == 0) {
if (runningShardTasksForIndex.size != followerShardIds.size) {
return InitFollowState
}

Expand Down Expand Up @@ -690,19 +679,27 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
}
}

private suspend fun
startShardFollowTasks(tasks: Map<ShardId, PersistentTask<ShardReplicationParams>>): FollowingState {
suspend fun startNewOrMissingShardTasks(): Map<ShardId, PersistentTask<ShardReplicationParams>> {
assert(clusterService.state().routingTable.hasIndex(followerIndexName)) { "Can't find index $followerIndexName" }
val shards = clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards()
val newTasks = shards.map {
val persistentTasks = clusterService.state().metadata.custom<PersistentTasksCustomMetadata>(PersistentTasksCustomMetadata.TYPE)
val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()
.map { task -> task as PersistentTask<ShardReplicationParams> }
.filter { task -> task.params!!.followerShardId.indexName == followerIndexName}
.collect(Collectors.toMap(
{t: PersistentTask<ShardReplicationParams> -> t.params!!.followerShardId},
{t: PersistentTask<ShardReplicationParams> -> t}))

val tasks = shards.map {
it.value.shardId
}.associate { shardId ->
val task = tasks.getOrElse(shardId) {
val task = runningShardTasks.getOrElse(shardId) {
startReplicationTask(ShardReplicationParams(leaderAlias, ShardId(leaderIndex, shardId.id), shardId))
}
return@associate shardId to task
}
return FollowingState(newTasks)

return tasks
}

private suspend fun cancelRestore() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ import org.opensearch.tasks.TaskManager
import org.opensearch.test.ClusterServiceUtils
import org.opensearch.test.ClusterServiceUtils.setState
import org.opensearch.test.OpenSearchTestCase
import org.opensearch.test.OpenSearchTestCase.assertBusy
import org.opensearch.threadpool.TestThreadPool
import java.util.*
import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -150,6 +149,66 @@ class IndexReplicationTaskTests : OpenSearchTestCase() {

}

fun testStartNewShardTasks() = runBlocking {
val replicationTask: IndexReplicationTask = spy(createIndexReplicationTask())
var taskManager = Mockito.mock(TaskManager::class.java)
replicationTask.setPersistent(taskManager)
var rc = ReplicationContext(followerIndex)
var rm = ReplicationMetadata(connectionName, ReplicationStoreMetadataType.INDEX.name, ReplicationOverallState.RUNNING.name, "reason", rc, rc, Settings.EMPTY)
replicationTask.setReplicationMetadata(rm)

// Build cluster state
val indices: MutableList<String> = ArrayList()
indices.add(followerIndex)
var metadata = Metadata.builder()
.put(IndexMetadata.builder(REPLICATION_CONFIG_SYSTEM_INDEX).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
.put(IndexMetadata.builder(followerIndex).settings(settings(Version.CURRENT)).numberOfShards(2).numberOfReplicas(0))
.build()
var routingTableBuilder = RoutingTable.builder()
.addAsNew(metadata.index(REPLICATION_CONFIG_SYSTEM_INDEX))
.addAsNew(metadata.index(followerIndex))
var newClusterState = ClusterState.builder(clusterService.state()).routingTable(routingTableBuilder.build()).build()
setState(clusterService, newClusterState)

// Try starting shard tasks
val shardTasks = replicationTask.startNewOrMissingShardTasks()
assert(shardTasks.size == 2)
}


fun testStartMissingShardTasks() = runBlocking {
val replicationTask: IndexReplicationTask = spy(createIndexReplicationTask())
var taskManager = Mockito.mock(TaskManager::class.java)
replicationTask.setPersistent(taskManager)
var rc = ReplicationContext(followerIndex)
var rm = ReplicationMetadata(connectionName, ReplicationStoreMetadataType.INDEX.name, ReplicationOverallState.RUNNING.name, "reason", rc, rc, Settings.EMPTY)
replicationTask.setReplicationMetadata(rm)

// Build cluster state
val indices: MutableList<String> = ArrayList()
indices.add(followerIndex)

val tasks = PersistentTasksCustomMetadata.builder()
var sId = ShardId(Index(followerIndex, "_na_"), 0)
tasks.addTask<PersistentTaskParams>( "replication:0", ShardReplicationExecutor.TASK_NAME, ShardReplicationParams("remoteCluster", sId, sId),
PersistentTasksCustomMetadata.Assignment("other_node_", "test assignment on other node"))

var metadata = Metadata.builder()
.put(IndexMetadata.builder(REPLICATION_CONFIG_SYSTEM_INDEX).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
.put(IndexMetadata.builder(followerIndex).settings(settings(Version.CURRENT)).numberOfShards(2).numberOfReplicas(0))
.putCustom(PersistentTasksCustomMetadata.TYPE, tasks.build())
.build()
var routingTableBuilder = RoutingTable.builder()
.addAsNew(metadata.index(REPLICATION_CONFIG_SYSTEM_INDEX))
.addAsNew(metadata.index(followerIndex))
var newClusterState = ClusterState.builder(clusterService.state()).routingTable(routingTableBuilder.build()).build()
setState(clusterService, newClusterState)

// Try starting shard tasks
val shardTasks = replicationTask.startNewOrMissingShardTasks()
assert(shardTasks.size == 2)
}

private fun createIndexReplicationTask() : IndexReplicationTask {
var threadPool = TestThreadPool("IndexReplicationTask")
//Hack Alert : Though it is meant to force rejection , this is to make overallTaskScope not null
Expand Down