@@ -183,7 +183,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
183183                    ReplicationState .INIT_FOLLOW  ->  {
184184                        log.info(" Starting shard tasks" 
185185                        addIndexBlockForReplication()
186-                         startShardFollowTasks(emptyMap())
186+                         FollowingState (startNewOrMissingShardTasks())
187+ 
187188                    }
188189                    ReplicationState .FOLLOWING  ->  {
189190                        if  (currentTaskState is  FollowingState ) {
@@ -206,8 +207,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
206207                            //  Tasks need to be started
207208                            state
208209                        } else  {
209-                             state =  pollShardTaskStatus((followingTaskState  as   FollowingState ).shardReplicationTasks )
210-                             followingTaskState =  startMissingShardTasks((followingTaskState  as   FollowingState ).shardReplicationTasks )
210+                             state =  pollShardTaskStatus()
211+                             followingTaskState =  FollowingState (startNewOrMissingShardTasks() )
211212                            when  (state) {
212213                                is  MonitoringState  ->  {
213214                                    updateMetadata()
@@ -285,24 +286,7 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
285286        clusterService.addListener(this )
286287    }
287288
288-     private  suspend  fun  startMissingShardTasks (shardTasks :  Map <ShardId , PersistentTask <ShardReplicationParams >>): IndexReplicationState  {
289-         val  persistentTasks =  clusterService.state().metadata.custom<PersistentTasksCustomMetadata >(PersistentTasksCustomMetadata .TYPE )
290- 
291-         val  runningShardTasks =  persistentTasks.findTasks(ShardReplicationExecutor .TASK_NAME , Predicate  { true  }).stream()
292-                 .map { task ->  task.params as  ShardReplicationParams  }
293-                 .collect(Collectors .toList())
294- 
295-         val  runningTasksForCurrentIndex =  shardTasks.filter { entry ->  runningShardTasks.find { task ->  task.followerShardId ==  entry.key } !=  null }
296- 
297-         val  numMissingTasks =  shardTasks.size -  runningTasksForCurrentIndex.size
298-         if  (numMissingTasks >  0 ) {
299-             log.info(" Starting $numMissingTasks  missing shard task(s)" 
300-             return  startShardFollowTasks(runningTasksForCurrentIndex)
301-         }
302-         return  FollowingState (shardTasks)
303-     }
304- 
305-     private  suspend  fun  pollShardTaskStatus (shardTasks :  Map <ShardId , PersistentTask <ShardReplicationParams >>): IndexReplicationState  {
289+     private  suspend  fun  pollShardTaskStatus (): IndexReplicationState  {
306290        val  failedShardTasks =  findAllReplicationFailedShardTasks(followerIndexName, clusterService.state())
307291        if  (failedShardTasks.isNotEmpty()) {
308292            log.info(" Failed shard tasks - " 
@@ -343,11 +327,16 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
343327        registerCloseListeners()
344328        val  clusterState =  clusterService.state()
345329        val  persistentTasks =  clusterState.metadata.custom<PersistentTasksCustomMetadata >(PersistentTasksCustomMetadata .TYPE )
346-         val  runningShardTasks =  persistentTasks.findTasks(ShardReplicationExecutor .TASK_NAME , Predicate  { true  }).stream()
330+ 
331+         val  followerShardIds =  clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards()
332+             .map {  shard ->  shard.value.shardId }
333+             .stream().collect(Collectors .toSet())
334+         val  runningShardTasksForIndex =  persistentTasks.findTasks(ShardReplicationExecutor .TASK_NAME , Predicate  { true  }).stream()
347335                .map { task ->  task.params as  ShardReplicationParams  }
336+                 .filter {taskParam ->  followerShardIds.contains(taskParam.followerShardId) }
348337                .collect(Collectors .toList())
349338
350-         if  (runningShardTasks .size ==   0 ) {
339+         if  (runningShardTasksForIndex .size !=  followerShardIds.size ) {
351340            return  InitFollowState 
352341        }
353342
@@ -696,19 +685,27 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
696685        }
697686    }
698687
699-     private  suspend  fun 
700-             startShardFollowTasks (tasks :  Map <ShardId , PersistentTask <ShardReplicationParams >>): FollowingState  {
688+     suspend  fun  startNewOrMissingShardTasks ():  Map <ShardId , PersistentTask <ShardReplicationParams >> {
701689        assert (clusterService.state().routingTable.hasIndex(followerIndexName)) { " Can't find index $followerIndexName " 
702690        val  shards =  clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards()
703-         val  newTasks =  shards.map {
691+         val  persistentTasks =  clusterService.state().metadata.custom<PersistentTasksCustomMetadata >(PersistentTasksCustomMetadata .TYPE )
692+         val  runningShardTasks =  persistentTasks.findTasks(ShardReplicationExecutor .TASK_NAME , Predicate  { true  }).stream()
693+             .map { task ->  task as  PersistentTask <ShardReplicationParams > }
694+             .filter { task ->  task.params!! .followerShardId.indexName  ==  followerIndexName}
695+             .collect(Collectors .toMap(
696+                 {t:  PersistentTask <ShardReplicationParams > ->  t.params!! .followerShardId},
697+                 {t:  PersistentTask <ShardReplicationParams > ->  t}))
698+ 
699+         val  tasks =  shards.map {
704700            it.value.shardId
705701        }.associate { shardId -> 
706-             val  task =  tasks .getOrElse(shardId) {
702+             val  task =  runningShardTasks .getOrElse(shardId) {
707703                startReplicationTask(ShardReplicationParams (leaderAlias, ShardId (leaderIndex, shardId.id), shardId))
708704            }
709705            return @associate shardId to task
710706        }
711-         return  FollowingState (newTasks)
707+ 
708+         return  tasks
712709    }
713710
714711    private  suspend  fun  cancelRestore () {
0 commit comments