Skip to content

Commit b8da642

Browse files
authored
Refactor checks to ML Indices to return true when MultiTenancy enabled (opensearch-project#4089)
* refactor checks to ml indices to return true when MultiTenancy enabled Signed-off-by: Brian Flores <iflorbri@amazon.com> * add JavaDoc to doesMultiTenantIndexExists Signed-off-by: Brian Flores <iflorbri@amazon.com> * updates naming & adds UTs to doesMultiTenantIndexExist Signed-off-by: Brian Flores <iflorbri@amazon.com> * assert MLIndicesHandler has non-null MLFeatureEnabledSettingObject Signed-off-by: Brian Flores <iflorbri@amazon.com> * update JavaDoc with better grammar Signed-off-by: Brian Flores <iflorbri@amazon.com> * apply spotless Signed-off-by: Brian Flores <iflorbri@amazon.com> --------- Signed-off-by: Brian Flores <iflorbri@amazon.com>
1 parent b6952b9 commit b8da642

File tree

9 files changed

+105
-14
lines changed

9 files changed

+105
-14
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import org.opensearch.ml.engine.Executable;
6969
import org.opensearch.ml.engine.annotation.Function;
7070
import org.opensearch.ml.engine.encryptor.Encryptor;
71+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
7172
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
7273
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
7374
import org.opensearch.ml.engine.tools.QueryPlanningTool;
@@ -173,7 +174,7 @@ public void execute(Input input, ActionListener<Output> listener) {
173174
.fetchSourceContext(fetchSourceContext)
174175
.build();
175176

176-
if (clusterService.state().metadata().hasIndex(ML_AGENT_INDEX)) {
177+
if (MLIndicesHandler.doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_AGENT_INDEX)) {
177178
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
178179
sdkClient
179180
.getDataObjectAsync(getDataObjectRequest, client.threadPool().executor("opensearch_ml_general"))

ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,13 @@
3030
import org.opensearch.ml.common.CommonValue;
3131
import org.opensearch.ml.common.MLIndex;
3232
import org.opensearch.ml.common.exception.MLException;
33+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
3334
import org.opensearch.transport.client.Client;
3435

36+
import com.google.common.annotations.VisibleForTesting;
37+
3538
import lombok.AccessLevel;
39+
import lombok.NonNull;
3640
import lombok.RequiredArgsConstructor;
3741
import lombok.experimental.FieldDefaults;
3842
import lombok.extern.log4j.Log4j2;
@@ -41,9 +45,12 @@
4145
@RequiredArgsConstructor
4246
@Log4j2
4347
public class MLIndicesHandler {
44-
48+
@NonNull
4549
ClusterService clusterService;
50+
@NonNull
4651
Client client;
52+
@NonNull
53+
MLFeatureEnabledSetting mlFeatureEnabledSetting;
4754
private static final Map<String, AtomicBoolean> indexMappingUpdated = new HashMap<>();
4855

4956
static {
@@ -52,6 +59,21 @@ public class MLIndicesHandler {
5259
}
5360
}
5461

62+
/**
63+
* Determines whether an index exists on non-multi tenancy enabled environments. Otherwise,
64+
* returns true when multiTenancy is Enabled
65+
*
66+
* @param clusterService the cluster service
67+
* @param isMultiTenancyEnabled whether multi-tenancy is enabled
68+
* @param indexName - the index to search
69+
* @return boolean indicating the existence of an index. Returns true if multitenancy is enabled.
70+
* @implNote This method assumes if your environment enables multi tenancy, then your plugin indices are
71+
* pre-populated. If this is incorrect, it will result in unwanted early returns without checking the clusterService.
72+
*/
73+
public static boolean doesMultiTenantIndexExist(ClusterService clusterService, boolean isMultiTenancyEnabled, String indexName) {
74+
return isMultiTenancyEnabled || clusterService.state().metadata().hasIndex(indexName);
75+
}
76+
5577
public void initModelGroupIndexIfAbsent(ActionListener<Boolean> listener) {
5678
initMLIndexIfAbsent(MLIndex.MODEL_GROUP, listener);
5779
}
@@ -105,7 +127,7 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener<Boolean> listener)
105127
String mapping = index.getMapping();
106128
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
107129
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
108-
if (!clusterService.state().metadata().hasIndex(indexName)) {
130+
if (!MLIndicesHandler.doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), indexName)) {
109131
ActionListener<CreateIndexResponse> actionListener = ActionListener.wrap(r -> {
110132
if (r.isAcknowledged()) {
111133
log.info("create index:{}", indexName);
@@ -220,4 +242,9 @@ public void shouldUpdateIndex(String indexName, Integer newVersion, ActionListen
220242
listener.onResponse(newVersion > oldVersion);
221243
}
222244

245+
@VisibleForTesting
246+
public boolean doesIndexExists(String indexName) {
247+
return MLIndicesHandler.doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), indexName);
248+
}
249+
223250
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package org.opensearch.ml.engine.indices;
22

33
import static org.junit.Assert.assertEquals;
4+
import static org.junit.Assert.assertFalse;
5+
import static org.junit.Assert.assertTrue;
46
import static org.mockito.ArgumentMatchers.any;
57
import static org.mockito.ArgumentMatchers.anyString;
68
import static org.mockito.ArgumentMatchers.isA;
@@ -11,6 +13,7 @@
1113
import static org.mockito.Mockito.when;
1214
import static org.opensearch.ml.common.CommonValue.META;
1315
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
16+
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
1417
import static org.opensearch.ml.common.CommonValue.ML_JOBS_INDEX;
1518
import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX;
1619
import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX;
@@ -35,6 +38,7 @@
3538
import org.opensearch.common.settings.Settings;
3639
import org.opensearch.common.util.concurrent.ThreadContext;
3740
import org.opensearch.core.action.ActionListener;
41+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
3842
import org.opensearch.threadpool.ThreadPool;
3943
import org.opensearch.transport.client.AdminClient;
4044
import org.opensearch.transport.client.Client;
@@ -74,6 +78,9 @@ public class MLIndicesHandlerTest {
7478
@Mock
7579
private ThreadPool threadPool;
7680

81+
@Mock
82+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
83+
7784
Settings settings;
7885
ThreadContext threadContext;
7986
MLIndicesHandler indicesHandler;
@@ -102,7 +109,29 @@ public void setUp() {
102109
threadContext = new ThreadContext(settings);
103110
when(client.threadPool()).thenReturn(threadPool);
104111
when(threadPool.getThreadContext()).thenReturn(threadContext);
105-
indicesHandler = new MLIndicesHandler(clusterService, client);
112+
indicesHandler = new MLIndicesHandler(clusterService, client, mlFeatureEnabledSetting);
113+
}
114+
115+
@Test
116+
public void doesMultiTenantIndexExist_multiTenancyEnabled_returnsTrue() {
117+
assertTrue(MLIndicesHandler.doesMultiTenantIndexExist(null, true, null));
118+
MLIndicesHandler mlIndicesHandler = new MLIndicesHandler(clusterService, client, mlFeatureEnabledSetting);
119+
assertTrue(mlIndicesHandler.doesIndexExists(ML_CONFIG_INDEX));
120+
}
121+
122+
@Test
123+
public void doesMultiTenantIndexExist_multiTenancyDisabledSearchesClusterService_returnsValidSearchResult() {
124+
assertFalse(MLIndicesHandler.doesMultiTenantIndexExist(clusterService, false, null));
125+
126+
String sampleIndexName = "test-index";
127+
when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false);
128+
MLIndicesHandler mlIndicesHandler = new MLIndicesHandler(clusterService, client, mlFeatureEnabledSetting);
129+
130+
when(clusterService.state().metadata().hasIndex(sampleIndexName)).thenReturn(true);
131+
assertTrue(mlIndicesHandler.doesIndexExists(sampleIndexName));
132+
133+
when(clusterService.state().metadata().hasIndex(sampleIndexName)).thenReturn(false);
134+
assertFalse(mlIndicesHandler.doesIndexExists(sampleIndexName));
106135
}
107136

108137
@Test

plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.opensearch.ml.engine.MLEngineClassLoader;
2727
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
2828
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
29+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
2930
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
3031
import org.opensearch.script.ScriptService;
3132
import org.opensearch.tasks.Task;
@@ -74,7 +75,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
7475
String connectorId = executeConnectorRequest.getConnectorId();
7576
String connectorAction = ConnectorAction.ActionType.EXECUTE.name();
7677

77-
if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
78+
if (MLIndicesHandler
79+
.doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_CONNECTOR_INDEX)) {
7880
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
7981
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
8082
// adding tenantID as null, because we are not implement multi-tenancy for this feature yet.

plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import org.opensearch.ml.common.connector.HttpConnector;
4040
import org.opensearch.ml.common.exception.MLException;
4141
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
42+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
43+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
4244
import org.opensearch.ml.helper.ModelAccessControlHelper;
4345
import org.opensearch.ml.utils.RestActionUtils;
4446
import org.opensearch.remote.metadata.client.SdkClient;
@@ -65,17 +67,20 @@ public class MLSearchHandler {
6567
private ModelAccessControlHelper modelAccessControlHelper;
6668

6769
private ClusterService clusterService;
70+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
6871

6972
public MLSearchHandler(
7073
Client client,
7174
NamedXContentRegistry xContentRegistry,
7275
ModelAccessControlHelper modelAccessControlHelper,
73-
ClusterService clusterService
76+
ClusterService clusterService,
77+
MLFeatureEnabledSetting mlFeatureEnabledSetting
7478
) {
7579
this.modelAccessControlHelper = modelAccessControlHelper;
7680
this.client = client;
7781
this.xContentRegistry = xContentRegistry;
7882
this.clusterService = clusterService;
83+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
7984
}
8085

8186
/**
@@ -132,7 +137,12 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId,
132137
final ActionListener<SearchResponse> doubleWrapperListener = ActionListener
133138
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));
134139
if (modelAccessControlHelper.skipModelAccessControl(user)
135-
|| !clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {
140+
|| !MLIndicesHandler
141+
.doesMultiTenantIndexExist(
142+
clusterService,
143+
mlFeatureEnabledSetting.isMultiTenancyEnabled(),
144+
CommonValue.ML_MODEL_GROUP_INDEX
145+
)) {
136146

137147
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
138148
.builder()

plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
5656
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
5757
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
58+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
5859
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
5960
import org.opensearch.ml.helper.ModelAccessControlHelper;
6061
import org.opensearch.ml.model.MLModelManager;
@@ -199,7 +200,12 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel
199200
if (model.getConnector() != null) {
200201
Connector connector = model.getConnector();
201202
executeConnector(connector, mlInput, actionListener);
202-
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
203+
} else if (MLIndicesHandler
204+
.doesMultiTenantIndexExist(
205+
clusterService,
206+
mlFeatureEnabledSetting.isMultiTenancyEnabled(),
207+
ML_CONNECTOR_INDEX
208+
)) {
203209
ActionListener<Connector> listener = ActionListener
204210
.wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> {
205211
log.error("Failed to get connector {}", model.getConnectorId(), e);

plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
8383
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
8484
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
85+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
8586
import org.opensearch.ml.engine.utils.S3Utils;
8687
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
8788
import org.opensearch.ml.helper.ModelAccessControlHelper;
@@ -390,7 +391,12 @@ private void processRemoteBatchPrediction(
390391
remoteJob,
391392
actionListener
392393
);
393-
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
394+
} else if (MLIndicesHandler
395+
.doesMultiTenantIndexExist(
396+
clusterService,
397+
mlFeatureEnabledSetting.isMultiTenancyEnabled(),
398+
ML_CONNECTOR_INDEX
399+
)) {
394400
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
395401
executeConnector(
396402
connector,

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,10 @@ public Collection<Object> createComponents(
599599
Settings settings = environment.settings();
600600
Path dataPath = environment.dataFiles()[0];
601601

602-
mlIndicesHandler = new MLIndicesHandler(clusterService, client);
602+
mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings);
603+
mlFeatureEnabledSetting.addListener(mlTaskManager);
604+
605+
mlIndicesHandler = new MLIndicesHandler(clusterService, client, mlFeatureEnabledSetting);
603606

604607
SdkClient sdkClient = SdkClientFactory
605608
.createSdkClient(
@@ -665,8 +668,7 @@ public Collection<Object> createComponents(
665668
mlInputDatasetHandler = new MLInputDatasetHandler(client);
666669
modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings);
667670
connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings);
668-
mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings);
669-
mlFeatureEnabledSetting.addListener(mlTaskManager);
671+
670672
mlModelManager = new MLModelManager(
671673
clusterService,
672674
scriptService,
@@ -803,7 +805,13 @@ public Collection<Object> createComponents(
803805
MLToolExecutor toolExecutor = new MLToolExecutor(client, sdkClient, settings, clusterService, xContentRegistry, toolFactories);
804806
MLEngineClassLoader.register(FunctionName.TOOL, toolExecutor);
805807

806-
MLSearchHandler mlSearchHandler = new MLSearchHandler(client, xContentRegistry, modelAccessControlHelper, clusterService);
808+
MLSearchHandler mlSearchHandler = new MLSearchHandler(
809+
client,
810+
xContentRegistry,
811+
modelAccessControlHelper,
812+
clusterService,
813+
mlFeatureEnabledSetting
814+
);
807815
MLModelAutoReDeployer mlModelAutoRedeployer = new MLModelAutoReDeployer(
808816
clusterService,
809817
client,

plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase {
114114
public void setup() {
115115
MockitoAnnotations.openMocks(this);
116116
sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap());
117-
mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService));
117+
mlSearchHandler = spy(
118+
new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService, mlFeatureEnabledSetting)
119+
);
118120
searchModelTransportAction = new SearchModelTransportAction(
119121
transportService,
120122
actionFilters,

0 commit comments

Comments
 (0)