7
7
8
8
import java .io .IOException ;
9
9
import java .util .HashMap ;
10
+ import java .util .HashSet ;
10
11
import java .util .List ;
11
12
import java .util .Map ;
13
+ import java .util .Set ;
14
+ import java .util .concurrent .CountDownLatch ;
15
+ import java .util .concurrent .TimeUnit ;
12
16
13
17
import org .opensearch .action .FailedNodeException ;
18
+ import org .opensearch .action .search .SearchRequest ;
19
+ import org .opensearch .action .search .SearchResponse ;
14
20
import org .opensearch .action .support .ActionFilters ;
15
21
import org .opensearch .action .support .nodes .TransportNodesAction ;
22
+ import org .opensearch .client .Client ;
16
23
import org .opensearch .cluster .service .ClusterService ;
17
24
import org .opensearch .common .inject .Inject ;
25
+ import org .opensearch .common .util .concurrent .ThreadContext ;
26
+ import org .opensearch .core .action .ActionListener ;
18
27
import org .opensearch .core .common .io .stream .StreamInput ;
19
28
import org .opensearch .env .Environment ;
29
+ import org .opensearch .index .query .BoolQueryBuilder ;
30
+ import org .opensearch .index .query .QueryBuilders ;
31
+ import org .opensearch .ml .common .CommonValue ;
20
32
import org .opensearch .ml .common .FunctionName ;
33
+ import org .opensearch .ml .common .MLModel ;
34
+ import org .opensearch .ml .model .MLModelManager ;
21
35
import org .opensearch .ml .stats .ActionName ;
22
36
import org .opensearch .ml .stats .MLActionStats ;
23
37
import org .opensearch .ml .stats .MLAlgoStats ;
26
40
import org .opensearch .ml .stats .MLStatLevel ;
27
41
import org .opensearch .ml .stats .MLStats ;
28
42
import org .opensearch .ml .stats .MLStatsInput ;
43
+ import org .opensearch .ml .utils .RestActionUtils ;
29
44
import org .opensearch .monitor .jvm .JvmService ;
45
+ import org .opensearch .search .SearchHit ;
30
46
import org .opensearch .threadpool .ThreadPool ;
31
47
import org .opensearch .transport .TransportService ;
32
48
49
+ import com .google .common .annotations .VisibleForTesting ;
50
+
51
+ import lombok .extern .log4j .Log4j2 ;
52
+
53
+ @ Log4j2
33
54
public class MLStatsNodesTransportAction extends
34
55
TransportNodesAction <MLStatsNodesRequest , MLStatsNodesResponse , MLStatsNodeRequest , MLStatsNodeResponse > {
35
56
private MLStats mlStats ;
36
57
private final JvmService jvmService ;
37
58
59
+ private final Client client ;
60
+
61
+ private final MLModelManager mlModelManager ;
62
+
38
63
/**
39
64
* Constructor
40
65
*
@@ -52,7 +77,9 @@ public MLStatsNodesTransportAction(
52
77
TransportService transportService ,
53
78
ActionFilters actionFilters ,
54
79
MLStats mlStats ,
55
- Environment environment
80
+ Environment environment ,
81
+ Client client ,
82
+ MLModelManager mlModelManager
56
83
) {
57
84
super (
58
85
MLStatsNodesAction .NAME ,
@@ -67,6 +94,8 @@ public MLStatsNodesTransportAction(
67
94
);
68
95
this .mlStats = mlStats ;
69
96
this .jvmService = new JvmService (environment .settings ());
97
+ this .client = client ;
98
+ this .mlModelManager = mlModelManager ;
70
99
}
71
100
72
101
@ Override
@@ -127,21 +156,88 @@ MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRe
127
156
}
128
157
129
158
Map <String , MLModelStats > modelStats = new HashMap <>();
130
- // return model level stats
131
159
if (mlStatsInput .includeModelStats ()) {
132
- for (String modelId : mlStats .getAllModels ()) {
133
- if (mlStatsInput .retrieveStatsForModel (modelId )) {
134
- Map <ActionName , MLActionStats > actionStatsMap = new HashMap <>();
135
- for (Map .Entry <ActionName , MLActionStats > entry : mlStats .getModelStats (modelId ).entrySet ()) {
136
- if (mlStatsInput .retrieveStatsForAction (entry .getKey ())) {
137
- actionStatsMap .put (entry .getKey (), entry .getValue ());
160
+ CountDownLatch latch = new CountDownLatch (1 );
161
+ boolean isSuperAdmin = isSuperAdminUserWrapper (clusterService , client );
162
+ searchHiddenModels (ActionListener .wrap (hiddenModels -> {
163
+ for (String modelId : mlStats .getAllModels ()) {
164
+ if (isSuperAdmin || !hiddenModels .contains (modelId )) {
165
+ if (mlStatsInput .retrieveStatsForModel (modelId )) {
166
+ Map <ActionName , MLActionStats > actionStatsMap = new HashMap <>();
167
+ for (Map .Entry <ActionName , MLActionStats > entry : mlStats .getModelStats (modelId ).entrySet ()) {
168
+ if (mlStatsInput .retrieveStatsForAction (entry .getKey ())) {
169
+ actionStatsMap .put (entry .getKey (), entry .getValue ());
170
+ }
171
+ }
172
+ modelStats .put (modelId , new MLModelStats (actionStatsMap ));
138
173
}
139
174
}
140
- modelStats .put (modelId , new MLModelStats (actionStatsMap ));
141
175
}
176
+ }, e -> { log .error ("Search Hidden model wasn't successful" ); }), latch );
177
+ // Wait for the asynchronous call to complete
178
+ try {
179
+ latch .await (10 , TimeUnit .SECONDS );
180
+ } catch (InterruptedException e ) {
181
+ // Handle interruption if necessary
182
+ Thread .currentThread ().interrupt ();
142
183
}
143
184
}
144
-
145
185
return new MLStatsNodeResponse (clusterService .localNode (), statValues , algorithmStats , modelStats );
146
186
}
187
+
188
+ @ VisibleForTesting
189
+ void searchHiddenModels (ActionListener <Set <String >> listener , CountDownLatch latch ) {
190
+ SearchRequest searchRequest = buildHiddenModelSearchRequest ();
191
+ // Use a try-with-resources block to ensure resources are properly released
192
+ try (ThreadContext .StoredContext threadContext = client .threadPool ().getThreadContext ().stashContext ()) {
193
+ // Wrap the listener to restore thread context before calling it
194
+ ActionListener <Set <String >> internalListener = ActionListener .runAfter (listener , () -> {
195
+ latch .countDown ();
196
+ threadContext .restore ();
197
+ });
198
+ // Wrap the search response handler to handle success and failure cases
199
+ // Notify the listener of any search failures
200
+ ActionListener <SearchResponse > al = ActionListener .wrap (response -> {
201
+ // Initialize the result set
202
+ Set <String > result = new HashSet <>(response .getHits ().getHits ().length ); // Set initial capacity to the number of hits
203
+
204
+ // Iterate over the search hits and add their IDs to the result set
205
+ for (SearchHit hit : response .getHits ()) {
206
+ result .add (hit .getId ());
207
+ }
208
+ // Notify the listener of the search results
209
+ internalListener .onResponse (result );
210
+ }, internalListener ::onFailure );
211
+
212
+ // Execute the search request asynchronously
213
+ client .search (searchRequest , al );
214
+ } catch (Exception e ) {
215
+ // Notify the listener of any unexpected errors
216
+ listener .onFailure (e );
217
+ }
218
+ }
219
+
220
+ private SearchRequest buildHiddenModelSearchRequest () {
221
+ SearchRequest searchRequest = new SearchRequest (CommonValue .ML_MODEL_INDEX );
222
+ // Build the query
223
+ BoolQueryBuilder boolQueryBuilder = QueryBuilders .boolQuery ();
224
+ boolQueryBuilder
225
+ .filter (
226
+ QueryBuilders
227
+ .boolQuery ()
228
+ .must (QueryBuilders .termQuery (MLModel .IS_HIDDEN_FIELD , true ))
229
+ // Add the additional filter to exclude documents where "chunk_number" exists
230
+ .mustNot (QueryBuilders .existsQuery ("chunk_number" ))
231
+ );
232
+ searchRequest .source ().query (boolQueryBuilder );
233
+ // Specify the fields to include in the search results (only the "_id" field)
234
+ // No fields to exclude
235
+ searchRequest .source ().fetchSource (new String [] { "_id" }, new String [] {});
236
+ return searchRequest ;
237
+ }
238
+
239
+ @ VisibleForTesting
240
+ boolean isSuperAdminUserWrapper (ClusterService clusterService , Client client ) {
241
+ return RestActionUtils .isSuperAdminUser (clusterService , client );
242
+ }
147
243
}
0 commit comments