Skip to content

Commit 3cc2fca

Browse files
filtering stats api for hidden model (#2307) (#2328)
* filtering stats api for hidden model Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressing comments Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * reverting to excludes instead of includes Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressing comments Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * spotlessApply Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * renamed method name Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressing comments Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * spotlessApply Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * fixing a test Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressed comments Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> (cherry picked from commit a83a78f) Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 9c61432 commit 3cc2fca

File tree

2 files changed

+291
-38
lines changed

2 files changed

+291
-38
lines changed

plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java

Lines changed: 106 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,31 @@
77

88
import java.io.IOException;
99
import java.util.HashMap;
10+
import java.util.HashSet;
1011
import java.util.List;
1112
import java.util.Map;
13+
import java.util.Set;
14+
import java.util.concurrent.CountDownLatch;
15+
import java.util.concurrent.TimeUnit;
1216

1317
import org.opensearch.action.FailedNodeException;
18+
import org.opensearch.action.search.SearchRequest;
19+
import org.opensearch.action.search.SearchResponse;
1420
import org.opensearch.action.support.ActionFilters;
1521
import org.opensearch.action.support.nodes.TransportNodesAction;
22+
import org.opensearch.client.Client;
1623
import org.opensearch.cluster.service.ClusterService;
1724
import org.opensearch.common.inject.Inject;
25+
import org.opensearch.common.util.concurrent.ThreadContext;
26+
import org.opensearch.core.action.ActionListener;
1827
import org.opensearch.core.common.io.stream.StreamInput;
1928
import org.opensearch.env.Environment;
29+
import org.opensearch.index.query.BoolQueryBuilder;
30+
import org.opensearch.index.query.QueryBuilders;
31+
import org.opensearch.ml.common.CommonValue;
2032
import org.opensearch.ml.common.FunctionName;
33+
import org.opensearch.ml.common.MLModel;
34+
import org.opensearch.ml.model.MLModelManager;
2135
import org.opensearch.ml.stats.ActionName;
2236
import org.opensearch.ml.stats.MLActionStats;
2337
import org.opensearch.ml.stats.MLAlgoStats;
@@ -26,15 +40,26 @@
2640
import org.opensearch.ml.stats.MLStatLevel;
2741
import org.opensearch.ml.stats.MLStats;
2842
import org.opensearch.ml.stats.MLStatsInput;
43+
import org.opensearch.ml.utils.RestActionUtils;
2944
import org.opensearch.monitor.jvm.JvmService;
45+
import org.opensearch.search.SearchHit;
3046
import org.opensearch.threadpool.ThreadPool;
3147
import org.opensearch.transport.TransportService;
3248

49+
import com.google.common.annotations.VisibleForTesting;
50+
51+
import lombok.extern.log4j.Log4j2;
52+
53+
@Log4j2
3354
public class MLStatsNodesTransportAction extends
3455
TransportNodesAction<MLStatsNodesRequest, MLStatsNodesResponse, MLStatsNodeRequest, MLStatsNodeResponse> {
3556
private MLStats mlStats;
3657
private final JvmService jvmService;
3758

59+
private final Client client;
60+
61+
private final MLModelManager mlModelManager;
62+
3863
/**
3964
* Constructor
4065
*
@@ -52,7 +77,9 @@ public MLStatsNodesTransportAction(
5277
TransportService transportService,
5378
ActionFilters actionFilters,
5479
MLStats mlStats,
55-
Environment environment
80+
Environment environment,
81+
Client client,
82+
MLModelManager mlModelManager
5683
) {
5784
super(
5885
MLStatsNodesAction.NAME,
@@ -67,6 +94,8 @@ public MLStatsNodesTransportAction(
6794
);
6895
this.mlStats = mlStats;
6996
this.jvmService = new JvmService(environment.settings());
97+
this.client = client;
98+
this.mlModelManager = mlModelManager;
7099
}
71100

72101
@Override
@@ -127,21 +156,88 @@ MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRe
127156
}
128157

129158
Map<String, MLModelStats> modelStats = new HashMap<>();
130-
// return model level stats
131159
if (mlStatsInput.includeModelStats()) {
132-
for (String modelId : mlStats.getAllModels()) {
133-
if (mlStatsInput.retrieveStatsForModel(modelId)) {
134-
Map<ActionName, MLActionStats> actionStatsMap = new HashMap<>();
135-
for (Map.Entry<ActionName, MLActionStats> entry : mlStats.getModelStats(modelId).entrySet()) {
136-
if (mlStatsInput.retrieveStatsForAction(entry.getKey())) {
137-
actionStatsMap.put(entry.getKey(), entry.getValue());
160+
CountDownLatch latch = new CountDownLatch(1);
161+
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
162+
searchHiddenModels(ActionListener.wrap(hiddenModels -> {
163+
for (String modelId : mlStats.getAllModels()) {
164+
if (isSuperAdmin || !hiddenModels.contains(modelId)) {
165+
if (mlStatsInput.retrieveStatsForModel(modelId)) {
166+
Map<ActionName, MLActionStats> actionStatsMap = new HashMap<>();
167+
for (Map.Entry<ActionName, MLActionStats> entry : mlStats.getModelStats(modelId).entrySet()) {
168+
if (mlStatsInput.retrieveStatsForAction(entry.getKey())) {
169+
actionStatsMap.put(entry.getKey(), entry.getValue());
170+
}
171+
}
172+
modelStats.put(modelId, new MLModelStats(actionStatsMap));
138173
}
139174
}
140-
modelStats.put(modelId, new MLModelStats(actionStatsMap));
141175
}
176+
}, e -> { log.error("Search Hidden model wasn't successful"); }), latch);
177+
// Wait for the asynchronous call to complete
178+
try {
179+
latch.await(10, TimeUnit.SECONDS);
180+
} catch (InterruptedException e) {
181+
// Handle interruption if necessary
182+
Thread.currentThread().interrupt();
142183
}
143184
}
144-
145185
return new MLStatsNodeResponse(clusterService.localNode(), statValues, algorithmStats, modelStats);
146186
}
187+
188+
@VisibleForTesting
189+
void searchHiddenModels(ActionListener<Set<String>> listener, CountDownLatch latch) {
190+
SearchRequest searchRequest = buildHiddenModelSearchRequest();
191+
// Use a try-with-resources block to ensure resources are properly released
192+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
193+
// Wrap the listener to restore thread context before calling it
194+
ActionListener<Set<String>> internalListener = ActionListener.runAfter(listener, () -> {
195+
latch.countDown();
196+
threadContext.restore();
197+
});
198+
// Wrap the search response handler to handle success and failure cases
199+
// Notify the listener of any search failures
200+
ActionListener<SearchResponse> al = ActionListener.wrap(response -> {
201+
// Initialize the result set
202+
Set<String> result = new HashSet<>(response.getHits().getHits().length); // Set initial capacity to the number of hits
203+
204+
// Iterate over the search hits and add their IDs to the result set
205+
for (SearchHit hit : response.getHits()) {
206+
result.add(hit.getId());
207+
}
208+
// Notify the listener of the search results
209+
internalListener.onResponse(result);
210+
}, internalListener::onFailure);
211+
212+
// Execute the search request asynchronously
213+
client.search(searchRequest, al);
214+
} catch (Exception e) {
215+
// Notify the listener of any unexpected errors
216+
listener.onFailure(e);
217+
}
218+
}
219+
220+
private SearchRequest buildHiddenModelSearchRequest() {
221+
SearchRequest searchRequest = new SearchRequest(CommonValue.ML_MODEL_INDEX);
222+
// Build the query
223+
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
224+
boolQueryBuilder
225+
.filter(
226+
QueryBuilders
227+
.boolQuery()
228+
.must(QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, true))
229+
// Add the additional filter to exclude documents where "chunk_number" exists
230+
.mustNot(QueryBuilders.existsQuery("chunk_number"))
231+
);
232+
searchRequest.source().query(boolQueryBuilder);
233+
// Specify the fields to include in the search results (only the "_id" field)
234+
// No fields to exclude
235+
searchRequest.source().fetchSource(new String[] { "_id" }, new String[] {});
236+
return searchRequest;
237+
}
238+
239+
@VisibleForTesting
240+
boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) {
241+
return RestActionUtils.isSuperAdminUser(clusterService, client);
242+
}
147243
}

0 commit comments

Comments
 (0)