1616import org .elasticsearch .client .internal .Client ;
1717import org .elasticsearch .client .internal .OriginSettingClient ;
1818import org .elasticsearch .client .internal .ParentTaskAssigningClient ;
19+ import org .elasticsearch .cluster .ClusterName ;
1920import org .elasticsearch .cluster .ClusterState ;
2021import org .elasticsearch .cluster .block .ClusterBlockException ;
2122import org .elasticsearch .cluster .block .ClusterBlockLevel ;
2223import org .elasticsearch .cluster .node .DiscoveryNode ;
2324import org .elasticsearch .cluster .node .DiscoveryNodeRole ;
25+ import org .elasticsearch .cluster .project .ProjectResolver ;
2426import org .elasticsearch .cluster .service .ClusterService ;
2527import org .elasticsearch .common .settings .ClusterSettings ;
2628import org .elasticsearch .common .unit .ByteSizeValue ;
3537import org .elasticsearch .xpack .core .ml .action .MlMemoryAction .Response .MlMemoryStats ;
3638import org .elasticsearch .xpack .core .ml .action .TrainedModelCacheInfoAction ;
3739import org .elasticsearch .xpack .core .ml .action .TrainedModelCacheInfoAction .Response .CacheInfo ;
40+ import org .elasticsearch .xpack .core .ml .inference .assignment .TrainedModelAssignmentMetadata ;
3841import org .elasticsearch .xpack .ml .job .NodeLoad ;
3942import org .elasticsearch .xpack .ml .job .NodeLoadDetector ;
4043import org .elasticsearch .xpack .ml .process .MlMemoryTracker ;
@@ -53,6 +56,7 @@ public class TransportMlMemoryAction extends TransportMasterNodeAction<MlMemoryA
5356
5457 private final Client client ;
5558 private final MlMemoryTracker memoryTracker ;
59+ private final ProjectResolver projectResolver ;
5660
5761 @ Inject
5862 public TransportMlMemoryAction (
@@ -61,7 +65,8 @@ public TransportMlMemoryAction(
6165 ThreadPool threadPool ,
6266 ActionFilters actionFilters ,
6367 Client client ,
64- MlMemoryTracker memoryTracker
68+ MlMemoryTracker memoryTracker ,
69+ ProjectResolver projectResolver
6570 ) {
6671 super (
6772 MlMemoryAction .NAME ,
@@ -75,6 +80,7 @@ public TransportMlMemoryAction(
7580 );
7681 this .client = new OriginSettingClient (client , ML_ORIGIN );
7782 this .memoryTracker = memoryTracker ;
83+ this .projectResolver = projectResolver ;
7884 }
7985
8086 @ Override
@@ -87,6 +93,10 @@ protected void masterOperation(
8793
8894 ClusterSettings clusterSettings = clusterService .getClusterSettings ();
8995
96+ var clusterName = state .getClusterName ();
97+ var projectMetadata = state .projectState (projectResolver .getProjectId ()).metadata ();
98+ var trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata .fromMetadata (projectMetadata );
99+ PersistentTasksCustomMetadata persistentTasksCustomMetadata = projectMetadata .custom (PersistentTasksCustomMetadata .TYPE );
90100 // Resolve the node specification to some concrete nodes
91101 String [] nodeIds = state .nodes ().resolveNodes (request .getNodeId ());
92102
@@ -112,7 +122,9 @@ protected void masterOperation(
112122 trainedModelCacheInfoRequest ,
113123 delegate2 .delegateFailureAndWrap (
114124 (l , trainedModelCacheInfoResponse ) -> handleResponses (
115- state ,
125+ clusterName ,
126+ persistentTasksCustomMetadata ,
127+ trainedModelAssignmentMetadata ,
116128 clusterSettings ,
117129 nodesStatsResponse ,
118130 trainedModelCacheInfoResponse ,
@@ -127,15 +139,14 @@ protected void masterOperation(
127139 if (memoryTracker .isEverRefreshed ()) {
128140 memoryTrackerRefreshListener .onResponse (null );
129141 } else {
130- memoryTracker .refresh (
131- state .getMetadata ().getProject ().custom (PersistentTasksCustomMetadata .TYPE ),
132- memoryTrackerRefreshListener
133- );
142+ memoryTracker .refresh (persistentTasksCustomMetadata , memoryTrackerRefreshListener );
134143 }
135144 }
136145
137146 void handleResponses (
138- ClusterState state ,
147+ ClusterName clusterName ,
148+ PersistentTasksCustomMetadata persistentTasks ,
149+ TrainedModelAssignmentMetadata assignmentMetadata ,
139150 ClusterSettings clusterSettings ,
140151 NodesStatsResponse nodesStatsResponse ,
141152 TrainedModelCacheInfoAction .Response trainedModelCacheInfoResponse ,
@@ -173,7 +184,8 @@ void handleResponses(
173184 ByteSizeValue mlNativeInference ;
174185 if (node .getRoles ().contains (DiscoveryNodeRole .ML_ROLE )) {
175186 NodeLoad nodeLoad = nodeLoadDetector .detectNodeLoad (
176- state ,
187+ persistentTasks ,
188+ assignmentMetadata ,
177189 node ,
178190 maxOpenJobsPerNode ,
179191 maxMachineMemoryPercent ,
@@ -219,7 +231,7 @@ void handleResponses(
219231 );
220232 }
221233
222- listener .onResponse (new MlMemoryAction .Response (state . getClusterName () , nodeResponses , failures ));
234+ listener .onResponse (new MlMemoryAction .Response (clusterName , nodeResponses , failures ));
223235 }
224236
225237 @ Override
0 commit comments