@@ -183,7 +183,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
183183 ReplicationState .INIT_FOLLOW -> {
184184 log.info(" Starting shard tasks" )
185185 addIndexBlockForReplication()
186- startShardFollowTasks(emptyMap())
186+ FollowingState (startNewOrMissingShardTasks())
187+
187188 }
188189 ReplicationState .FOLLOWING -> {
189190 if (currentTaskState is FollowingState ) {
@@ -206,8 +207,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
206207 // Tasks need to be started
207208 state
208209 } else {
209- state = pollShardTaskStatus((followingTaskState as FollowingState ).shardReplicationTasks )
210- followingTaskState = startMissingShardTasks((followingTaskState as FollowingState ).shardReplicationTasks )
210+ state = pollShardTaskStatus()
211+ followingTaskState = FollowingState (startNewOrMissingShardTasks() )
211212 when (state) {
212213 is MonitoringState -> {
213214 updateMetadata()
@@ -285,24 +286,7 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
285286 clusterService.addListener(this )
286287 }
287288
288- private suspend fun startMissingShardTasks (shardTasks : Map <ShardId , PersistentTask <ShardReplicationParams >>): IndexReplicationState {
289- val persistentTasks = clusterService.state().metadata.custom<PersistentTasksCustomMetadata >(PersistentTasksCustomMetadata .TYPE )
290-
291- val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor .TASK_NAME , Predicate { true }).stream()
292- .map { task -> task.params as ShardReplicationParams }
293- .collect(Collectors .toList())
294-
295- val runningTasksForCurrentIndex = shardTasks.filter { entry -> runningShardTasks.find { task -> task.followerShardId == entry.key } != null }
296-
297- val numMissingTasks = shardTasks.size - runningTasksForCurrentIndex.size
298- if (numMissingTasks > 0 ) {
299- log.info(" Starting $numMissingTasks missing shard task(s)" )
300- return startShardFollowTasks(runningTasksForCurrentIndex)
301- }
302- return FollowingState (shardTasks)
303- }
304-
305- private suspend fun pollShardTaskStatus (shardTasks : Map <ShardId , PersistentTask <ShardReplicationParams >>): IndexReplicationState {
289+ private suspend fun pollShardTaskStatus (): IndexReplicationState {
306290 val failedShardTasks = findAllReplicationFailedShardTasks(followerIndexName, clusterService.state())
307291 if (failedShardTasks.isNotEmpty()) {
308292 log.info(" Failed shard tasks - " , failedShardTasks)
@@ -343,11 +327,16 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
343327 registerCloseListeners()
344328 val clusterState = clusterService.state()
345329 val persistentTasks = clusterState.metadata.custom<PersistentTasksCustomMetadata >(PersistentTasksCustomMetadata .TYPE )
346- val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor .TASK_NAME , Predicate { true }).stream()
330+
331+ val followerShardIds = clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards()
332+ .map { shard -> shard.value.shardId }
333+ .stream().collect(Collectors .toSet())
334+ val runningShardTasksForIndex = persistentTasks.findTasks(ShardReplicationExecutor .TASK_NAME , Predicate { true }).stream()
347335 .map { task -> task.params as ShardReplicationParams }
336+ .filter {taskParam -> followerShardIds.contains(taskParam.followerShardId) }
348337 .collect(Collectors .toList())
349338
350- if (runningShardTasks .size == 0 ) {
339+ if (runningShardTasksForIndex .size != followerShardIds.size ) {
351340 return InitFollowState
352341 }
353342
@@ -696,19 +685,27 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
696685 }
697686 }
698687
699- private suspend fun
700- startShardFollowTasks (tasks : Map <ShardId , PersistentTask <ShardReplicationParams >>): FollowingState {
688+ suspend fun startNewOrMissingShardTasks (): Map <ShardId , PersistentTask <ShardReplicationParams >> {
701689 assert (clusterService.state().routingTable.hasIndex(followerIndexName)) { " Can't find index $followerIndexName " }
702690 val shards = clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards()
703- val newTasks = shards.map {
691+ val persistentTasks = clusterService.state().metadata.custom<PersistentTasksCustomMetadata >(PersistentTasksCustomMetadata .TYPE )
692+ val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor .TASK_NAME , Predicate { true }).stream()
693+ .map { task -> task as PersistentTask <ShardReplicationParams > }
694+ .filter { task -> task.params!! .followerShardId.indexName == followerIndexName}
695+ .collect(Collectors .toMap(
696+ {t: PersistentTask <ShardReplicationParams > -> t.params!! .followerShardId},
697+ {t: PersistentTask <ShardReplicationParams > -> t}))
698+
699+ val tasks = shards.map {
704700 it.value.shardId
705701 }.associate { shardId ->
706- val task = tasks .getOrElse(shardId) {
702+ val task = runningShardTasks .getOrElse(shardId) {
707703 startReplicationTask(ShardReplicationParams (leaderAlias, ShardId (leaderIndex, shardId.id), shardId))
708704 }
709705 return @associate shardId to task
710706 }
711- return FollowingState (newTasks)
707+
708+ return tasks
712709 }
713710
714711 private suspend fun cancelRestore () {
0 commit comments